module Language.Hakaru.Evaluation.EvalMonad
( runPureEvaluate
, pureEvaluate
, ListContext(..), PureAns, Eval(..), runEval
, residualizePureListContext
) where
import Prelude hiding (id, (.))
import Control.Category (Category(..))
#if __GLASGOW_HASKELL__ < 710
import Data.Functor ((<$>))
import Control.Applicative (Applicative(..))
#endif
import qualified Data.Foldable as F
import Language.Hakaru.Syntax.IClasses (Some2(..))
import Language.Hakaru.Syntax.Variable (memberVarSet)
import Language.Hakaru.Syntax.ABT (ABT(..), subst, maxNextFree)
import Language.Hakaru.Syntax.DatumABT
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Evaluation.Types
import Language.Hakaru.Evaluation.Lazy (TermEvaluator, evaluate, defaultCaseEvaluator)
import Language.Hakaru.Evaluation.PEvalMonad (ListContext(..))
import Data.Text (Text)
import qualified Data.Text as Text
import qualified Data.Traversable as T
import Language.Hakaru.Syntax.IClasses (Functor11(..))
import Language.Hakaru.Syntax.Variable (Variable(), toAssocs1)
import Language.Hakaru.Syntax.ABT (caseVarSyn, caseBinds, substs)
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing (Sing, sUnPair)
import Language.Hakaru.Syntax.TypeOf (typeOf)
import Language.Hakaru.Syntax.Datum
import Language.Hakaru.Evaluation.Lazy (reifyPair)
#ifdef __TRACE_DISINTEGRATE__
import Debug.Trace (trace)
#endif
runPureEvaluate :: (ABT Term abt) => abt '[] a -> abt '[] a
runPureEvaluate e = runEval (fromWhnf <$> pureEvaluate e) [Some2 e]
pureEvaluate :: (ABT Term abt) => TermEvaluator abt (Eval abt)
pureEvaluate = evaluate (brokenInvariant "perform") defaultCaseEvaluator
type PureAns abt a = ListContext abt 'Pure -> abt '[] a
newtype Eval abt x =
Eval { unEval :: forall a. (x -> PureAns abt a) -> PureAns abt a }
brokenInvariant :: String -> a
brokenInvariant loc = error (loc ++ ": Eval's invariant broken")
runEval :: (ABT Term abt, F.Foldable f)
=> Eval abt (abt '[] a)
-> f (Some2 abt)
-> abt '[] a
runEval (Eval m) es =
m residualizePureListContext (ListContext (maxNextFree es) [])
residualizePureListContext
:: forall abt a
. (ABT Term abt)
=> abt '[] a
-> ListContext abt 'Pure
-> abt '[] a
residualizePureListContext e0 =
foldl step e0 . statements
where
step :: abt '[] a -> Statement abt 'Pure -> abt '[] a
step e s =
case s of
SLet x body _
| not (x `memberVarSet` freeVars e) -> e
| otherwise ->
case getLazyVariable body of
Just y -> subst x (var y) e
Nothing ->
case getLazyLiteral body of
Just v -> subst x (syn $ Literal_ v) e
Nothing ->
syn (Let_ :$ fromLazy body :* bind x e :* End)
instance Functor (Eval abt) where
fmap f (Eval m) = Eval $ \c -> m (c . f)
instance Applicative (Eval abt) where
pure x = Eval $ \c -> c x
Eval mf <*> Eval mx = Eval $ \c -> mf $ \f -> mx $ \x -> c (f x)
instance Monad (Eval abt) where
return = pure
Eval m >>= k = Eval $ \c -> m $ \x -> unEval (k x) c
instance (ABT Term abt) => EvaluationMonad abt (Eval abt) 'Pure where
freshNat =
Eval $ \c (ListContext i ss) ->
c i (ListContext (i+1) ss)
unsafePush s =
Eval $ \c (ListContext i ss) ->
c () (ListContext i (s:ss))
unsafePushes ss =
Eval $ \c (ListContext i ss') ->
c () (ListContext i (reverse ss ++ ss'))
select x p = loop []
where
loop ss = do
ms <- unsafePop
case ms of
Nothing -> do
unsafePushes ss
return Nothing
Just s ->
case x `isBoundBy` s >> p s of
Nothing -> loop (s:ss)
Just mr -> do
r <- mr
unsafePushes ss
return (Just r)
unsafePop :: Eval abt (Maybe (Statement abt 'Pure))
unsafePop =
Eval $ \c h@(ListContext i ss) ->
case ss of
[] -> c Nothing h
s:ss' -> c (Just s) (ListContext i ss')
emit
:: (ABT Term abt)
=> Text
-> Sing a
-> (forall r. abt '[a] r -> abt '[] r)
-> Eval abt (Variable a)
emit hint typ f = do
x <- freshVar hint typ
Eval $ \c h -> (f . bind x) $ c x h
emitLet :: (ABT Term abt) => abt '[] a -> Eval abt (Variable a)
emitLet e =
caseVarSyn e return $ \_ ->
emit Text.empty (typeOf e) $ \f ->
syn (Let_ :$ e :* f :* End)
emitLet' :: (ABT Term abt) => abt '[] a -> Eval abt (abt '[] a)
emitLet' e =
caseVarSyn e (const $ return e) $ \t ->
case t of
Literal_ _ -> return e
_ -> do
x <- emit Text.empty (typeOf e) $ \f ->
syn (Let_ :$ e :* f :* End)
return (var x)
emitUnpair
:: (ABT Term abt)
=> Whnf abt (HPair a b)
-> Eval abt (abt '[] a, abt '[] b)
emitUnpair (Head_ w) = return $ reifyPair w
emitUnpair (Neutral e) = do
let (a,b) = sUnPair (typeOf e)
x <- freshVar Text.empty a
y <- freshVar Text.empty b
emitUnpair_ x y e
emitUnpair_
:: forall abt a b
. (ABT Term abt)
=> Variable a
-> Variable b
-> abt '[] (HPair a b)
-> Eval abt (abt '[] a, abt '[] b)
emitUnpair_ x y = loop
where
done :: abt '[] (HPair a b) -> Eval abt (abt '[] a, abt '[] b)
done e =
#ifdef __TRACE_DISINTEGRATE__
trace "-- emitUnpair: done (term is not Datum_ nor Case_)" $
#endif
Eval $ \c h ->
( syn
. Case_ e
. (:[])
. Branch (pPair PVar PVar)
. bind x
. bind y
) $ c (var x, var y) h
loop :: abt '[] (HPair a b) -> Eval abt (abt '[] a, abt '[] b)
loop e0 =
caseVarSyn e0 (done . var) $ \t ->
case t of
Datum_ d -> do
#ifdef __TRACE_DISINTEGRATE__
trace "-- emitUnpair: found Datum_" $ return ()
#endif
return $ reifyPair (WDatum d)
Case_ e bs -> do
#ifdef __TRACE_DISINTEGRATE__
trace "-- emitUnpair: going under Case_" $ return ()
#endif
emitCaseWith loop e bs
_ -> done e0
emitFork_
:: (ABT Term abt, T.Traversable t)
=> (forall r. t (abt '[] r) -> abt '[] r)
-> t (Eval abt a)
-> Eval abt a
emitFork_ f ms =
Eval $ \c h -> f $ fmap (\m -> unEval m c h) ms
emitCaseWith
:: (ABT Term abt)
=> (abt '[] b -> Eval abt r)
-> abt '[] a
-> [Branch a abt b]
-> Eval abt r
emitCaseWith f e bs = do
gms <- T.for bs $ \(Branch pat body) ->
let (vars, body') = caseBinds body
in (\vars' ->
let rho = toAssocs1 vars (fmap11 var vars')
in GBranch pat vars' (f $ substs rho body')
) <$> freshenVars vars
Eval $ \c h ->
syn (Case_ e
(map (fromGBranch . fmap (\m -> unEval m c h)) gms))