From dbf921bb13bd4801588c3803261a78b012b760d7 Mon Sep 17 00:00:00 2001 From: Harendra Kumar Date: Tue, 27 Aug 2024 12:28:36 +0530 Subject: [PATCH] Add parTeeWith for folds --- .../Internal/Data/Fold/Channel/Type.hs | 14 +++ .../Internal/Data/Scanl/Concurrent.hs | 85 ++++++++++++++++++- 2 files changed, 98 insertions(+), 1 deletion(-) diff --git a/src/Streamly/Internal/Data/Fold/Channel/Type.hs b/src/Streamly/Internal/Data/Fold/Channel/Type.hs index fdde964e7f..7a6c2875be 100644 --- a/src/Streamly/Internal/Data/Fold/Channel/Type.hs +++ b/src/Streamly/Internal/Data/Fold/Channel/Type.hs @@ -22,6 +22,7 @@ module Streamly.Internal.Data.Fold.Channel.Type , newChannelWith , newChannelWithScan , newChannel + , newScanChannel , sendToWorker , sendToWorker_ , checkFoldStatus -- XXX collectFoldOutput @@ -340,6 +341,8 @@ scanToChannel chan (Scanl step initial extract final) = r <- initial case r of Fold.Partial s -> do + b <- extract s + void $ sendPartialToDriver chan b return $ Fold.Partial s Fold.Done b -> Fold.Done <$> void (sendYieldToDriver chan b) @@ -356,6 +359,7 @@ scanToChannel chan (Scanl step initial extract final) = extract1 _ = return () + -- XXX Should we not discard the result? final1 st = void (final st) {-# INLINABLE newChannelWithScan #-} @@ -394,6 +398,16 @@ newChannel modifier f = do outQMvRev <- liftIO newEmptyMVar fmap fst (newChannelWith outQRev outQMvRev modifier f) +{-# INLINABLE newScanChannel #-} +{-# SPECIALIZE newScanChannel :: + (Config -> Config) -> Scanl IO a b -> IO (Channel IO a b) #-} +newScanChannel :: (MonadRunInIO m) => + (Config -> Config) -> Scanl m a b -> m (Channel m a b) +newScanChannel modifier f = do + outQRev <- liftIO $ newIORef ([], 0) + outQMvRev <- liftIO newEmptyMVar + fmap fst (newChannelWithScan outQRev outQMvRev modifier f) + ------------------------------------------------------------------------------- -- Process events received by the driver thread from the fold worker side ------------------------------------------------------------------------------- diff --git a/src/Streamly/Internal/Data/Scanl/Concurrent.hs b/src/Streamly/Internal/Data/Scanl/Concurrent.hs index 14614e755a..5898a2623a 100644 --- a/src/Streamly/Internal/Data/Scanl/Concurrent.hs +++ b/src/Streamly/Internal/Data/Scanl/Concurrent.hs @@ -8,7 +8,8 @@ module Streamly.Internal.Data.Scanl.Concurrent ( - parDistributeScan + parTeeWith + , parDistributeScan , parDemuxScan ) where @@ -21,9 +22,12 @@ import Control.Monad.IO.Class (MonadIO(liftIO)) import Data.IORef (newIORef, readIORef) import Fusion.Plugin.Types (Fuse(..)) import Streamly.Internal.Control.Concurrent (MonadAsync) +import Streamly.Internal.Data.Atomics (atomicModifyIORefCAS) +import Streamly.Internal.Data.Fold (Step (..)) import Streamly.Internal.Data.Scanl (Scanl(..)) import Streamly.Internal.Data.Stream (Stream(..), Step(..)) import Streamly.Internal.Data.SVar.Type (adaptState) +import Streamly.Internal.Data.Tuple.Strict (Tuple3'(..)) import qualified Data.Map.Strict as Map @@ -34,6 +38,85 @@ import Streamly.Internal.Data.Channel.Types -- Concurrent scans ------------------------------------------------------------------------------- +-- | Execute both the scans in a tee concurrently. +-- +-- Example: +-- +-- >>> src = Stream.delay 1 (Stream.enumerateFromTo 1 3) +-- >>> delay x = threadDelay 1000000 >> print x >> return x +-- >>> c1 = Scanl.lmapM delay Scanl.sum +-- >>> c2 = Scanl.lmapM delay Scanl.length +-- >>> dst = Scanl.parTeeWith id (,) c1 c2 +-- >>> Stream.toList $ Stream.scanl dst src +-- ... +-- +{-# INLINABLE parTeeWith #-} +parTeeWith :: MonadAsync m => + (Config -> Config) + -> (a -> b -> c) + -> Scanl m x a + -> Scanl m x b + -> Scanl m x c +parTeeWith cfg f c1 c2 = Scanl step initial extract final + + where + + getResponse ch1 ch2 = do + -- NOTE: We do not need a queue and doorbell mechanism for this, a single + -- MVar should be enough. Also, there is only one writer and it writes + -- only once before we read it. + let db1 = outputDoorBell ch1 + let q1 = outputQueue ch1 + (xs1, _) <- liftIO $ atomicModifyIORefCAS q1 $ \x -> (([],0), x) + case xs1 of + [] -> do + liftIO $ takeMVar db1 + getResponse ch1 ch2 + x1 : [] -> do + case x1 of + FoldException _tid ex -> do + -- XXX + -- liftIO $ throwTo ch2Tid ThreadAbort + cleanup ch1 + cleanup ch2 + liftIO $ throwM ex + FoldDone _tid b -> return (Left b) + FoldPartial b -> return (Right b) + _ -> error "parTeeWith: not expecting more than one msg in q" + + processResponses ch1 ch2 r1 r2 = + return $ case r1 of + Left b1 -> do + case r2 of + Left b2 -> Done (f b1 b2) + Right b2 -> Done (f b1 b2) + Right b1 -> do + case r2 of + Left b2 -> Done (f b1 b2) + Right b2 -> Partial $ Tuple3' ch1 ch2 (f b1 b2) + + initial = do + ch1 <- newScanChannel cfg c1 + ch2 <- newScanChannel cfg c2 + r1 <- getResponse ch1 ch2 + r2 <- getResponse ch2 ch1 + processResponses ch1 ch2 r1 r2 + + step (Tuple3' ch1 ch2 _) x = do + sendToWorker_ ch1 x + sendToWorker_ ch2 x + r1 <- getResponse ch1 ch2 + r2 <- getResponse ch2 ch1 + processResponses ch1 ch2 r1 r2 + + extract (Tuple3' _ _ x) = return x + + final (Tuple3' ch1 ch2 x) = do + finalize ch1 + finalize ch2 + -- XXX generate the final value? + return x + -- There are two ways to implement a concurrent scan. -- -- 1. Make the scan itself asynchronous, add the input to the queue, and then