{-|
  Copyright  :  (C) 2012-2016, University of Twente
  License    :  BSD2 (see the file LICENSE)
  Maintainer :  Christiaan Baaij <christiaan.baaij@gmail.com>

  Rewriting combinators and traversals
-}

{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections       #-}
{-# LANGUAGE ViewPatterns        #-}

module Clash.Rewrite.Combinators
  ( allR
  , (!->)
  , (>-!)
  , (>-!->)
  , (>->)
  , bottomupR
  , repeatR
  , topdownR

  , whenR
  , bottomupWhenR
  )
 where

import           Control.DeepSeq             (deepseq)
import           Control.Monad               ((>=>))
import qualified Control.Monad.Writer        as Writer
import qualified Data.Monoid                 as Monoid

import           Clash.Core.Term             (Term (..), CoreContext (..), primArg)
import           Clash.Core.Util             (patIds)
import           Clash.Core.VarEnv
  (extendInScopeSet, extendInScopeSetList)
import           Clash.Rewrite.Types

-- | Apply a transformation on the subtrees of an term
allR
  :: forall m
   . Monad m
  => Transform m
  -- ^ The transformation to apply to the subtrees
  -> Transform m
allR :: Transform m -> Transform m
allR trans :: Transform m
trans (TransformContext is :: InScopeSet
is c :: Context
c) (Lam v :: Id
v e :: Term
e) =
  Id -> Term -> Term
Lam Id
v (Term -> Term) -> m Term -> m Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is Id
v) (Id -> CoreContext
LamBody Id
vCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e

allR trans :: Transform m
trans (TransformContext is :: InScopeSet
is c :: Context
c) (TyLam tv :: TyVar
tv e :: Term
e) =
  TyVar -> Term -> Term
TyLam TyVar
tv (Term -> Term) -> m Term -> m Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> TyVar -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is TyVar
tv) (TyVar -> CoreContext
TyLamBody TyVar
tvCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e

allR trans :: Transform m
trans (TransformContext is :: InScopeSet
is c :: Context
c) (App e1 :: Term
e1 e2 :: Term
e2) = do
  Term
e1' <- Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (CoreContext
AppFunCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e1
  Term
e2' <- Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (Maybe (Text, Int, Int) -> CoreContext
AppArg (Term -> Maybe (Text, Int, Int)
primArg Term
e1') CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
c)) Term
e2
  Term -> m Term
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term -> Term -> Term
App Term
e1' Term
e2')

allR trans :: Transform m
trans (TransformContext is :: InScopeSet
is c :: Context
c) (TyApp e :: Term
e ty :: Type
ty) =
  Term -> Type -> Term
TyApp (Term -> Type -> Term) -> m Term -> m (Type -> Term)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (CoreContext
TyAppCCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e m (Type -> Term) -> m Type -> m Term
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
ty

allR trans :: Transform m
trans (TransformContext is :: InScopeSet
is c :: Context
c) (Cast e :: Term
e ty1 :: Type
ty1 ty2 :: Type
ty2) =
  Term -> Type -> Type -> Term
Cast (Term -> Type -> Type -> Term)
-> m Term -> m (Type -> Type -> Term)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (CoreContext
CastBodyCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e m (Type -> Type -> Term) -> m Type -> m (Type -> Term)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
ty1 m (Type -> Term) -> m Type -> m Term
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
ty2

allR trans :: Transform m
trans (TransformContext is :: InScopeSet
is c :: Context
c) (Letrec xes :: [LetBinding]
xes e :: Term
e) = do
  [LetBinding]
xes' <- (LetBinding -> m LetBinding) -> [LetBinding] -> m [LetBinding]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse LetBinding -> m LetBinding
rewriteBind [LetBinding]
xes
  Term
e'   <- Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is' ([Id] -> CoreContext
LetBody [Id]
bndrsCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e
  Term -> m Term
forall (m :: * -> *) a. Monad m => a -> m a
return ([LetBinding] -> Term -> Term
Letrec [LetBinding]
xes' Term
e')
 where
  bndrs :: [Id]
bndrs              = (LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
xes
  is' :: InScopeSet
is'                = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
xes)
  rewriteBind :: LetBinding -> m LetBinding
rewriteBind (b :: Id
b,e' :: Term
e') = (Id
b,) (Term -> LetBinding) -> m Term -> m LetBinding
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is' (Id -> [Id] -> CoreContext
LetBinding Id
b [Id]
bndrsCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e'

allR trans :: Transform m
trans (TransformContext is :: InScopeSet
is c :: Context
c) (Case scrut :: Term
scrut ty :: Type
ty alts :: [Alt]
alts) =
  Term -> Type -> [Alt] -> Term
Case (Term -> Type -> [Alt] -> Term)
-> m Term -> m (Type -> [Alt] -> Term)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (CoreContext
CaseScrutCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
scrut
       m (Type -> [Alt] -> Term) -> m Type -> m ([Alt] -> Term)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
ty
       m ([Alt] -> Term) -> m [Alt] -> m Term
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Alt -> m Alt) -> [Alt] -> m [Alt]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Alt -> m Alt
rewriteAlt [Alt]
alts
 where
  rewriteAlt :: Alt -> m Alt
rewriteAlt (p :: Pat
p,e :: Term
e) =
    let (tvs :: [TyVar]
tvs,ids :: [Id]
ids) = Pat -> ([TyVar], [Id])
patIds Pat
p
        is' :: InScopeSet
is'       = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList (InScopeSet -> [TyVar] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is [TyVar]
tvs) [Id]
ids
    in  (Pat
p,) (Term -> Alt) -> m Term -> m Alt
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is' (Pat -> CoreContext
CaseAlt Pat
p CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
c)) Term
e

allR trans :: Transform m
trans (TransformContext is :: InScopeSet
is c :: Context
c) (Tick sp :: TickInfo
sp e :: Term
e) =
  TickInfo -> Term -> Term
Tick TickInfo
sp (Term -> Term) -> m Term -> m Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (TickInfo -> CoreContext
TickC TickInfo
spCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e

allR _ _ tm :: Term
tm = Term -> m Term
forall (f :: * -> *) a. Applicative f => a -> f a
pure Term
tm

infixr 6 >->
-- | Apply two transformations in succession
(>->) :: Monad m => Transform m -> Transform m -> Transform m
>-> :: Transform m -> Transform m -> Transform m
(>->) = \r1 :: Transform m
r1 r2 :: Transform m
r2 c :: TransformContext
c -> Transform m
r1 TransformContext
c (Term -> m Term) -> (Term -> m Term) -> Term -> m Term
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Transform m
r2 TransformContext
c
{-# INLINE (>->) #-}

infixr 6 >-!->
-- | Apply two transformations in succession, and perform a deepseq in between.
(>-!->) :: Monad m => Transform m -> Transform m -> Transform m
>-!-> :: Transform m -> Transform m -> Transform m
(>-!->) = \r1 :: Transform m
r1 r2 :: Transform m
r2 c :: TransformContext
c e :: Term
e -> do
  Term
e' <- Transform m
r1 TransformContext
c Term
e
  Term -> m Term -> m Term
forall a b. NFData a => a -> b -> b
deepseq Term
e' (Transform m
r2 TransformContext
c Term
e')
{-# INLINE (>-!->) #-}

{-
Note [topdown repeatR]
~~~~~~~~~~~~~~~~~~~~~~
In a topdown traversal we need to repeat the transformation r because
if r replaces a parent node with one of its children
we should still apply r to that child, before continuing with its children.

Example: topdownR (inlineBinders (\_ _ -> return True))
on:
> letrec
>   x = 1
> in letrec
>      y = 2
>    in f x y

inlineBinders would inline x and return:
> letrec
>   y = 2
> in f 1 y

Then we must repeat the transformation to let it also inline y.
-}

-- | Apply a transformation in a topdown traversal
topdownR :: Rewrite m -> Rewrite m
-- See Note [topdown repeatR]
topdownR :: Rewrite m -> Rewrite m
topdownR r :: Rewrite m
r = Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m
repeatR Rewrite m
r Rewrite m -> Rewrite m -> Rewrite m
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-> Rewrite m -> Rewrite m
forall (m :: * -> *). Monad m => Transform m -> Transform m
allR (Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m
topdownR Rewrite m
r)

-- | Apply a transformation in a bottomup traversal
bottomupR :: Monad m => Transform m -> Transform m
bottomupR :: Transform m -> Transform m
bottomupR r :: Transform m
r = Transform m -> Transform m
forall (m :: * -> *). Monad m => Transform m -> Transform m
allR (Transform m -> Transform m
forall (m :: * -> *). Monad m => Transform m -> Transform m
bottomupR Transform m
r) Transform m -> Transform m -> Transform m
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-> Transform m
r

infixr 5 !->
-- | Only apply the second transformation if the first one succeeds.
(!->) :: Rewrite m -> Rewrite m -> Rewrite m
!-> :: Rewrite m -> Rewrite m -> Rewrite m
(!->) = \r1 :: Rewrite m
r1 r2 :: Rewrite m
r2 c :: TransformContext
c expr :: Term
expr -> do
  (expr' :: Term
expr',changed :: Any
changed) <- RewriteMonad m Term -> RewriteMonad m (Term, Any)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
Writer.listen (RewriteMonad m Term -> RewriteMonad m (Term, Any))
-> RewriteMonad m Term -> RewriteMonad m (Term, Any)
forall a b. (a -> b) -> a -> b
$ Rewrite m
r1 TransformContext
c Term
expr
  if Any -> Bool
Monoid.getAny Any
changed
    then Rewrite m
r2 TransformContext
c Term
expr'
    else Term -> RewriteMonad m Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
expr'
{-# INLINE (!->) #-}

infixr 5 >-!
-- | Only apply the second transformation if the first one fails.
(>-!) :: Rewrite m -> Rewrite m -> Rewrite m
>-! :: Rewrite m -> Rewrite m -> Rewrite m
(>-!) = \r1 :: Rewrite m
r1 r2 :: Rewrite m
r2 c :: TransformContext
c expr :: Term
expr -> do
  (expr' :: Term
expr',changed :: Any
changed) <- RewriteMonad m Term -> RewriteMonad m (Term, Any)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
Writer.listen (RewriteMonad m Term -> RewriteMonad m (Term, Any))
-> RewriteMonad m Term -> RewriteMonad m (Term, Any)
forall a b. (a -> b) -> a -> b
$ Rewrite m
r1 TransformContext
c Term
expr
  if Any -> Bool
Monoid.getAny Any
changed
    then Term -> RewriteMonad m Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
expr'
    else Rewrite m
r2 TransformContext
c Term
expr'
{-# INLINE (>-!) #-}

-- | Keep applying a transformation until it fails.
repeatR :: Rewrite m -> Rewrite m
repeatR :: Rewrite m -> Rewrite m
repeatR = let go :: Rewrite m -> Rewrite m
go r :: Rewrite m
r = Rewrite m
r Rewrite m -> Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m -> Rewrite m
!-> Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m
repeatR Rewrite m
r in Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m
go
{-# INLINE repeatR #-}

whenR :: Monad m
      => (TransformContext -> Term -> m Bool)
      -> Transform m
      -> Transform m
whenR :: (TransformContext -> Term -> m Bool) -> Transform m -> Transform m
whenR f :: TransformContext -> Term -> m Bool
f r1 :: Transform m
r1 ctx :: TransformContext
ctx expr :: Term
expr = do
  Bool
b <- TransformContext -> Term -> m Bool
f TransformContext
ctx Term
expr
  if Bool
b
    then Transform m
r1 TransformContext
ctx Term
expr
    else Term -> m Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
expr

-- | Only traverse downwards when the assertion evaluates to true
bottomupWhenR
  :: Monad m
  => (TransformContext -> Term -> m Bool)
  -> Transform m
  -> Transform m
bottomupWhenR :: (TransformContext -> Term -> m Bool) -> Transform m -> Transform m
bottomupWhenR f :: TransformContext -> Term -> m Bool
f r :: Transform m
r ctx :: TransformContext
ctx expr :: Term
expr = do
  Bool
b <- TransformContext -> Term -> m Bool
f TransformContext
ctx Term
expr
  if Bool
b
    then (Transform m -> Transform m
forall (m :: * -> *). Monad m => Transform m -> Transform m
allR ((TransformContext -> Term -> m Bool) -> Transform m -> Transform m
forall (m :: * -> *).
Monad m =>
(TransformContext -> Term -> m Bool) -> Transform m -> Transform m
bottomupWhenR TransformContext -> Term -> m Bool
f Transform m
r) Transform m -> Transform m -> Transform m
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>-> Transform m
r) TransformContext
ctx Term
expr
    else Transform m
r TransformContext
ctx Term
expr