{-# LANGUAGE ScopedTypeVariables #-}
-- | Rewriting combinators and traversals
module CLaSH.Rewrite.Combinators where

import           Control.Monad               ((<=<), (>=>))
import qualified Control.Monad.Writer        as Writer
import qualified Data.Monoid                 as Monoid
import           Unbound.LocallyNameless     (Embed, Fresh, bind, embed, rec,
                                              unbind, unembed, unrec)
import           Unbound.LocallyNameless.Ops (unsafeUnbind)

import           CLaSH.Core.Term             (Pat, Term (..))
import           CLaSH.Core.Util             (patIds)
import           CLaSH.Core.Var              (Id)
import           CLaSH.Rewrite.Types

-- | Apply a transformation on the subtrees of an term
allR :: forall m . (Functor m, Monad m, Fresh m)
     => Bool -- ^ Freshen variable references in abstracted terms
     -> Transform m -- ^ The transformation to apply to the subtrees
     -> Transform m
allR _ _ _ (Var t x)   = return (Var t x)
allR _ _ _ (Data dc)   = return (Data dc)
allR _ _ _ (Literal l) = return (Literal l)
allR _ _ _ (Prim nm t) = return (Prim nm t)

allR rf trans c (Lam b) = do
  (v,e) <- if rf then unbind b else return (unsafeUnbind b)
  e'    <- trans (LamBody v:c) e
  return . Lam $ bind v e'

allR rf trans c (TyLam b) = do
  (tv, e) <- if rf then unbind b else return (unsafeUnbind b)
  e' <- trans (TyLamBody tv:c) e
  return . TyLam $ bind tv e'

allR _ trans c (App e1 e2) = do
  e1' <- trans (AppFun:c) e1
  e2' <- trans (AppArg:c) e2
  return $ App e1' e2'

allR _ trans c (TyApp e ty) = do
  e' <- trans (TyAppC:c) e
  return $ TyApp e' ty

allR rf trans c (Letrec b) = do
  (xesR,e) <- if rf then unbind b else return (unsafeUnbind b)
  let xes   = unrec xesR
  let bndrs = map fst xes
  e' <- trans (LetBody bndrs:c) e
  xes' <- mapM (rewriteBind bndrs) xes
  return . Letrec $ bind (rec xes') e'
  where
    rewriteBind :: [Id] -> (Id,Embed Term) -> m (Id,Embed Term)
    rewriteBind bndrs (b', e) = do
      e' <- trans (LetBinding bndrs:c) (unembed e)
      return (b',embed e')

allR rf trans c (Case scrut ty alts) = do
  scrut' <- trans (CaseScrut:c) scrut
  alts'  <- if rf then mapM (fmap (uncurry bind) . rewriteAlt <=< unbind) alts
                  else mapM (fmap (uncurry bind) . rewriteAlt . unsafeUnbind) alts
  return $ Case scrut' ty alts'
  where
    rewriteAlt :: (Pat, Term) -> m (Pat, Term)
    rewriteAlt (p,e) = do
      e' <- trans (CaseAlt (patIds p):c) e
      return (p,e')

infixr 6 >->
-- | Apply two transformations in succession
(>->) :: (Monad m) => Transform m -> Transform m -> Transform m
(>->) r1 r2 c = r1 c >=> r2 c

-- | Apply a transformation in a topdown traversal
topdownR :: (Fresh m, Functor m, Monad m) => Transform m -> Transform m
topdownR r = r >-> allR True (topdownR r)

-- | Apply a transformation in a topdown traversal. Doesn't freshen bound
-- variables
unsafeTopdownR :: (Fresh m, Functor m, Monad m) => Transform m -> Transform m
unsafeTopdownR r = r >-> allR False (unsafeTopdownR r)

-- | Apply a transformation in a bottomup traversal
bottomupR :: (Fresh m, Functor m, Monad m) => Transform m -> Transform m
bottomupR r = allR True (bottomupR r) >-> r

-- | Apply a transformation in a bottomup traversal. Doesn't freshen bound
-- variables
unsafeBottomupR :: (Fresh m, Functor m, Monad m) => Transform m -> Transform m
unsafeBottomupR r = allR False (unsafeBottomupR r) >-> r

-- | Apply a transformation in a bottomup traversal, when a transformation
-- succeeds in a certain node, apply the transformation further in a topdown
-- traversal starting at that node.
upDownR :: (Functor m,Monad m) => Rewrite m -> Rewrite m
upDownR r = bottomupR (r !-> topdownR r)

-- | Apply a transformation in a bottomup traversal, when a transformation
-- succeeds in a certain node, apply the transformation further in a topdown
-- traversal starting at that node. Doesn't freshen bound variables
unsafeUpDownR :: (Functor m,Monad m) => Rewrite m -> Rewrite m
unsafeUpDownR r = unsafeBottomupR (r !-> unsafeTopdownR r)

infixr 5 !->
-- | Only apply the second transformation if the first one succeeds.
(!->) :: Monad m => Rewrite m -> Rewrite m -> Rewrite m
(!->) r1 r2 c expr = R $ do
  (expr',changed) <- runR $ Writer.listen $ r1 c expr
  if Monoid.getAny changed
    then runR $ r2 c expr'
    else return expr

-- | Keep applying a transformation until it fails.
repeatR :: Monad m => Rewrite m -> Rewrite m
repeatR r = r !-> repeatR r