module Clash.Rewrite.Combinators
( allR
, (!->)
, (>-!)
, (>-!->)
, (>->)
, bottomupR
, repeatR
, topdownR
) where
import Control.DeepSeq (deepseq)
import Control.Monad ((>=>))
import qualified Control.Monad.Writer as Writer
import qualified Data.Monoid as Monoid
import Clash.Core.Term (Term (..), CoreContext (..), primArg, patIds)
import Clash.Core.VarEnv
(extendInScopeSet, extendInScopeSetList)
import Clash.Rewrite.Types
allR
:: forall m
. Monad m
=> Transform m
-> Transform m
allR :: Transform m -> Transform m
allR Transform m
trans (TransformContext InScopeSet
is Context
c) (Lam Id
v Term
e) =
Id -> Term -> Term
Lam Id
v (Term -> Term) -> m Term -> m Term
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is Id
v) (Id -> CoreContext
LamBody Id
vCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e
allR Transform m
trans (TransformContext InScopeSet
is Context
c) (TyLam TyVar
tv Term
e) =
TyVar -> Term -> Term
TyLam TyVar
tv (Term -> Term) -> m Term -> m Term
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> TyVar -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is TyVar
tv) (TyVar -> CoreContext
TyLamBody TyVar
tvCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e
allR Transform m
trans (TransformContext InScopeSet
is Context
c) (App Term
e1 Term
e2) = do
Term
e1' <- Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (CoreContext
AppFunCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e1
Term
e2' <- Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (Maybe (Text, Int, Int) -> CoreContext
AppArg (Term -> Maybe (Text, Int, Int)
primArg Term
e1') CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
c)) Term
e2
Term -> m Term
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Term -> Term -> Term
App Term
e1' Term
e2')
allR Transform m
trans (TransformContext InScopeSet
is Context
c) (TyApp Term
e Type
ty) =
Term -> Type -> Term
TyApp (Term -> Type -> Term) -> m Term -> m (Type -> Term)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (CoreContext
TyAppCCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e m (Type -> Term) -> m Type -> m Term
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Type -> m Type
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Type
ty
allR Transform m
trans (TransformContext InScopeSet
is Context
c) (Cast Term
e Type
ty1 Type
ty2) =
Term -> Type -> Type -> Term
Cast (Term -> Type -> Type -> Term)
-> m Term -> m (Type -> Type -> Term)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (CoreContext
CastBodyCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e m (Type -> Type -> Term) -> m Type -> m (Type -> Term)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Type -> m Type
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Type
ty1 m (Type -> Term) -> m Type -> m Term
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Type -> m Type
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Type
ty2
allR Transform m
trans (TransformContext InScopeSet
is Context
c) (Letrec [LetBinding]
xes Term
e) = do
[LetBinding]
xes' <- (LetBinding -> m LetBinding) -> [LetBinding] -> m [LetBinding]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse LetBinding -> m LetBinding
rewriteBind [LetBinding]
xes
Term
e' <- Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is' ([LetBinding] -> CoreContext
LetBody [LetBinding]
xesCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e
Term -> m Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return ([LetBinding] -> Term -> Term
Letrec [LetBinding]
xes' Term
e')
where
bndrs :: [Id]
bndrs = (LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
xes
is' :: InScopeSet
is' = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
xes)
rewriteBind :: LetBinding -> m LetBinding
rewriteBind (Id
b,Term
e') = (Id
b,) (Term -> LetBinding) -> m Term -> m LetBinding
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is' (Id -> [Id] -> CoreContext
LetBinding Id
b [Id]
bndrsCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e'
allR Transform m
trans (TransformContext InScopeSet
is Context
c) (Case Term
scrut Type
ty [Alt]
alts) =
Term -> Type -> [Alt] -> Term
Case (Term -> Type -> [Alt] -> Term)
-> m Term -> m (Type -> [Alt] -> Term)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (CoreContext
CaseScrutCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
scrut
m (Type -> [Alt] -> Term) -> m Type -> m ([Alt] -> Term)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Type -> m Type
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Type
ty
m ([Alt] -> Term) -> m [Alt] -> m Term
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Alt -> m Alt) -> [Alt] -> m [Alt]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Alt -> m Alt
rewriteAlt [Alt]
alts
where
rewriteAlt :: Alt -> m Alt
rewriteAlt (Pat
p,Term
e) =
let ([TyVar]
tvs,[Id]
ids) = Pat -> ([TyVar], [Id])
patIds Pat
p
is' :: InScopeSet
is' = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList (InScopeSet -> [TyVar] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is [TyVar]
tvs) [Id]
ids
in (Pat
p,) (Term -> Alt) -> m Term -> m Alt
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is' (Pat -> CoreContext
CaseAlt Pat
p CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
c)) Term
e
allR Transform m
trans (TransformContext InScopeSet
is Context
c) (Tick TickInfo
sp Term
e) =
TickInfo -> Term -> Term
Tick TickInfo
sp (Term -> Term) -> m Term -> m Term
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (TickInfo -> CoreContext
TickC TickInfo
spCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e
allR Transform m
_ TransformContext
_ Term
tm = Term -> m Term
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Term
tm
infixr 6 >->
(>->) :: Monad m => Transform m -> Transform m -> Transform m
>-> :: Transform m -> Transform m -> Transform m
(>->) = \Transform m
r1 Transform m
r2 TransformContext
c -> Transform m
r1 TransformContext
c (Term -> m Term) -> (Term -> m Term) -> Term -> m Term
forall (m :: Type -> Type) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Transform m
r2 TransformContext
c
{-# INLINE (>->) #-}
infixr 6 >-!->
(>-!->) :: Monad m => Transform m -> Transform m -> Transform m
>-!-> :: Transform m -> Transform m -> Transform m
(>-!->) = \Transform m
r1 Transform m
r2 TransformContext
c Term
e -> do
Term
e' <- Transform m
r1 TransformContext
c Term
e
Term -> m Term -> m Term
forall a b. NFData a => a -> b -> b
deepseq Term
e' (Transform m
r2 TransformContext
c Term
e')
{-# INLINE (>-!->) #-}
topdownR :: Rewrite m -> Rewrite m
topdownR :: Rewrite m -> Rewrite m
topdownR Rewrite m
r = Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m
repeatR Rewrite m
r Rewrite m -> Rewrite m -> Rewrite m
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-> Rewrite m -> Rewrite m
forall (m :: Type -> Type). Monad m => Transform m -> Transform m
allR (Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m
topdownR Rewrite m
r)
bottomupR :: Monad m => Transform m -> Transform m
bottomupR :: Transform m -> Transform m
bottomupR Transform m
r = Transform m -> Transform m
forall (m :: Type -> Type). Monad m => Transform m -> Transform m
allR (Transform m -> Transform m
forall (m :: Type -> Type). Monad m => Transform m -> Transform m
bottomupR Transform m
r) Transform m -> Transform m -> Transform m
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-> Transform m
r
infixr 5 !->
(!->) :: Rewrite m -> Rewrite m -> Rewrite m
!-> :: Rewrite m -> Rewrite m -> Rewrite m
(!->) = \Rewrite m
r1 Rewrite m
r2 TransformContext
c Term
expr -> do
(Term
expr',Any
changed) <- RewriteMonad m Term -> RewriteMonad m (Term, Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
Writer.listen (RewriteMonad m Term -> RewriteMonad m (Term, Any))
-> RewriteMonad m Term -> RewriteMonad m (Term, Any)
forall a b. (a -> b) -> a -> b
$ Rewrite m
r1 TransformContext
c Term
expr
if Any -> Bool
Monoid.getAny Any
changed
then Rewrite m
r2 TransformContext
c Term
expr'
else Term -> RewriteMonad m Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
expr'
{-# INLINE (!->) #-}
infixr 5 >-!
(>-!) :: Rewrite m -> Rewrite m -> Rewrite m
>-! :: Rewrite m -> Rewrite m -> Rewrite m
(>-!) = \Rewrite m
r1 Rewrite m
r2 TransformContext
c Term
expr -> do
(Term
expr',Any
changed) <- RewriteMonad m Term -> RewriteMonad m (Term, Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
Writer.listen (RewriteMonad m Term -> RewriteMonad m (Term, Any))
-> RewriteMonad m Term -> RewriteMonad m (Term, Any)
forall a b. (a -> b) -> a -> b
$ Rewrite m
r1 TransformContext
c Term
expr
if Any -> Bool
Monoid.getAny Any
changed
then Term -> RewriteMonad m Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
expr'
else Rewrite m
r2 TransformContext
c Term
expr'
{-# INLINE (>-!) #-}
repeatR :: Rewrite m -> Rewrite m
repeatR :: Rewrite m -> Rewrite m
repeatR = let go :: Rewrite m -> Rewrite m
go Rewrite m
r = Rewrite m
r Rewrite m -> Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m -> Rewrite m
!-> Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m
repeatR Rewrite m
r in Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m
go
{-# INLINE repeatR #-}