Skip to content

Commit

Permalink
Add concurrent scan combinators
Browse files Browse the repository at this point in the history
  • Loading branch information
harendra-kumar committed Aug 27, 2024
1 parent e026ad8 commit 32350bc
Show file tree
Hide file tree
Showing 5 changed files with 405 additions and 24 deletions.
87 changes: 85 additions & 2 deletions src/Streamly/Internal/Data/Fold/Channel/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ module Streamly.Internal.Data.Fold.Channel.Type

-- ** Operations
, newChannelWith
, newChannelWithScan
, newChannel
, sendToWorker
, sendToWorker_
, checkFoldStatus -- XXX collectFoldOutput
, dumpChannel
, cleanup
, finalize
)
where

Expand All @@ -33,17 +36,19 @@ where
import Control.Concurrent (ThreadId, myThreadId, tryPutMVar)
import Control.Concurrent.MVar (MVar, newEmptyMVar, takeMVar)
import Control.Exception (SomeException(..))
import Control.Monad (void)
import Control.Monad (void, when)
import Control.Monad.Catch (throwM)
import Control.Monad.IO.Class (MonadIO(..))
import Data.IORef (IORef, newIORef, readIORef)
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import Data.List (intersperse)
import Streamly.Internal.Control.Concurrent
(MonadAsync, MonadRunInIO, askRunInIO)
import Streamly.Internal.Control.ForkLifted (doForkWith)
import Streamly.Internal.Data.Fold (Fold(..))
import Streamly.Internal.Data.Scanl (Scanl(..))
import Streamly.Internal.Data.Channel.Dispatcher (dumpSVarStats)
import Streamly.Internal.Data.Channel.Worker (sendEvent)
import Streamly.Internal.Data.Time.Clock (Clock(Monotonic), getTime)

import qualified Streamly.Internal.Data.Fold as Fold
import qualified Streamly.Internal.Data.Stream as D
Expand All @@ -58,6 +63,7 @@ import Streamly.Internal.Data.Channel.Types

data OutEvent b =
FoldException ThreadId SomeException
| FoldPartial b
| FoldDone ThreadId b

-- | The fold driver thread queues the input of the fold in the 'inputQueue'
Expand Down Expand Up @@ -201,6 +207,10 @@ sendYieldToDriver sv res = liftIO $ do
tid <- myThreadId
void $ sendToDriver sv (FoldDone tid res)

sendPartialToDriver :: MonadIO m => Channel m a b -> b -> m ()
sendPartialToDriver sv res = liftIO $ do
void $ sendToDriver sv (FoldPartial res)

