Skip to content

Commit 13abd4a

Browse files
committed
Implement intersectBySorted API
1 parent ef0a299 commit 13abd4a

File tree

4 files changed

+124
-5
lines changed

4 files changed

+124
-5
lines changed

src/Streamly/Internal/Data/Stream/IsStream/Top.hs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ module Streamly.Internal.Data.Stream.IsStream.Top
2828
-- | These are not exactly set operations because streams are not
2929
-- necessarily sets, they may have duplicated elements.
3030
, intersectBy
31-
, mergeIntersectBy
31+
, intersectBySorted
3232
, differenceBy
3333
, mergeDifferenceBy
3434
, unionBy
@@ -65,6 +65,7 @@ import Streamly.Internal.Data.Stream.IsStream.Common (concatM)
6565
import Streamly.Internal.Data.Stream.IsStream.Type
6666
(IsStream(..), adapt, foldl', fromList)
6767
import Streamly.Internal.Data.Stream.Serial (SerialT)
68+
--import Streamly.Internal.Data.Stream.StreamD (fromStreamD, toStreamD)
6869
import Streamly.Internal.Data.Time.Units (NanoSecond64(..), toRelTime64)
6970

7071
import qualified Data.List as List
@@ -78,6 +79,7 @@ import qualified Streamly.Internal.Data.Stream.IsStream.Expand as Stream
7879
import qualified Streamly.Internal.Data.Stream.IsStream.Reduce as Stream
7980
import qualified Streamly.Internal.Data.Stream.IsStream.Transform as Stream
8081
import qualified Streamly.Internal.Data.Stream.IsStream.Type as IsStream
82+
import qualified Streamly.Internal.Data.Stream.StreamD as StreamD
8183

8284
import Prelude hiding (filter, zipWith, concatMap, concat)
8385

@@ -514,11 +516,12 @@ intersectBy eq s1 s2 =
514516
--
515517
-- Time: O(m+n)
516518
--
517-
-- /Unimplemented/
518-
{-# INLINE mergeIntersectBy #-}
519-
mergeIntersectBy :: -- (IsStream t, Monad m) =>
519+
-- /Pre-release/
520+
{-# INLINE intersectBySorted #-}
521+
intersectBySorted :: (IsStream t, MonadIO m, Eq a) =>
520522
(a -> a -> Ordering) -> t m a -> t m a -> t m a
521-
mergeIntersectBy _eq _s1 _s2 = undefined
523+
intersectBySorted eq s1 =
524+
IsStream.fromStreamD . StreamD.intersectBySorted eq (IsStream.toStreamD s1) . IsStream.toStreamD
522525

523526
-- Roughly leftJoin s1 s2 = s1 `difference` s2 + s1 `intersection` s2
524527

src/Streamly/Internal/Data/Stream/StreamD/Nesting.hs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ module Streamly.Internal.Data.Stream.StreamD.Nesting
142142
-- | Opposite to compact in ArrayStream
143143
, splitInnerBy
144144
, splitInnerBySuffix
145+
, intersectBySorted
145146
)
146147
where
147148

@@ -481,6 +482,59 @@ mergeBy
481482
=> (a -> a -> Ordering) -> Stream m a -> Stream m a -> Stream m a
482483
mergeBy cmp = mergeByM (\a b -> return $ cmp a b)
483484

485+
-------------------------------------------------------------------------------
486+
-- Intersection of sorted streams ---------------------------------------------
487+
-------------------------------------------------------------------------------
488+
{-# INLINE_NORMAL intersectBySorted #-}
489+
intersectBySorted
490+
:: (MonadIO m, Eq a)
491+
=> (a -> a -> Ordering) -> Stream m a -> Stream m a -> Stream m a
492+
intersectBySorted cmp (Stream stepa ta) (Stream stepb tb) =
493+
Stream step (Just ta, Just tb, Nothing, Nothing, Nothing)
494+
495+
where
496+
{-# INLINE_LATE step #-}
497+
498+
-- step 1
499+
step gst (Just sa, sb, Nothing, b, Nothing) = do
500+
r <- stepa gst sa
501+
return $ case r of
502+
Yield a sa' -> Skip (Just sa', sb, Just a, b, Nothing)
503+
Skip sa' -> Skip (Just sa', sb, Nothing, b, Nothing)
504+
Stop -> Stop
505+
506+
-- step 2
507+
step gst (sa, Just sb, a, Nothing, Nothing) = do
508+
r <- stepb gst sb
509+
return $ case r of
510+
Yield b sb' -> Skip (sa, Just sb', a, Just b, Nothing)
511+
Skip sb' -> Skip (sa, Just sb', a, Nothing, Nothing)
512+
Stop -> Stop
513+
514+
-- step 3
515+
-- both the values are available compare it
516+
step _ (sa, sb, Just a, Just b, Nothing) = do
517+
let res = cmp a b
518+
return $ case res of
519+
GT -> Skip (sa, sb, Just a, Nothing, Nothing)
520+
LT -> Skip (sa, sb, Nothing, Just b, Nothing)
521+
EQ -> Yield a (sa, sb, Nothing, Just a, Just b) -- step 4
522+
523+
-- step 4
524+
-- Matching element
525+
step gst (Just sa, Just sb, Nothing, Just _, Just b) = do
526+
r1 <- stepa gst sa
527+
return $ case r1 of
528+
Yield a' sa' -> do
529+
if a' == b -- match with prev a
530+
then Yield a' (Just sa', Just sb, Nothing, Just b, Just b) --step 1
531+
else Skip (Just sa', Just sb, Just a', Nothing, Nothing)
532+
533+
Skip sa' -> Skip (Just sa', Just sb, Nothing, Nothing, Nothing)
534+
Stop -> Stop
535+
536+
step _ (_, _, _, _, _) = return Stop
537+
484538
------------------------------------------------------------------------------
485539
-- Combine N Streams - unfoldMany
486540
------------------------------------------------------------------------------
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
module Main (main)
2+
where
3+
4+
import Data.List (intersect, sort)
5+
import Test.QuickCheck
6+
( Gen
7+
, Property
8+
, choose
9+
, forAll
10+
, listOf
11+
)
12+
import Test.QuickCheck.Monadic (monadicIO, assert, run)
13+
import qualified Streamly.Prelude as S
14+
import qualified Streamly.Internal.Data.Stream.IsStream.Top as Top
15+
16+
import Prelude hiding
17+
(maximum, minimum, elem, notElem, null, product, sum, head, last, take)
18+
import Test.Hspec as H
19+
import Test.Hspec.QuickCheck
20+
21+
min_value :: Int
22+
min_value = 0
23+
24+
max_value :: Int
25+
max_value = 10000
26+
27+
chooseInt :: (Int, Int) -> Gen Int
28+
chooseInt = choose
29+
30+
intersectBySorted :: Property
31+
intersectBySorted =
32+
forAll (listOf (chooseInt (min_value, max_value))) $ \ls0 ->
33+
forAll (listOf (chooseInt (min_value, max_value))) $ \ls1 ->
34+
monadicIO $ action (sort ls0) (sort ls1)
35+
36+
where
37+
38+
action ls0 ls1 = do
39+
v1 <-
40+
run
41+
$ S.toList
42+
$ Top.intersectBySorted
43+
compare
44+
(S.fromList ls0)
45+
(S.fromList ls1)
46+
let v2 = intersect ls0 ls1
47+
assert (v1 == sort v2)
48+
49+
-------------------------------------------------------------------------------
50+
moduleName :: String
51+
moduleName = "Data.Stream.Top"
52+
53+
main :: IO ()
54+
main = hspec $ do
55+
describe moduleName $ do
56+
-- intersect
57+
prop "intersectBySorted" Main.intersectBySorted

test/streamly-tests.cabal

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,3 +430,8 @@ test-suite version-bounds
430430
import: test-options
431431
type: exitcode-stdio-1.0
432432
main-is: version-bounds.hs
433+
434+
test-suite Data.Stream.Top
435+
import: test-options
436+
type: exitcode-stdio-1.0
437+
main-is: Streamly/Test/Data/Stream/Top.hs

0 commit comments

Comments
 (0)