Skip to content

Commit 21ca6d4

Browse files
committed
add ListT monad transformer and relevant instances
1 parent e3bd0b4 commit 21ca6d4

File tree

1 file changed

+185
-1
lines changed

1 file changed

+185
-1
lines changed

src/Streaming/Prelude.hs

Lines changed: 185 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,12 @@
5353
{-# LANGUAGE DeriveFoldable #-}
5454
{-# LANGUAGE DeriveFunctor #-}
5555
{-# LANGUAGE DeriveTraversable #-}
56+
{-# LANGUAGE FlexibleInstances #-}
57+
{-# LANGUAGE MultiParamTypeClasses #-}
5658
{-# LANGUAGE RankNTypes #-}
5759
{-# LANGUAGE ScopedTypeVariables #-}
5860
{-# LANGUAGE TypeFamilies #-}
61+
{-# LANGUAGE UndecidableInstances #-}
5962

6063
{-# OPTIONS_GHC -Wall #-}
6164

@@ -255,15 +258,26 @@ module Streaming.Prelude (
255258

256259
-- * Basic Type
257260
, Stream
261+
262+
-- * ListT
263+
, ListT(..)
264+
, runListT
258265
) where
259266
import Streaming.Internal
260267

261268
import Control.Monad hiding (filterM, mapM, mapM_, foldM, foldM_, replicateM, sequence)
262269
import Data.Functor.Identity
263270
import Data.Functor.Sum
264271
import Control.Monad.Trans
265-
import Control.Applicative (Applicative (..))
272+
import Control.Applicative (Applicative (..), Alternative (..))
273+
import Control.Monad.Morph
274+
import Control.Monad.Error.Class
275+
import Control.Monad.Reader.Class
276+
import Control.Monad.State.Class
277+
import Control.Monad.Writer.Class
278+
import Control.Monad.Zip
266279
import Data.Functor (Functor (..), (<$))
280+
import Data.Semigroup (Semigroup ((<>)))
267281

268282
import qualified Prelude as Prelude
269283
import qualified Data.Foldable as Foldable
@@ -2921,3 +2935,173 @@ mapMaybeM phi = loop where
29212935
Nothing -> loop snext
29222936
Just b -> Step (b :> loop snext)
29232937
{-#INLINABLE mapMaybeM #-}
2938+
2939+
{-| The list monad transformer.
2940+
'pure' and 'return' correspond to 'yield', yielding a single value.
2941+
('>>=') corresponds to 'for', calling the second computation once for
2942+
each time the first computation 'yield's.
2943+
-}
2944+
newtype ListT m a = Select { enumerate :: Stream (Of a) m () }
2945+
2946+
instance Monad m => Functor (ListT m) where
2947+
fmap f p = Select (for (enumerate p) (\a -> yield (f a)))
2948+
{-# INLINE fmap #-}
2949+
2950+
instance Monad m => Applicative (ListT m) where
2951+
pure a = Select (yield a)
2952+
{-# INLINE pure #-}
2953+
mf <*> mx = Select (
2954+
for (enumerate mf) (\f ->
2955+
for (enumerate mx) (\x ->
2956+
yield (f x) ) ) )
2957+
2958+
instance Monad m => Monad (ListT m) where
2959+
return = pure
2960+
{-# INLINE return #-}
2961+
m >>= f = Select (for (enumerate m) (\a -> enumerate (f a)))
2962+
{-# INLINE (>>=) #-}
2963+
2964+
instance Foldable m => Foldable (ListT m) where
2965+
foldMap f = go . enumerate
2966+
where
2967+
go p = case p of
2968+
Return () -> mempty
2969+
Effect m -> Foldable.foldMap go m
2970+
Step (a :> rest) -> f a `mappend` go rest
2971+
{-# INLINE foldMap #-}
2972+
2973+
instance (Monad m, Traversable m) => Traversable (ListT m) where
2974+
traverse k (Select p) = fmap Select (traverse_ p)
2975+
where
2976+
traverse_ (Return ()) = pure (Return ())
2977+
traverse_ (Effect m) = fmap Effect (traverse traverse_ m)
2978+
traverse_ (Step (a :> rest)) = (\a_ rest_ -> Step (a_ :> rest_)) <$> k a <*> traverse_ rest
2979+
2980+
instance MonadTrans ListT where
2981+
lift m = Select (do
2982+
a <- lift m
2983+
yield a )
2984+
2985+
instance MonadIO m => MonadIO (ListT m) where
2986+
liftIO m = lift (liftIO m)
2987+
{-# INLINE liftIO #-}
2988+
2989+
instance Monad m => Alternative (ListT m) where
2990+
empty = Select (pure ())
2991+
{-# INLINE empty #-}
2992+
p1 <|> p2 = Select (do
2993+
enumerate p1
2994+
enumerate p2 )
2995+
2996+
instance Monad m => MonadPlus (ListT m) where
2997+
mzero = empty
2998+
{-# INLINE mzero #-}
2999+
mplus = (<|>)
3000+
{-# INLINE mplus #-}
3001+
3002+
instance MFunctor ListT where
3003+
hoist morph = Select . hoist morph . enumerate
3004+
{-# INLINE hoist #-}
3005+
3006+
instance MMonad ListT where
3007+
embed f (Select p0) = Select (loop p0)
3008+
where
3009+
loop (Return ()) = Return ()
3010+
loop (Effect m) = for (enumerate (fmap loop (f m))) id
3011+
loop (Step (a :> rest)) = Step (a :> loop rest)
3012+
{-# INLINE embed #-}
3013+
3014+
instance Monad m => Semigroup (ListT m a) where
3015+
(<>) = (<|>)
3016+
{-# INLINE (<>) #-}
3017+
3018+
instance Monad m => Monoid (ListT m a) where
3019+
mempty = empty
3020+
{-# INLINE mempty #-}
3021+
#if !(MIN_VERSION_base(4,11,0))
3022+
mappend = (<|>)
3023+
{-# INLINE mappend #-}
3024+
#endif
3025+
3026+
instance (MonadState s m) => MonadState s (ListT m) where
3027+
get = lift get
3028+
{-# INLINE get #-}
3029+
3030+
put s = lift (put s)
3031+
{-# INLINE put #-}
3032+
3033+
state f = lift (state f)
3034+
{-# INLINE state #-}
3035+
3036+
instance (MonadWriter w m) => MonadWriter w (ListT m) where
3037+
writer = lift . writer
3038+
{-# INLINE writer #-}
3039+
3040+
tell w = lift (tell w)
3041+
{-# INLINE tell #-}
3042+
3043+
--listen :: ListT m a -> ListT m (a, w)
3044+
listen l = Select (go (enumerate l) mempty)
3045+
where
3046+
go p w = case p of
3047+
Return () -> Return ()
3048+
Effect m -> Effect (do
3049+
(p', w') <- listen m
3050+
pure (go p' $! mappend w w') )
3051+
Step (a :> rest) -> Step ( (a,w) :> go rest w)
3052+
3053+
pass l = Select (go (enumerate l) mempty)
3054+
where
3055+
--go :: forall m a w. Stream (Of (w, a)) m () -> (w -> w) -> Stream (Of a) m ()
3056+
go p w = case p of
3057+
Return () -> Return ()
3058+
Effect m -> Effect (do
3059+
(p', w') <- listen m
3060+
pure (go p' $! mappend w w'))
3061+
Step ((b, f) :> rest) -> Effect (pass (return (Step (b :> (go rest (f w))), \_ -> f w) ))
3062+
3063+
instance (MonadReader i m) => MonadReader i (ListT m) where
3064+
ask = lift ask
3065+
{-# INLINE ask #-}
3066+
3067+
local f l = Select (local f (enumerate l))
3068+
{-# INLINE local #-}
3069+
3070+
reader f = lift (reader f)
3071+
{-# INLINE reader #-}
3072+
3073+
instance (MonadError e m) => MonadError e (ListT m) where
3074+
throwError e = lift (throwError e)
3075+
{-# INLINE throwError #-}
3076+
3077+
catchError l k = Select (catchError (enumerate l) (\e -> enumerate (k e)))
3078+
{-# INLINE catchError #-}
3079+
3080+
{- These instances require a dependency on `exceptions`.
3081+
instance MonadThrow m => MonadThrow (ListT m) where
3082+
throwM = Select . throwM
3083+
{-# INLINE throwM #-}
3084+
instance MonadCatch m => MonadCatch (ListT m) where
3085+
catch l k = Select (Control.Monad.Catch.catch (enumerate l) (\e -> enumerate (k e)))
3086+
{-# INLINE catch #-}
3087+
-}
3088+
3089+
instance Monad m => MonadZip (ListT m) where
3090+
mzipWith f = go
3091+
where
3092+
go xs ys = Select $ do
3093+
xres <- lift $ next (enumerate xs)
3094+
case xres of
3095+
Left () -> pure ()
3096+
Right (x, xrest) -> do
3097+
yres <- lift $ next (enumerate ys)
3098+
case yres of
3099+
Left () -> pure ()
3100+
Right (y, yrest) -> do
3101+
yield (f x y)
3102+
enumerate (go (Select xrest) (Select yrest))
3103+
3104+
-- | Run a self-contained 'ListT' computation
3105+
runListT :: Monad m => ListT m a -> m ()
3106+
runListT l = effects (enumerate (l >> mzero))
3107+
{-# INLINABLE runListT #-}

0 commit comments

Comments
 (0)