diff --git a/src/Streamly/Internal/Data/Fold/Channel/Type.hs b/src/Streamly/Internal/Data/Fold/Channel/Type.hs index 2f2802b62c..fdde964e7f 100644 --- a/src/Streamly/Internal/Data/Fold/Channel/Type.hs +++ b/src/Streamly/Internal/Data/Fold/Channel/Type.hs @@ -20,11 +20,14 @@ module Streamly.Internal.Data.Fold.Channel.Type -- ** Operations , newChannelWith + , newChannelWithScan , newChannel , sendToWorker , sendToWorker_ , checkFoldStatus -- XXX collectFoldOutput , dumpChannel + , cleanup + , finalize ) where @@ -33,17 +36,19 @@ where import Control.Concurrent (ThreadId, myThreadId, tryPutMVar) import Control.Concurrent.MVar (MVar, newEmptyMVar, takeMVar) import Control.Exception (SomeException(..)) -import Control.Monad (void) +import Control.Monad (void, when) import Control.Monad.Catch (throwM) import Control.Monad.IO.Class (MonadIO(..)) -import Data.IORef (IORef, newIORef, readIORef) +import Data.IORef (IORef, newIORef, readIORef, writeIORef) import Data.List (intersperse) import Streamly.Internal.Control.Concurrent (MonadAsync, MonadRunInIO, askRunInIO) import Streamly.Internal.Control.ForkLifted (doForkWith) import Streamly.Internal.Data.Fold (Fold(..)) +import Streamly.Internal.Data.Scanl (Scanl(..)) import Streamly.Internal.Data.Channel.Dispatcher (dumpSVarStats) import Streamly.Internal.Data.Channel.Worker (sendEvent) +import Streamly.Internal.Data.Time.Clock (Clock(Monotonic), getTime) import qualified Streamly.Internal.Data.Fold as Fold import qualified Streamly.Internal.Data.Stream as D @@ -58,6 +63,7 @@ import Streamly.Internal.Data.Channel.Types data OutEvent b = FoldException ThreadId SomeException + | FoldPartial b | FoldDone ThreadId b -- | The fold driver thread queues the input of the fold in the 'inputQueue' @@ -201,6 +207,10 @@ sendYieldToDriver sv res = liftIO $ do tid <- myThreadId void $ sendToDriver sv (FoldDone tid res) +sendPartialToDriver :: MonadIO m => Channel m a b -> b -> m () +sendPartialToDriver sv res = liftIO $ do + void $ sendToDriver sv (FoldPartial res) + {-# NOINLINE sendExceptionToDriver #-} sendExceptionToDriver :: Channel m a b -> SomeException -> IO () sendExceptionToDriver sv e = do @@ -319,6 +329,61 @@ newChannelWith outq outqDBell modifier f = do let f1 = Fold.rmapM (void . sendYieldToDriver chan) f in D.fold f1 $ fromInputQueue chan +{-# INLINE scanToChannel #-} +scanToChannel :: MonadIO m => Channel m a b -> Scanl m a b -> Scanl m a () +scanToChannel chan (Scanl step initial extract final) = + Scanl step1 initial1 extract1 final1 + + where + + initial1 = do + r <- initial + case r of + Fold.Partial s -> do + return $ Fold.Partial s + Fold.Done b -> + Fold.Done <$> void (sendYieldToDriver chan b) + + step1 st x = do + r <- step st x + case r of + Fold.Partial s -> do + b <- extract s + void $ sendPartialToDriver chan b + return $ Fold.Partial s + Fold.Done b -> + Fold.Done <$> void (sendYieldToDriver chan b) + + extract1 _ = return () + + final1 st = void (final st) + +{-# INLINABLE newChannelWithScan #-} +{-# SPECIALIZE newChannelWithScan :: + IORef ([OutEvent b], Int) + -> MVar () + -> (Config -> Config) + -> Scanl IO a b + -> IO (Channel IO a b, ThreadId) #-} +newChannelWithScan :: (MonadRunInIO m) => + IORef ([OutEvent b], Int) + -> MVar () + -> (Config -> Config) + -> Scanl m a b + -> m (Channel m a b, ThreadId) +newChannelWithScan outq outqDBell modifier f = do + let config = modifier defaultConfig + sv <- liftIO $ mkNewChannelWith outq outqDBell config + mrun <- askRunInIO + tid <- doForkWith + (getBound config) (work sv) mrun (sendExceptionToDriver sv) + return (sv, tid) + + where + + {-# NOINLINE work #-} + work chan = D.drain $ D.scanl (scanToChannel chan f) $ fromInputQueue chan + {-# INLINABLE newChannel #-} {-# SPECIALIZE newChannel :: (Config -> Config) -> Fold IO a b -> IO (Channel IO a b) #-} @@ -362,6 +427,7 @@ checkFoldStatus sv = do case ev of FoldException _ e -> throwM e FoldDone _ b -> return (Just b) + FoldPartial _ -> undefined {-# INLINE isBufferAvailable #-} isBufferAvailable :: MonadIO m => Channel m a b -> m Bool @@ -434,3 +500,20 @@ sendToWorker_ chan a = go -- Block for space -- () <- liftIO $ takeMVar (inputSpaceDoorBell chan) -- go + +-- XXX Cleanup the fold if the stream is interrupted. Add a GC hook. + +cleanup :: MonadIO m => Channel m a b -> m () +cleanup chan = do + when (svarInspectMode chan) $ liftIO $ do + t <- getTime Monotonic + writeIORef (svarStopTime (svarStats chan)) (Just t) + printSVar (dumpChannel chan) "Scan channel done" + +finalize :: MonadIO m => Channel m a b -> m () +finalize chan = do + liftIO $ void + $ sendEvent + (inputQueue chan) + (inputItemDoorBell chan) + ChildStopChannel diff --git a/src/Streamly/Internal/Data/Fold/Concurrent.hs b/src/Streamly/Internal/Data/Fold/Concurrent.hs index fd69f2fd4a..53cf3f51e1 100644 --- a/src/Streamly/Internal/Data/Fold/Concurrent.hs +++ b/src/Streamly/Internal/Data/Fold/Concurrent.hs @@ -62,16 +62,15 @@ where import Control.Concurrent (newEmptyMVar, takeMVar, throwTo) import Control.Monad.Catch (throwM) -import Control.Monad (void, when) +import Control.Monad (void) import Control.Monad.IO.Class (MonadIO(liftIO)) -import Data.IORef (newIORef, readIORef, writeIORef) +import Data.IORef (newIORef, readIORef) import Fusion.Plugin.Types (Fuse(..)) import Streamly.Internal.Control.Concurrent (MonadAsync) import Streamly.Internal.Data.Channel.Worker (sendEvent) import Streamly.Internal.Data.Fold (Fold(..), Step (..)) import Streamly.Internal.Data.Stream (Stream(..), Step(..)) import Streamly.Internal.Data.SVar.Type (adaptState) -import Streamly.Internal.Data.Time.Clock (Clock(Monotonic), getTime) import qualified Data.Map.Strict as Map import qualified Streamly.Internal.Data.Fold as Fold @@ -83,15 +82,6 @@ import Streamly.Internal.Data.Channel.Types -- Evaluating a Fold ------------------------------------------------------------------------------- --- XXX Cleanup the fold if the stream is interrupted. Add a GC hook. - -cleanup :: MonadIO m => Channel m a b -> m () -cleanup chan = do - when (svarInspectMode chan) $ liftIO $ do - t <- getTime Monotonic - writeIORef (svarStopTime (svarStats chan)) (Just t) - printSVar (dumpChannel chan) "Fold channel done" - -- | 'parEval' introduces a concurrent stage at the input of the fold. The -- inputs are asynchronously queued in a buffer and evaluated concurrently with -- the evaluation of the source stream. On finalization, 'parEval' waits for @@ -291,14 +281,6 @@ parUnzipWithM cfg f c1 c2 = Fold.unzipWithM f (parEval cfg c1) (parEval cfg c2) -- 2. A monolithic implementation of concurrent Stream->Stream scan, using a -- custom implementation of the scan and the driver. -finalize :: MonadIO m => Channel m a b -> m () -finalize chan = do - liftIO $ void - $ sendEvent - (inputQueue chan) - (inputItemDoorBell chan) - ChildStopChannel - {-# ANN type ScanState Fuse #-} data ScanState s q db f = ScanInit @@ -344,6 +326,7 @@ parDistributeScan cfg getFolds (Stream sstep state) = FoldDone tid b -> let ch = filter (\(_, t) -> t /= tid) chans in processOutputs ch xs (b:done) + FoldPartial _ -> undefined collectOutputs qref chans = do (_, n) <- liftIO $ readIORef qref @@ -418,8 +401,9 @@ data DemuxState s q db f = -- fold again because some inputs would be lost in between, or (2) have a -- FoldYield constructor to yield repeatedly so that we can restart the -- existing fold itself when it is done. But in that case we cannot change the --- fold once it is started. Whatever we do we should keep the non-concurrent --- fold as well consistent with that. +-- fold once it is started. Also the Map would keep on increasing in size as we +-- never delete a key. Whatever we do we should keep the non-concurrent fold as +-- well consistent with that. -- | Evaluate a stream and send its outputs to the selected fold. The fold is -- dynamically selected using a key at the time of the first input seen for @@ -466,6 +450,7 @@ parDemuxScan cfg getKey getFold (Stream sstep state) = FoldDone _tid o@(k, _) -> let ch = Map.delete k keyToChan in processOutputs ch xs (o:done) + FoldPartial _ -> undefined collectOutputs qref keyToChan = do (_, n) <- liftIO $ readIORef qref diff --git a/src/Streamly/Internal/Data/Scanl/Concurrent.hs b/src/Streamly/Internal/Data/Scanl/Concurrent.hs new file mode 100644 index 0000000000..14614e755a --- /dev/null +++ b/src/Streamly/Internal/Data/Scanl/Concurrent.hs @@ -0,0 +1,292 @@ +-- | +-- Module : Streamly.Internal.Data.Scanl.Concurrent +-- Copyright : (c) 2024 Composewell Technologies +-- License : BSD-3-Clause +-- Maintainer : streamly@composewell.com +-- Stability : experimental +-- Portability : GHC + +module Streamly.Internal.Data.Scanl.Concurrent + ( + parDistributeScan + , parDemuxScan + ) +where + +#include "inline.hs" + +import Control.Concurrent (newEmptyMVar, takeMVar, throwTo) +import Control.Monad.Catch (throwM) +import Control.Monad.IO.Class (MonadIO(liftIO)) +import Data.IORef (newIORef, readIORef) +import Fusion.Plugin.Types (Fuse(..)) +import Streamly.Internal.Control.Concurrent (MonadAsync) +import Streamly.Internal.Data.Scanl (Scanl(..)) +import Streamly.Internal.Data.Stream (Stream(..), Step(..)) +import Streamly.Internal.Data.SVar.Type (adaptState) + +import qualified Data.Map.Strict as Map + +import Streamly.Internal.Data.Fold.Channel.Type +import Streamly.Internal.Data.Channel.Types + +------------------------------------------------------------------------------- +-- Concurrent scans +------------------------------------------------------------------------------- + +-- There are two ways to implement a concurrent scan. +-- +-- 1. Make the scan itself asynchronous, add the input to the queue, and then +-- extract the output. Extraction will have to be asynchronous, which will +-- require changes to the scan driver. This will require a different Scanl +-- type. +-- +-- 2. A monolithic implementation of concurrent Stream->Stream scan, using a +-- custom implementation of the scan and the driver. + +{-# ANN type ScanState Fuse #-} +data ScanState s q db f = + ScanInit + | ScanGo s q db [f] + | ScanDrain q db [f] + | ScanStop + +-- XXX return [b] or just b? +-- XXX We can use a one way mailbox type abstraction instead of using an IORef +-- for adding new folds dynamically. + +-- | Evaluate a stream and scan its outputs using zero or more dynamically +-- generated parallel scans. It checks for any new folds at each input +-- generation step. Any new fold is added to the list of folds which are +-- currently running. If there are no folds available, the input is discarded. +-- If a fold completes its output is emitted in the output of the scan. The +-- outputs of the parallel scans are merged in the output stream. +-- +-- >>> import Data.IORef +-- >>> ref <- newIORef [Scanl.take 5 Scanl.sum, Scanl.take 5 Scanl.length :: Scanl IO Int Int] +-- >>> gen = atomicModifyIORef ref (\xs -> ([], xs)) +-- >>> Stream.toList $ Scanl.parDistributeScan id gen (Stream.enumerateFromTo 1 10) +-- ... +-- +{-# INLINE parDistributeScan #-} +parDistributeScan :: MonadAsync m => + (Config -> Config) -> m [Scanl m a b] -> Stream m a -> Stream m [b] +parDistributeScan cfg getFolds (Stream sstep state) = + Stream step ScanInit + + where + + -- XXX can be written as a fold + processOutputs chans events done = do + case events of + [] -> return (chans, done) + (x:xs) -> + case x of + FoldException _tid ex -> do + -- XXX report the fold that threw the exception + liftIO $ mapM_ (`throwTo` ThreadAbort) (fmap snd chans) + mapM_ cleanup (fmap fst chans) + liftIO $ throwM ex + FoldDone tid b -> + let ch = filter (\(_, t) -> t /= tid) chans + in processOutputs ch xs (b:done) + FoldPartial b -> + processOutputs chans xs (b:done) + + collectOutputs qref chans = do + (_, n) <- liftIO $ readIORef qref + if n > 0 + then do + r <- fmap fst $ liftIO $ readOutputQBasic qref + processOutputs chans r [] + else return (chans, []) + + step _ ScanInit = do + q <- liftIO $ newIORef ([], 0) + db <- liftIO newEmptyMVar + return $ Skip (ScanGo state q db []) + + step gst (ScanGo st q db chans) = do + -- merge any new channels added since last input + fxs <- getFolds + newChans <- Prelude.mapM (newChannelWithScan q db cfg) fxs + let allChans = chans ++ newChans + + -- Collect outputs from running channels + (running, outputs) <- collectOutputs q allChans + + -- Send input to running folds + res <- sstep (adaptState gst) st + next <- case res of + Yield x s -> do + -- XXX We might block forever if some folds are already + -- done but we have not read the output queue yet. To + -- avoid that we have to either (1) precheck if space + -- is available in the input queues of all folds so + -- that this does not block, or (2) we have to use a + -- non-blocking read and track progress so that we can + -- restart from where we left. + -- + -- If there is no space available then we should block + -- on doorbell db or inputSpaceDoorBell of the relevant + -- channel. To avoid deadlock the output space can be + -- kept unlimited. However, the blocking will delay the + -- processing of outputs. We should yield the outputs + -- before blocking. + Prelude.mapM_ (`sendToWorker_` x) (fmap fst running) + return $ ScanGo s q db running + Skip s -> do + return $ ScanGo s q db running + Stop -> do + Prelude.mapM_ finalize (fmap fst running) + return $ ScanDrain q db running + if null outputs + then return $ Skip next + else return $ Yield outputs next + step _ (ScanDrain q db chans) = do + (running, outputs) <- collectOutputs q chans + case running of + [] -> return $ Yield outputs ScanStop + _ -> do + if null outputs + then do + liftIO $ takeMVar db + return $ Skip (ScanDrain q db running) + else return $ Yield outputs (ScanDrain q db running) + step _ ScanStop = return Stop + +{-# ANN type DemuxState Fuse #-} +data DemuxState s q db f = + DemuxInit + | DemuxGo s q db f + | DemuxDrain q db f + | DemuxStop + +-- XXX We need to either (1) remember a key when done so that we do not add the +-- fold again because some inputs would be lost in between, or (2) have a +-- FoldYield constructor to yield repeatedly so that we can restart the +-- existing fold itself when it is done. But in that case we cannot change the +-- fold once it is started. Also the Map would keep on increasing in size as we +-- never delete a key. Whatever we do we should keep the non-concurrent fold as +-- well consistent with that. + +-- | Evaluate a stream and send its outputs to the selected scan. The scan is +-- dynamically selected using a key at the time of the first input seen for +-- that key. Any new scan is added to the list of scans which are currently +-- running. If there are no scans available for a given key, the input is +-- discarded. If a constituent scan completes its output is emitted in the +-- output of the composed scan. +-- +-- >>> import qualified Data.Map.Strict as Map +-- >>> import Data.Maybe (fromJust) +-- >>> f1 = ("even", Scanl.take 5 Scanl.sum) +-- >>> f2 = ("odd", Scanl.take 5 Scanl.sum) +-- >>> kv = Map.fromList [f1, f2] +-- >>> getScan k = return (fromJust $ Map.lookup k kv) +-- >>> getKey x = if even x then "even" else "odd" +-- >>> input = Stream.enumerateFromTo 1 10 +-- >>> Stream.toList $ Scanl.parDemuxScan id getKey getScan input +-- ... +-- +{-# INLINE parDemuxScan #-} +parDemuxScan :: (MonadAsync m, Ord k) => + (Config -> Config) + -> (a -> k) + -> (k -> m (Scanl m a b)) + -> Stream m a + -> Stream m [(k, b)] +parDemuxScan cfg getKey getFold (Stream sstep state) = + Stream step DemuxInit + + where + + -- XXX can be written as a fold + processOutputs keyToChan events done = do + case events of + [] -> return (keyToChan, done) + (x:xs) -> + case x of + FoldException _tid ex -> do + -- XXX report the fold that threw the exception + let chans = fmap snd $ Map.toList keyToChan + liftIO $ mapM_ (`throwTo` ThreadAbort) (fmap snd chans) + mapM_ cleanup (fmap fst chans) + liftIO $ throwM ex + FoldDone _tid o@(k, _) -> + let ch = Map.delete k keyToChan + in processOutputs ch xs (o:done) + FoldPartial b -> + processOutputs keyToChan xs (b:done) + + collectOutputs qref keyToChan = do + (_, n) <- liftIO $ readIORef qref + if n > 0 + then do + r <- fmap fst $ liftIO $ readOutputQBasic qref + processOutputs keyToChan r [] + else return (keyToChan, []) + + step _ DemuxInit = do + q <- liftIO $ newIORef ([], 0) + db <- liftIO newEmptyMVar + return $ Skip (DemuxGo state q db Map.empty) + + step gst (DemuxGo st q db keyToChan) = do + -- Collect outputs from running channels + (keyToChan1, outputs) <- collectOutputs q keyToChan + + -- Send input to the selected fold + res <- sstep (adaptState gst) st + + next <- case res of + Yield x s -> do + -- XXX If the fold for a particular key is done and we see that + -- key again. If we have not yet collected the done event we + -- cannot restart the fold because the previous key is already + -- installed. Thererfore, restarting the fold for the same key + -- fraught with races. + let k = getKey x + (keyToChan2, ch) <- + case Map.lookup k keyToChan1 of + Nothing -> do + fld <- getFold k + r@(chan, _) <- newChannelWithScan q db cfg (fmap (k,) fld) + return (Map.insert k r keyToChan1, chan) + Just (chan, _) -> return (keyToChan1, chan) + -- XXX We might block forever if some folds are already + -- done but we have not read the output queue yet. To + -- avoid that we have to either (1) precheck if space + -- is available in the input queues of all folds so + -- that this does not block, or (2) we have to use a + -- non-blocking read and track progress so that we can + -- restart from where we left. + -- + -- If there is no space available then we should block + -- on doorbell db or inputSpaceDoorBell of the relevant + -- channel. To avoid deadlock the output space can be + -- kept unlimited. However, the blocking will delay the + -- processing of outputs. We should yield the outputs + -- before blocking. + sendToWorker_ ch x + return $ DemuxGo s q db keyToChan2 + Skip s -> + return $ DemuxGo s q db keyToChan1 + Stop -> do + let chans = fmap (fst . snd) $ Map.toList keyToChan1 + Prelude.mapM_ finalize chans + return $ DemuxDrain q db keyToChan1 + if null outputs + then return $ Skip next + else return $ Yield outputs next + step _ (DemuxDrain q db keyToChan) = do + (keyToChan1, outputs) <- collectOutputs q keyToChan + if Map.null keyToChan1 + -- XXX null outputs case + then return $ Yield outputs DemuxStop + else do + if null outputs + then do + liftIO $ takeMVar db + return $ Skip (DemuxDrain q db keyToChan1) + else return $ Yield outputs (DemuxDrain q db keyToChan1) + step _ DemuxStop = return Stop diff --git a/src/Streamly/Internal/Data/Scanl/Prelude.hs b/src/Streamly/Internal/Data/Scanl/Prelude.hs new file mode 100644 index 0000000000..b0f150e16a --- /dev/null +++ b/src/Streamly/Internal/Data/Scanl/Prelude.hs @@ -0,0 +1,19 @@ +-- | +-- Module : Streamly.Internal.Data.Scanl.Prelude +-- Copyright : (c) 2022 Composewell Technologies +-- License : BSD-3-Clause +-- Maintainer : streamly@composewell.com +-- Stability : experimental +-- Portability : GHC +-- +module Streamly.Internal.Data.Scanl.Prelude + ( + -- * Channel + module Streamly.Internal.Data.Fold.Channel + -- * Concurrency + , module Streamly.Internal.Data.Scanl.Concurrent + ) +where + +import Streamly.Internal.Data.Fold.Channel +import Streamly.Internal.Data.Scanl.Concurrent diff --git a/streamly.cabal b/streamly.cabal index 0021a27009..456d52c1bc 100644 --- a/streamly.cabal +++ b/streamly.cabal @@ -385,6 +385,7 @@ library , Streamly.Internal.Data.Stream.Prelude , Streamly.Internal.Data.Unfold.Prelude , Streamly.Internal.Data.Fold.Prelude + , Streamly.Internal.Data.Scanl.Prelude -- streamly-unicode (depends on unicode-data) , Streamly.Internal.Unicode.Utf8 @@ -469,6 +470,7 @@ library , Streamly.Internal.Data.Fold.Channel.Type , Streamly.Internal.Data.Fold.Channel , Streamly.Internal.Data.Fold.Concurrent + , Streamly.Internal.Data.Scanl.Concurrent , Streamly.Internal.Data.Unfold.Exception , Streamly.Internal.Data.Unfold.SVar