{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE TupleSections #-}
{-# HLINT ignore "Use gets" #-}
{-# HLINT ignore "Use asks" #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
module Try.Monads.MonadBaseControl where
import Control.Monad.Base (MonadBase)
import Control.Monad.Reader (MonadReader (ask), ReaderT (runReaderT))
import Control.Monad.State (MonadState (get, put), MonadTrans (lift), StateT (runStateT))
import Control.Monad.Trans.Writer
import Data.Functor ((<&>))
https://lexi-lambda.github.io/blog/2019/09/07/demystifying-monadbasecontrol/
The essence of MonadBaseControl
- Capture the action’s input state and close over it.
- Package up the action’s output state with its result and run it.
- Restore the action’s output state into the enclosing transformer.
- Return the action’s result.
class MonadBase b m => MonadBaseControl b m | m -> b where
type InputState m
type OutputState m
captureInputState :: m (InputState m)
-- run monad with an input state and return a result and the output state in another monad
-- we have access to the result of the first monad
closeOverInputState :: m a -> InputState m -> b (a, OutputState m)
restoreOutputState :: OutputState m -> m ()
instance MonadBaseControl IO IO where
type InputState IO = ()
type OutputState IO = ()
captureInputState = pure ()
closeOverInputState m () = m <&> (,())
restoreOutputState () = pure ()
instance MonadBaseControl b m => MonadBaseControl b (StateT s m) where
type InputState (StateT s m) = (s, InputState m)
type OutputState (StateT s m) = (s, OutputState m)
captureInputState = (,) <$> get <*> lift captureInputState
closeOverInputState m (s, ss) = do
((v, s'), ss') <- closeOverInputState (runStateT m s) ss
pure (v, (s', ss'))
restoreOutputState (s, ss) = lift (restoreOutputState ss) *> put s
instance MonadBaseControl b m => MonadBaseControl b (ReaderT r m) where
type InputState (ReaderT r m) = (r, InputState m)
type OutputState (ReaderT r m) = OutputState m
captureInputState = (,) <$> ask <*> lift captureInputState
closeOverInputState m (r, s) = closeOverInputState (runReaderT m r) s
restoreOutputState s = lift (restoreOutputState s)
instance (MonadBaseControl b m, Monoid w) => MonadBaseControl b (WriterT w m) where
type InputState (WriterT w m) = InputState m
type OutputState (WriterT w m) = (w, OutputState m)
captureInputState = lift captureInputState
closeOverInputState m ss = do
((v, s'), ss') <- closeOverInputState (runWriterT m) ss
pure (v, (s', ss'))
restoreOutputState (s, ss) = lift (restoreOutputState ss) *> tell s