{-# NOINLINE sendExceptionToDriver #-}
sendExceptionToDriver :: Channel m a b -> SomeException -> IO ()
sendExceptionToDriver sv e = do
Expand Down Expand Up @@ -319,6 +329,61 @@ newChannelWith outq outqDBell modifier f = do
let f1 = Fold.rmapM (void . sendYieldToDriver chan) f
in D.fold f1 $ fromInputQueue chan

{-# INLINE scanToChannel #-}
scanToChannel :: MonadIO m => Channel m a b -> Scanl m a b -> Scanl m a ()
scanToChannel chan (Scanl step initial extract final) =
Scanl step1 initial1 extract1 final1

where

initial1 = do
r <- initial
case r of
Fold.Partial s -> do
return $ Fold.Partial s
Fold.Done b ->
Fold.Done <$> void (sendYieldToDriver chan b)

step1 st x = do
r <- step st x
case r of
Fold.Partial s -> do
b <- extract s
void $ sendPartialToDriver chan b
return $ Fold.Partial s
Fold.Done b ->
Fold.Done <$> void (sendYieldToDriver chan b)

extract1 _ = return ()

final1 st = void (final st)

{-# INLINABLE newChannelWithScan #-}
{-# SPECIALIZE newChannelWithScan ::
IORef ([OutEvent b], Int)
-> MVar ()
-> (Config -> Config)
-> Scanl IO a b
-> IO (Channel IO a b, ThreadId) #-}
newChannelWithScan :: (MonadRunInIO m) =>
IORef ([OutEvent b], Int)
-> MVar ()
-> (Config -> Config)
-> Scanl m a b
-> m (Channel m a b, ThreadId)
newChannelWithScan outq outqDBell modifier f = do
let config = modifier defaultConfig
sv <- liftIO $ mkNewChannelWith outq outqDBell config
mrun <- askRunInIO
tid <- doForkWith
(getBound config) (work sv) mrun (sendExceptionToDriver sv)
return (sv, tid)

where

{-# NOINLINE work #-}
work chan = D.drain $ D.scanl (scanToChannel chan f) $ fromInputQueue chan

{-# INLINABLE newChannel #-}
{-# SPECIALIZE newChannel ::
(Config -> Config) -> Fold IO a b -> IO (Channel IO a b) #-}
Expand Down Expand Up @@ -362,6 +427,7 @@ checkFoldStatus sv = do
case ev of
FoldException _ e -> throwM e
FoldDone _ b -> return (Just b)
FoldPartial _ -> undefined

{-# INLINE isBufferAvailable #-}
isBufferAvailable :: MonadIO m => Channel m a b -> m Bool
Expand Down Expand Up @@ -434,3 +500,20 @@ sendToWorker_ chan a = go
-- Block for space
-- () <- liftIO $ takeMVar (inputSpaceDoorBell chan)
-- go

-- XXX Cleanup the fold if the stream is interrupted. Add a GC hook.

cleanup :: MonadIO m => Channel m a b -> m ()
cleanup chan = do
when (svarInspectMode chan) $ liftIO $ do
t <- getTime Monotonic
writeIORef (svarStopTime (svarStats chan)) (Just t)
printSVar (dumpChannel chan) "Scan channel done"

finalize :: MonadIO m => Channel m a b -> m ()
finalize chan = do
liftIO $ void
$ sendEvent
(inputQueue chan)
(inputItemDoorBell chan)
ChildStopChannel
29 changes: 7 additions & 22 deletions src/Streamly/Internal/Data/Fold/Concurrent.hs
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,15 @@ where

import Control.Concurrent (newEmptyMVar, takeMVar, throwTo)
import Control.Monad.Catch (throwM)
import Control.Monad (void, when)
import Control.Monad (void)
import Control.Monad.IO.Class (MonadIO(liftIO))
import Data.IORef (newIORef, readIORef, writeIORef)
import Data.IORef (newIORef, readIORef)
import Fusion.Plugin.Types (Fuse(..))
import Streamly.Internal.Control.Concurrent (MonadAsync)
import Streamly.Internal.Data.Channel.Worker (sendEvent)
import Streamly.Internal.Data.Fold (Fold(..), Step (..))
import Streamly.Internal.Data.Stream (Stream(..), Step(..))
import Streamly.Internal.Data.SVar.Type (adaptState)
import Streamly.Internal.Data.Time.Clock (Clock(Monotonic), getTime)

import qualified Data.Map.Strict as Map
import qualified Streamly.Internal.Data.Fold as Fold
Expand All @@ -83,15 +82,6 @@ import Streamly.Internal.Data.Channel.Types
-- Evaluating a Fold
-------------------------------------------------------------------------------

-- XXX Cleanup the fold if the stream is interrupted. Add a GC hook.

cleanup :: MonadIO m => Channel m a b -> m ()
cleanup chan = do
when (svarInspectMode chan) $ liftIO $ do
t <- getTime Monotonic
writeIORef (svarStopTime (svarStats chan)) (Just t)
printSVar (dumpChannel chan) "Fold channel done"

-- | 'parEval' introduces a concurrent stage at the input of the fold. The
-- inputs are asynchronously queued in a buffer and evaluated concurrently with
-- the evaluation of the source stream. On finalization, 'parEval' waits for
Expand Down Expand Up @@ -291,14 +281,6 @@ parUnzipWithM cfg f c1 c2 = Fold.unzipWithM f (parEval cfg c1) (parEval cfg c2)
-- 2. A monolithic implementation of concurrent Stream->Stream scan, using a
-- custom implementation of the scan and the driver.

finalize :: MonadIO m => Channel m a b -> m ()
finalize chan = do
liftIO $ void
$ sendEvent
(inputQueue chan)
(inputItemDoorBell chan)
ChildStopChannel

{-# ANN type ScanState Fuse #-}
data ScanState s q db f =
ScanInit
Expand Down Expand Up @@ -344,6 +326,7 @@ parDistributeScan cfg getFolds (Stream sstep state) =
FoldDone tid b ->
let ch = filter (\(_, t) -> t /= tid) chans
in processOutputs ch xs (b:done)
FoldPartial _ -> undefined

collectOutputs qref chans = do
(_, n) <- liftIO $ readIORef qref
Expand Down Expand Up @@ -418,8 +401,9 @@ data DemuxState s q db f =
-- fold again because some inputs would be lost in between, or (2) have a
-- FoldYield constructor to yield repeatedly so that we can restart the
-- existing fold itself when it is done. But in that case we cannot change the
-- fold once it is started. Whatever we do we should keep the non-concurrent
-- fold as well consistent with that.
-- fold once it is started. Also the Map would keep on increasing in size as we
-- never delete a key. Whatever we do we should keep the non-concurrent fold as
-- well consistent with that.

-- | Evaluate a stream and send its outputs to the selected fold. The fold is
-- dynamically selected using a key at the time of the first input seen for
Expand Down Expand Up @@ -466,6 +450,7 @@ parDemuxScan cfg getKey getFold (Stream sstep state) =
FoldDone _tid o@(k, _) ->
let ch = Map.delete k keyToChan
in processOutputs ch xs (o:done)
FoldPartial _ -> undefined

collectOutputs qref keyToChan = do
(_, n) <- liftIO $ readIORef qref
Expand Down
Loading

0 comments on commit 32350bc

Please sign in to comment.