{-# LANGUAGE ScopedTypeVariables #-}
module Clash.Rewrite.Combinators where
import Control.DeepSeq (deepseq)
import Control.Monad ((<=<), (>=>))
import qualified Control.Monad.Writer as Writer
import qualified Data.Monoid as Monoid
import Unbound.Generics.LocallyNameless (Embed, Fresh, bind, embed,
rec, unbind, unembed, unrec)
import Unbound.Generics.LocallyNameless.Unsafe (unsafeUnbind)
import Clash.Core.Term (Pat, Term (..))
import Clash.Core.Util (patIds)
import Clash.Core.Var (Id)
import Clash.Rewrite.Types
allR :: forall m . (Monad m, Fresh m)
=> Bool
-> Transform m
-> Transform m
allR _ _ _ (Var t x) = return (Var t x)
allR _ _ _ (Data dc) = return (Data dc)
allR _ _ _ (Literal l) = return (Literal l)
allR _ _ _ (Prim nm t) = return (Prim nm t)
allR rf trans c (Lam b) = do
(v,e) <- if rf then unbind b else return (unsafeUnbind b)
e' <- trans (LamBody v:c) e
return . Lam $ bind v e'
allR rf trans c (TyLam b) = do
(tv, e) <- if rf then unbind b else return (unsafeUnbind b)
e' <- trans (TyLamBody tv:c) e
return . TyLam $ bind tv e'
allR _ trans c (App e1 e2) = do
e1' <- trans (AppFun:c) e1
e2' <- trans (AppArg:c) e2
return $ App e1' e2'
allR _ trans c (TyApp e ty) = do
e' <- trans (TyAppC:c) e
return $ TyApp e' ty
allR _ trans c (Cast e ty1 ty2) = do
e' <- trans (CastBody:c) e
return $ Cast e' ty1 ty2
allR rf trans c (Letrec b) = do
(xesR,e) <- if rf then unbind b else return (unsafeUnbind b)
let xes = unrec xesR
let bndrs = map fst xes
e' <- trans (LetBody bndrs:c) e
xes' <- mapM (rewriteBind bndrs) xes
return . Letrec $ bind (rec xes') e'
where
rewriteBind :: [Id] -> (Id,Embed Term) -> m (Id,Embed Term)
rewriteBind bndrs (b', e) = do
e' <- trans (LetBinding b' bndrs:c) (unembed e)
return (b',embed e')
allR rf trans c (Case scrut ty alts) = do
scrut' <- trans (CaseScrut:c) scrut
alts' <- if rf then mapM (fmap (uncurry bind) . rewriteAlt <=< unbind) alts
else mapM (fmap (uncurry bind) . rewriteAlt . unsafeUnbind) alts
return $ Case scrut' ty alts'
where
rewriteAlt :: (Pat, Term) -> m (Pat, Term)
rewriteAlt (p,e) = do
let (tvs,ids) = patIds p
e' <- trans (CaseAlt tvs ids:c) e
return (p,e')
infixr 6 >->
(>->) :: Monad m => Transform m -> Transform m -> Transform m
(>->) r1 r2 c = r1 c >=> r2 c
infixr 6 >-!->
(>-!->) :: Monad m => Transform m -> Transform m -> Transform m
(>-!->) r1 r2 c e = do
e' <- r1 c e
deepseq e' (r2 c e')
topdownR :: Rewrite m -> Rewrite m
topdownR r = repeatR r >-> allR True (topdownR r)
unsafeTopdownR :: Rewrite m -> Rewrite m
unsafeTopdownR r = repeatR r >-> allR False (unsafeTopdownR r)
bottomupR :: Fresh m => Transform m -> Transform m
bottomupR r = allR True (bottomupR r) >-> r
unsafeBottomupR :: Fresh m => Transform m -> Transform m
unsafeBottomupR r = allR False (unsafeBottomupR r) >-> r
infixr 5 !->
(!->) :: Rewrite m -> Rewrite m -> Rewrite m
(!->) r1 r2 c expr = do
(expr',changed) <- Writer.listen $ r1 c expr
if Monoid.getAny changed
then r2 c expr'
else return expr'
infixr 5 >-!
(>-!) :: Rewrite m -> Rewrite m -> Rewrite m
(>-!) r1 r2 c expr = do
(expr',changed) <- Writer.listen $ r1 c expr
if Monoid.getAny changed
then return expr'
else r2 c expr'
repeatR :: Rewrite m -> Rewrite m
repeatR r = r !-> repeatR r
whenR :: Monad m
=> ([CoreContext] -> Term -> m Bool)
-> Transform m
-> Transform m
whenR f r1 ctx expr = do
b <- f ctx expr
if b
then r1 ctx expr
else return expr
bottomupWhenR :: Fresh m
=> ([CoreContext] -> Term -> m Bool)
-> Transform m
-> Transform m
bottomupWhenR f r ctx expr = do
b <- f ctx expr
if b
then (allR True (bottomupWhenR f r) >-> r) ctx expr
else r ctx expr