{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}
module Clash.Rewrite.Combinators
( allR
, (!->)
, (>-!)
, (>-!->)
, (>->)
, bottomupR
, repeatR
, topdownR
, whenR
, bottomupWhenR
)
where
import Control.DeepSeq (deepseq)
import Control.Monad ((>=>))
import qualified Control.Monad.Writer as Writer
import qualified Data.Monoid as Monoid
import Clash.Core.Term (Term (..), CoreContext (..), primArg)
import Clash.Core.Util (patIds)
import Clash.Core.VarEnv
(extendInScopeSet, extendInScopeSetList)
import Clash.Rewrite.Types
allR
:: forall m
. Monad m
=> Transform m
-> Transform m
allR :: Transform m -> Transform m
allR trans :: Transform m
trans (TransformContext is :: InScopeSet
is c :: Context
c) (Lam v :: Id
v e :: Term
e) =
Id -> Term -> Term
Lam Id
v (Term -> Term) -> m Term -> m Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is Id
v) (Id -> CoreContext
LamBody Id
vCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e
allR trans :: Transform m
trans (TransformContext is :: InScopeSet
is c :: Context
c) (TyLam tv :: TyVar
tv e :: Term
e) =
TyVar -> Term -> Term
TyLam TyVar
tv (Term -> Term) -> m Term -> m Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> TyVar -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is TyVar
tv) (TyVar -> CoreContext
TyLamBody TyVar
tvCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e
allR trans :: Transform m
trans (TransformContext is :: InScopeSet
is c :: Context
c) (App e1 :: Term
e1 e2 :: Term
e2) = do
Term
e1' <- Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (CoreContext
AppFunCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e1
Term
e2' <- Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (Maybe (Text, Int, Int) -> CoreContext
AppArg (Term -> Maybe (Text, Int, Int)
primArg Term
e1') CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
c)) Term
e2
Term -> m Term
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term -> Term -> Term
App Term
e1' Term
e2')
allR trans :: Transform m
trans (TransformContext is :: InScopeSet
is c :: Context
c) (TyApp e :: Term
e ty :: Type
ty) =
Term -> Type -> Term
TyApp (Term -> Type -> Term) -> m Term -> m (Type -> Term)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (CoreContext
TyAppCCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e m (Type -> Term) -> m Type -> m Term
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
ty
allR trans :: Transform m
trans (TransformContext is :: InScopeSet
is c :: Context
c) (Cast e :: Term
e ty1 :: Type
ty1 ty2 :: Type
ty2) =
Term -> Type -> Type -> Term
Cast (Term -> Type -> Type -> Term)
-> m Term -> m (Type -> Type -> Term)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (CoreContext
CastBodyCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e m (Type -> Type -> Term) -> m Type -> m (Type -> Term)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
ty1 m (Type -> Term) -> m Type -> m Term
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
ty2
allR trans :: Transform m
trans (TransformContext is :: InScopeSet
is c :: Context
c) (Letrec xes :: [LetBinding]
xes e :: Term
e) = do
[LetBinding]
xes' <- (LetBinding -> m LetBinding) -> [LetBinding] -> m [LetBinding]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse LetBinding -> m LetBinding
rewriteBind [LetBinding]
xes
Term
e' <- Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is' ([Id] -> CoreContext
LetBody [Id]
bndrsCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e
Term -> m Term
forall (m :: * -> *) a. Monad m => a -> m a
return ([LetBinding] -> Term -> Term
Letrec [LetBinding]
xes' Term
e')
where
bndrs :: [Id]
bndrs = (LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
xes
is' :: InScopeSet
is' = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
xes)
rewriteBind :: LetBinding -> m LetBinding
rewriteBind (b :: Id
b,e' :: Term
e') = (Id
b,) (Term -> LetBinding) -> m Term -> m LetBinding
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is' (Id -> [Id] -> CoreContext
LetBinding Id
b [Id]
bndrsCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e'
allR trans :: Transform m
trans (TransformContext is :: InScopeSet
is c :: Context
c) (Case scrut :: Term
scrut ty :: Type
ty alts :: [Alt]
alts) =
Term -> Type -> [Alt] -> Term
Case (Term -> Type -> [Alt] -> Term)
-> m Term -> m (Type -> [Alt] -> Term)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (CoreContext
CaseScrutCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
scrut
m (Type -> [Alt] -> Term) -> m Type -> m ([Alt] -> Term)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
ty
m ([Alt] -> Term) -> m [Alt] -> m Term
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Alt -> m Alt) -> [Alt] -> m [Alt]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Alt -> m Alt
rewriteAlt [Alt]
alts
where
rewriteAlt :: Alt -> m Alt
rewriteAlt (p :: Pat
p,e :: Term
e) =
let (tvs :: [TyVar]
tvs,ids :: [Id]
ids) = Pat -> ([TyVar], [Id])
patIds Pat
p
is' :: InScopeSet
is' = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList (InScopeSet -> [TyVar] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is [TyVar]
tvs) [Id]
ids
in (Pat
p,) (Term -> Alt) -> m Term -> m Alt
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is' (Pat -> CoreContext
CaseAlt Pat
p CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
c)) Term
e
allR trans :: Transform m
trans (TransformContext is :: InScopeSet
is c :: Context
c) (Tick sp :: TickInfo
sp e :: Term
e) =
TickInfo -> Term -> Term
Tick TickInfo
sp (Term -> Term) -> m Term -> m Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (TickInfo -> CoreContext
TickC TickInfo
spCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e
allR _ _ tm :: Term
tm = Term -> m Term
forall (f :: * -> *) a. Applicative f => a -> f a
pure Term
tm
infixr 6 >->
(>->) :: Monad m => Transform m -> Transform m -> Transform m
>-> :: Transform m -> Transform m -> Transform m
(>->) = \r1 :: Transform m
r1 r2 :: Transform m
r2 c :: TransformContext
c -> Transform m
r1 TransformContext
c (Term -> m Term) -> (Term -> m Term) -> Term -> m Term
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Transform m
r2 TransformContext
c
{-# INLINE (>->) #-}
infixr 6 >-!->
(>-!->) :: Monad m => Transform m -> Transform m -> Transform m
>-!-> :: Transform m -> Transform m -> Transform m
(>-!->) = \r1 :: Transform m
r1 r2 :: Transform m
r2 c :: TransformContext
c e :: Term
e -> do
Term
e' <- Transform m
r1 TransformContext
c Term
e
Term -> m Term -> m Term
forall a b. NFData a => a -> b -> b
deepseq Term
e' (Transform m
r2 TransformContext
c Term
e')
{-# INLINE (>-!->) #-}
topdownR :: Rewrite m -> Rewrite m
topdownR :: Rewrite m -> Rewrite m
topdownR r :: Rewrite m
r = Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m
repeatR Rewrite m
r Rewrite m -> Rewrite m -> Rewrite m
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-> Rewrite m -> Rewrite m
forall (m :: * -> *). Monad m => Transform m -> Transform m
allR (Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m
topdownR Rewrite m
r)
bottomupR :: Monad m => Transform m -> Transform m
bottomupR :: Transform m -> Transform m
bottomupR r :: Transform m
r = Transform m -> Transform m
forall (m :: * -> *). Monad m => Transform m -> Transform m
allR (Transform m -> Transform m
forall (m :: * -> *). Monad m => Transform m -> Transform m
bottomupR Transform m
r) Transform m -> Transform m -> Transform m
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-> Transform m
r
infixr 5 !->
(!->) :: Rewrite m -> Rewrite m -> Rewrite m
!-> :: Rewrite m -> Rewrite m -> Rewrite m
(!->) = \r1 :: Rewrite m
r1 r2 :: Rewrite m
r2 c :: TransformContext
c expr :: Term
expr -> do
(expr' :: Term
expr',changed :: Any
changed) <- RewriteMonad m Term -> RewriteMonad m (Term, Any)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
Writer.listen (RewriteMonad m Term -> RewriteMonad m (Term, Any))
-> RewriteMonad m Term -> RewriteMonad m (Term, Any)
forall a b. (a -> b) -> a -> b
$ Rewrite m
r1 TransformContext
c Term
expr
if Any -> Bool
Monoid.getAny Any
changed
then Rewrite m
r2 TransformContext
c Term
expr'
else Term -> RewriteMonad m Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
expr'
{-# INLINE (!->) #-}
infixr 5 >-!
(>-!) :: Rewrite m -> Rewrite m -> Rewrite m
>-! :: Rewrite m -> Rewrite m -> Rewrite m
(>-!) = \r1 :: Rewrite m
r1 r2 :: Rewrite m
r2 c :: TransformContext
c expr :: Term
expr -> do
(expr' :: Term
expr',changed :: Any
changed) <- RewriteMonad m Term -> RewriteMonad m (Term, Any)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
Writer.listen (RewriteMonad m Term -> RewriteMonad m (Term, Any))
-> RewriteMonad m Term -> RewriteMonad m (Term, Any)
forall a b. (a -> b) -> a -> b
$ Rewrite m
r1 TransformContext
c Term
expr
if Any -> Bool
Monoid.getAny Any
changed
then Term -> RewriteMonad m Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
expr'
else Rewrite m
r2 TransformContext
c Term
expr'
{-# INLINE (>-!) #-}
repeatR :: Rewrite m -> Rewrite m
repeatR :: Rewrite m -> Rewrite m
repeatR = let go :: Rewrite m -> Rewrite m
go r :: Rewrite m
r = Rewrite m
r Rewrite m -> Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m -> Rewrite m
!-> Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m
repeatR Rewrite m
r in Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m
go
{-# INLINE repeatR #-}
whenR :: Monad m
=> (TransformContext -> Term -> m Bool)
-> Transform m
-> Transform m
whenR :: (TransformContext -> Term -> m Bool) -> Transform m -> Transform m
whenR f :: TransformContext -> Term -> m Bool
f r1 :: Transform m
r1 ctx :: TransformContext
ctx expr :: Term
expr = do
Bool
b <- TransformContext -> Term -> m Bool
f TransformContext
ctx Term
expr
if Bool
b
then Transform m
r1 TransformContext
ctx Term
expr
else Term -> m Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
expr
bottomupWhenR
:: Monad m
=> (TransformContext -> Term -> m Bool)
-> Transform m
-> Transform m
bottomupWhenR :: (TransformContext -> Term -> m Bool) -> Transform m -> Transform m
bottomupWhenR f :: TransformContext -> Term -> m Bool
f r :: Transform m
r ctx :: TransformContext
ctx expr :: Term
expr = do
Bool
b <- TransformContext -> Term -> m Bool
f TransformContext
ctx Term
expr
if Bool
b
then (Transform m -> Transform m
forall (m :: * -> *). Monad m => Transform m -> Transform m
allR ((TransformContext -> Term -> m Bool) -> Transform m -> Transform m
forall (m :: * -> *).
Monad m =>
(TransformContext -> Term -> m Bool) -> Transform m -> Transform m
bottomupWhenR TransformContext -> Term -> m Bool
f Transform m
r) Transform m -> Transform m -> Transform m
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-> Transform m
r) TransformContext
ctx Term
expr
else Transform m
r TransformContext
ctx Term
expr