{-|
  Copyright  :  (C) 2012-2016, University of Twente,
                    2016     , Myrtle Software Ltd,
                    2017     , Google Inc.,
                    2021-2022, QBayLogic B.V.
  License    :  BSD2 (see the file LICENSE)
  Maintainer :  QBayLogic B.V. <devops@qbaylogic.com>

  Type and instance definitions for Rewrite modules
-}

{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}

module Clash.Rewrite.Types where

import Control.DeepSeq                       (NFData)
import Control.Lens                          (Lens', use, (.=))
import qualified Control.Lens as Lens
import Control.Monad.Fix                     (MonadFix)
import Control.Monad.State.Strict            (State)
#if MIN_VERSION_transformers(0,5,6)
import Control.Monad.Reader                  (MonadReader (..))
import Control.Monad.State                   (MonadState (..))
import Control.Monad.Trans.RWS.CPS           (RWST)
import qualified Control.Monad.Trans.RWS.CPS as RWS
import Control.Monad.Writer                  (MonadWriter (..))
#else
import Control.Monad.Trans.RWS.Strict        (RWST)
import qualified Control.Monad.Trans.RWS.Strict as RWS
#endif
import Data.Binary                           (Binary)
import Data.HashMap.Strict                   (HashMap)
import Data.IntMap.Strict                    (IntMap)
import Data.Monoid                           (Any)
import Data.Text                             (Text)
import GHC.Generics

import Clash.Core.PartialEval as PE          (Evaluator)
import Clash.Core.Evaluator.Types as WHNF    (Evaluator, PrimHeap)

import Clash.Core.Term           (Term, Context)
import Clash.Core.Type           (Type)
import Clash.Core.TyCon          (TyConMap, TyConName)
import Clash.Core.Var            (Id)
import Clash.Core.VarEnv         (InScopeSet, VarSet, VarEnv)
import Clash.Driver.Types        (ClashEnv(..), ClashOpts(..), BindingMap, DebugOpts)
import Clash.Netlist.Types       (FilteredHWType, HWMap)
import Clash.Primitives.Types    (CompiledPrimMap)
import Clash.Rewrite.WorkFree    (isWorkFree)
import Clash.Util
import Clash.Util.Supply         (Supply, freshId)

import Clash.Annotations.BitRepresentation.Internal (CustomReprs)

-- | State used by the inspection mechanism for recording rewrite steps.
data RewriteStep
  = RewriteStep
  { RewriteStep -> Context
t_ctx    :: Context
  -- ^ current context
  , RewriteStep -> String
t_name   :: String
  -- ^ Name of the transformation
  , RewriteStep -> String
t_bndrS  :: String
  -- ^ Name of the current binder
  , RewriteStep -> Term
t_before :: Term
  -- ^ Term before `apply`
  , RewriteStep -> Term
t_after  :: Term
  -- ^ Term after `apply`
  } deriving (Int -> RewriteStep -> ShowS
[RewriteStep] -> ShowS
RewriteStep -> String
(Int -> RewriteStep -> ShowS)
-> (RewriteStep -> String)
-> ([RewriteStep] -> ShowS)
-> Show RewriteStep
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RewriteStep] -> ShowS
$cshowList :: [RewriteStep] -> ShowS
show :: RewriteStep -> String
$cshow :: RewriteStep -> String
showsPrec :: Int -> RewriteStep -> ShowS
$cshowsPrec :: Int -> RewriteStep -> ShowS
Show, (forall x. RewriteStep -> Rep RewriteStep x)
-> (forall x. Rep RewriteStep x -> RewriteStep)
-> Generic RewriteStep
forall x. Rep RewriteStep x -> RewriteStep
forall x. RewriteStep -> Rep RewriteStep x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep RewriteStep x -> RewriteStep
$cfrom :: forall x. RewriteStep -> Rep RewriteStep x
Generic, RewriteStep -> ()
(RewriteStep -> ()) -> NFData RewriteStep
forall a. (a -> ()) -> NFData a
rnf :: RewriteStep -> ()
$crnf :: RewriteStep -> ()
NFData, Get RewriteStep
[RewriteStep] -> Put
RewriteStep -> Put
(RewriteStep -> Put)
-> Get RewriteStep -> ([RewriteStep] -> Put) -> Binary RewriteStep
forall t. (t -> Put) -> Get t -> ([t] -> Put) -> Binary t
putList :: [RewriteStep] -> Put
$cputList :: [RewriteStep] -> Put
get :: Get RewriteStep
$cget :: Get RewriteStep
put :: RewriteStep -> Put
$cput :: RewriteStep -> Put
Binary)

-- | State of a rewriting session
data RewriteState extra
  = RewriteState
    -- TODO Given we now keep transformCounters, this should just be 'fold'
    -- over that map, otherwise the two counts could fall out of sync.
  { RewriteState extra -> Word
_transformCounter :: {-# UNPACK #-} !Word
  -- ^ Total number of applied transformations
  , RewriteState extra -> HashMap Text Word
_transformCounters :: HashMap Text Word
  -- ^ Map that tracks how many times each transformation is applied
  , RewriteState extra -> BindingMap
_bindings         :: !BindingMap
  -- ^ Global binders
  , RewriteState extra -> Supply
_uniqSupply       :: !Supply
  -- ^ Supply of unique numbers
  , RewriteState extra -> (Id, SrcSpan)
_curFun           :: (Id,SrcSpan) -- Initially set to undefined: no strictness annotation
  -- ^ Function which is currently normalized
  , RewriteState extra -> Int
_nameCounter      :: {-# UNPACK #-} !Int
  -- ^ Used for 'Fresh'
  , RewriteState extra -> PrimHeap
_globalHeap       :: PrimHeap
  -- ^ Used as a heap for compile-time evaluation of primitives that live in I/O
  , RewriteState extra -> VarEnv Bool
_workFreeBinders  :: VarEnv Bool
  -- ^ Map telling whether a binder's definition is work-free
  , RewriteState extra -> extra
_extra            :: !extra
  -- ^ Additional state
  }

Lens.makeLenses ''RewriteState

-- | Read-only environment of a rewriting session
data RewriteEnv
  = RewriteEnv
  { RewriteEnv -> ClashEnv
_clashEnv       :: ClashEnv
  -- ^ The global environment of the compiler
  , RewriteEnv
-> CustomReprs
-> TyConMap
-> Type
-> State HWMap (Maybe (Either String FilteredHWType))
_typeTranslator :: CustomReprs
                    -> TyConMap
                    -> Type
                    -> State HWMap (Maybe (Either String FilteredHWType))
  -- ^ Hardcode Type -> FilteredHWType translator
  , RewriteEnv -> Evaluator
_peEvaluator    :: PE.Evaluator
  -- ^ Hardcoded evaluator for partial evaluation
  , RewriteEnv -> Evaluator
_evaluator      :: WHNF.Evaluator
  -- ^ Hardcoded evaluator for WHNF (old evaluator)
  , RewriteEnv -> VarSet
_topEntities    :: VarSet
  -- ^ Functions that are considered TopEntities
  }

Lens.makeLenses ''RewriteEnv

debugOpts :: Lens.Getter RewriteEnv DebugOpts
debugOpts :: (DebugOpts -> f DebugOpts) -> RewriteEnv -> f RewriteEnv
debugOpts = (ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv
Lens' RewriteEnv ClashEnv
clashEnv ((ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv)
-> ((DebugOpts -> f DebugOpts) -> ClashEnv -> f ClashEnv)
-> (DebugOpts -> f DebugOpts)
-> RewriteEnv
-> f RewriteEnv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClashEnv -> DebugOpts)
-> (DebugOpts -> f DebugOpts) -> ClashEnv -> f ClashEnv
forall (p :: Type -> Type -> Type) (f :: Type -> Type) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
Lens.to (ClashOpts -> DebugOpts
opt_debug (ClashOpts -> DebugOpts)
-> (ClashEnv -> ClashOpts) -> ClashEnv -> DebugOpts
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClashEnv -> ClashOpts
envOpts)

aggressiveXOpt :: Lens.Getter RewriteEnv Bool
aggressiveXOpt :: (Bool -> f Bool) -> RewriteEnv -> f RewriteEnv
aggressiveXOpt = (ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv
Lens' RewriteEnv ClashEnv
clashEnv ((ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv)
-> ((Bool -> f Bool) -> ClashEnv -> f ClashEnv)
-> (Bool -> f Bool)
-> RewriteEnv
-> f RewriteEnv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClashEnv -> Bool) -> (Bool -> f Bool) -> ClashEnv -> f ClashEnv
forall (p :: Type -> Type -> Type) (f :: Type -> Type) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
Lens.to (ClashOpts -> Bool
opt_aggressiveXOpt (ClashOpts -> Bool) -> (ClashEnv -> ClashOpts) -> ClashEnv -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClashEnv -> ClashOpts
envOpts)

tcCache :: Lens.Getter RewriteEnv TyConMap
tcCache :: (TyConMap -> f TyConMap) -> RewriteEnv -> f RewriteEnv
tcCache = (ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv
Lens' RewriteEnv ClashEnv
clashEnv ((ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv)
-> ((TyConMap -> f TyConMap) -> ClashEnv -> f ClashEnv)
-> (TyConMap -> f TyConMap)
-> RewriteEnv
-> f RewriteEnv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClashEnv -> TyConMap)
-> (TyConMap -> f TyConMap) -> ClashEnv -> f ClashEnv
forall (p :: Type -> Type -> Type) (f :: Type -> Type) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
Lens.to ClashEnv -> TyConMap
envTyConMap

tupleTcCache :: Lens.Getter RewriteEnv (IntMap TyConName)
tupleTcCache :: (IntMap TyConName -> f (IntMap TyConName))
-> RewriteEnv -> f RewriteEnv
tupleTcCache = (ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv
Lens' RewriteEnv ClashEnv
clashEnv ((ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv)
-> ((IntMap TyConName -> f (IntMap TyConName))
    -> ClashEnv -> f ClashEnv)
-> (IntMap TyConName -> f (IntMap TyConName))
-> RewriteEnv
-> f RewriteEnv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClashEnv -> IntMap TyConName)
-> (IntMap TyConName -> f (IntMap TyConName))
-> ClashEnv
-> f ClashEnv
forall (p :: Type -> Type -> Type) (f :: Type -> Type) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
Lens.to ClashEnv -> IntMap TyConName
envTupleTyCons

customReprs :: Lens.Getter RewriteEnv CustomReprs
customReprs :: (CustomReprs -> f CustomReprs) -> RewriteEnv -> f RewriteEnv
customReprs = (ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv
Lens' RewriteEnv ClashEnv
clashEnv ((ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv)
-> ((CustomReprs -> f CustomReprs) -> ClashEnv -> f ClashEnv)
-> (CustomReprs -> f CustomReprs)
-> RewriteEnv
-> f RewriteEnv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClashEnv -> CustomReprs)
-> (CustomReprs -> f CustomReprs) -> ClashEnv -> f ClashEnv
forall (p :: Type -> Type -> Type) (f :: Type -> Type) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
Lens.to ClashEnv -> CustomReprs
envCustomReprs

fuelLimit :: Lens.Getter RewriteEnv Word
fuelLimit :: (Word -> f Word) -> RewriteEnv -> f RewriteEnv
fuelLimit = (ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv
Lens' RewriteEnv ClashEnv
clashEnv ((ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv)
-> ((Word -> f Word) -> ClashEnv -> f ClashEnv)
-> (Word -> f Word)
-> RewriteEnv
-> f RewriteEnv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClashEnv -> Word) -> (Word -> f Word) -> ClashEnv -> f ClashEnv
forall (p :: Type -> Type -> Type) (f :: Type -> Type) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
Lens.to (ClashOpts -> Word
opt_evaluatorFuelLimit (ClashOpts -> Word) -> (ClashEnv -> ClashOpts) -> ClashEnv -> Word
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClashEnv -> ClashOpts
envOpts)

primitives :: Lens.Getter RewriteEnv CompiledPrimMap
primitives :: (CompiledPrimMap -> f CompiledPrimMap)
-> RewriteEnv -> f RewriteEnv
primitives = (ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv
Lens' RewriteEnv ClashEnv
clashEnv ((ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv)
-> ((CompiledPrimMap -> f CompiledPrimMap)
    -> ClashEnv -> f ClashEnv)
-> (CompiledPrimMap -> f CompiledPrimMap)
-> RewriteEnv
-> f RewriteEnv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClashEnv -> CompiledPrimMap)
-> (CompiledPrimMap -> f CompiledPrimMap) -> ClashEnv -> f ClashEnv
forall (p :: Type -> Type -> Type) (f :: Type -> Type) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
Lens.to ClashEnv -> CompiledPrimMap
envPrimitives

inlineLimit :: Lens.Getter RewriteEnv Int
inlineLimit :: (Int -> f Int) -> RewriteEnv -> f RewriteEnv
inlineLimit = (ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv
Lens' RewriteEnv ClashEnv
clashEnv ((ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv)
-> ((Int -> f Int) -> ClashEnv -> f ClashEnv)
-> (Int -> f Int)
-> RewriteEnv
-> f RewriteEnv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClashEnv -> Int) -> (Int -> f Int) -> ClashEnv -> f ClashEnv
forall (p :: Type -> Type -> Type) (f :: Type -> Type) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
Lens.to (ClashOpts -> Int
opt_inlineLimit (ClashOpts -> Int) -> (ClashEnv -> ClashOpts) -> ClashEnv -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClashEnv -> ClashOpts
envOpts)

inlineFunctionLimit :: Lens.Getter RewriteEnv Word
inlineFunctionLimit :: (Word -> f Word) -> RewriteEnv -> f RewriteEnv
inlineFunctionLimit = (ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv
Lens' RewriteEnv ClashEnv
clashEnv ((ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv)
-> ((Word -> f Word) -> ClashEnv -> f ClashEnv)
-> (Word -> f Word)
-> RewriteEnv
-> f RewriteEnv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClashEnv -> Word) -> (Word -> f Word) -> ClashEnv -> f ClashEnv
forall (p :: Type -> Type -> Type) (f :: Type -> Type) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
Lens.to (ClashOpts -> Word
opt_inlineFunctionLimit (ClashOpts -> Word) -> (ClashEnv -> ClashOpts) -> ClashEnv -> Word
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClashEnv -> ClashOpts
envOpts)

inlineConstantLimit :: Lens.Getter RewriteEnv Word
inlineConstantLimit :: (Word -> f Word) -> RewriteEnv -> f RewriteEnv
inlineConstantLimit = (ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv
Lens' RewriteEnv ClashEnv
clashEnv ((ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv)
-> ((Word -> f Word) -> ClashEnv -> f ClashEnv)
-> (Word -> f Word)
-> RewriteEnv
-> f RewriteEnv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClashEnv -> Word) -> (Word -> f Word) -> ClashEnv -> f ClashEnv
forall (p :: Type -> Type -> Type) (f :: Type -> Type) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
Lens.to (ClashOpts -> Word
opt_inlineConstantLimit (ClashOpts -> Word) -> (ClashEnv -> ClashOpts) -> ClashEnv -> Word
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClashEnv -> ClashOpts
envOpts)

inlineWFCacheLimit :: Lens.Getter RewriteEnv Word
inlineWFCacheLimit :: (Word -> f Word) -> RewriteEnv -> f RewriteEnv
inlineWFCacheLimit = (ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv
Lens' RewriteEnv ClashEnv
clashEnv ((ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv)
-> ((Word -> f Word) -> ClashEnv -> f ClashEnv)
-> (Word -> f Word)
-> RewriteEnv
-> f RewriteEnv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClashEnv -> Word) -> (Word -> f Word) -> ClashEnv -> f ClashEnv
forall (p :: Type -> Type -> Type) (f :: Type -> Type) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
Lens.to (ClashOpts -> Word
opt_inlineWFCacheLimit (ClashOpts -> Word) -> (ClashEnv -> ClashOpts) -> ClashEnv -> Word
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClashEnv -> ClashOpts
envOpts)

newInlineStrategy :: Lens.Getter RewriteEnv Bool
newInlineStrategy :: (Bool -> f Bool) -> RewriteEnv -> f RewriteEnv
newInlineStrategy = (ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv
Lens' RewriteEnv ClashEnv
clashEnv ((ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv)
-> ((Bool -> f Bool) -> ClashEnv -> f ClashEnv)
-> (Bool -> f Bool)
-> RewriteEnv
-> f RewriteEnv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClashEnv -> Bool) -> (Bool -> f Bool) -> ClashEnv -> f ClashEnv
forall (p :: Type -> Type -> Type) (f :: Type -> Type) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
Lens.to (ClashOpts -> Bool
opt_newInlineStrat (ClashOpts -> Bool) -> (ClashEnv -> ClashOpts) -> ClashEnv -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClashEnv -> ClashOpts
envOpts)

specializationLimit :: Lens.Getter RewriteEnv Int
specializationLimit :: (Int -> f Int) -> RewriteEnv -> f RewriteEnv
specializationLimit = (ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv
Lens' RewriteEnv ClashEnv
clashEnv ((ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv)
-> ((Int -> f Int) -> ClashEnv -> f ClashEnv)
-> (Int -> f Int)
-> RewriteEnv
-> f RewriteEnv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClashEnv -> Int) -> (Int -> f Int) -> ClashEnv -> f ClashEnv
forall (p :: Type -> Type -> Type) (f :: Type -> Type) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
Lens.to (ClashOpts -> Int
opt_specLimit (ClashOpts -> Int) -> (ClashEnv -> ClashOpts) -> ClashEnv -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClashEnv -> ClashOpts
envOpts)

normalizeUltra :: Lens.Getter RewriteEnv Bool
normalizeUltra :: (Bool -> f Bool) -> RewriteEnv -> f RewriteEnv
normalizeUltra = (ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv
Lens' RewriteEnv ClashEnv
clashEnv ((ClashEnv -> f ClashEnv) -> RewriteEnv -> f RewriteEnv)
-> ((Bool -> f Bool) -> ClashEnv -> f ClashEnv)
-> (Bool -> f Bool)
-> RewriteEnv
-> f RewriteEnv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClashEnv -> Bool) -> (Bool -> f Bool) -> ClashEnv -> f ClashEnv
forall (p :: Type -> Type -> Type) (f :: Type -> Type) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
Lens.to (ClashOpts -> Bool
opt_ultra (ClashOpts -> Bool) -> (ClashEnv -> ClashOpts) -> ClashEnv -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClashEnv -> ClashOpts
envOpts)

-- | Monad that keeps track how many transformations have been applied and can
-- generate fresh variables and unique identifiers. In addition, it keeps track
-- if a transformation/rewrite has been successfully applied.
newtype RewriteMonad extra a = R
  { RewriteMonad extra a
-> RWST RewriteEnv Any (RewriteState extra) IO a
unR :: RWST RewriteEnv Any (RewriteState extra) IO a }
  deriving newtype
    ( Functor (RewriteMonad extra)
a -> RewriteMonad extra a
Functor (RewriteMonad extra)
-> (forall a. a -> RewriteMonad extra a)
-> (forall a b.
    RewriteMonad extra (a -> b)
    -> RewriteMonad extra a -> RewriteMonad extra b)
-> (forall a b c.
    (a -> b -> c)
    -> RewriteMonad extra a
    -> RewriteMonad extra b
    -> RewriteMonad extra c)
-> (forall a b.
    RewriteMonad extra a
    -> RewriteMonad extra b -> RewriteMonad extra b)
-> (forall a b.
    RewriteMonad extra a
    -> RewriteMonad extra b -> RewriteMonad extra a)
-> Applicative (RewriteMonad extra)
RewriteMonad extra a
-> RewriteMonad extra b -> RewriteMonad extra b
RewriteMonad extra a
-> RewriteMonad extra b -> RewriteMonad extra a
RewriteMonad extra (a -> b)
-> RewriteMonad extra a -> RewriteMonad extra b
(a -> b -> c)
-> RewriteMonad extra a
-> RewriteMonad extra b
-> RewriteMonad extra c
forall extra. Functor (RewriteMonad extra)
forall a. a -> RewriteMonad extra a
forall extra a. a -> RewriteMonad extra a
forall a b.
RewriteMonad extra a
-> RewriteMonad extra b -> RewriteMonad extra a
forall a b.
RewriteMonad extra a
-> RewriteMonad extra b -> RewriteMonad extra b
forall a b.
RewriteMonad extra (a -> b)
-> RewriteMonad extra a -> RewriteMonad extra b
forall extra a b.
RewriteMonad extra a
-> RewriteMonad extra b -> RewriteMonad extra a
forall extra a b.
RewriteMonad extra a
-> RewriteMonad extra b -> RewriteMonad extra b
forall extra a b.
RewriteMonad extra (a -> b)
-> RewriteMonad extra a -> RewriteMonad extra b
forall a b c.
(a -> b -> c)
-> RewriteMonad extra a
-> RewriteMonad extra b
-> RewriteMonad extra c
forall extra a b c.
(a -> b -> c)
-> RewriteMonad extra a
-> RewriteMonad extra b
-> RewriteMonad extra c
forall (f :: Type -> Type).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: RewriteMonad extra a
-> RewriteMonad extra b -> RewriteMonad extra a
$c<* :: forall extra a b.
RewriteMonad extra a
-> RewriteMonad extra b -> RewriteMonad extra a
*> :: RewriteMonad extra a
-> RewriteMonad extra b -> RewriteMonad extra b
$c*> :: forall extra a b.
RewriteMonad extra a
-> RewriteMonad extra b -> RewriteMonad extra b
liftA2 :: (a -> b -> c)
-> RewriteMonad extra a
-> RewriteMonad extra b
-> RewriteMonad extra c
$cliftA2 :: forall extra a b c.
(a -> b -> c)
-> RewriteMonad extra a
-> RewriteMonad extra b
-> RewriteMonad extra c
<*> :: RewriteMonad extra (a -> b)
-> RewriteMonad extra a -> RewriteMonad extra b
$c<*> :: forall extra a b.
RewriteMonad extra (a -> b)
-> RewriteMonad extra a -> RewriteMonad extra b
pure :: a -> RewriteMonad extra a
$cpure :: forall extra a. a -> RewriteMonad extra a
$cp1Applicative :: forall extra. Functor (RewriteMonad extra)
Applicative
    , a -> RewriteMonad extra b -> RewriteMonad extra a
(a -> b) -> RewriteMonad extra a -> RewriteMonad extra b
(forall a b.
 (a -> b) -> RewriteMonad extra a -> RewriteMonad extra b)
-> (forall a b. a -> RewriteMonad extra b -> RewriteMonad extra a)
-> Functor (RewriteMonad extra)
forall a b. a -> RewriteMonad extra b -> RewriteMonad extra a
forall a b.
(a -> b) -> RewriteMonad extra a -> RewriteMonad extra b
forall extra a b. a -> RewriteMonad extra b -> RewriteMonad extra a
forall extra a b.
(a -> b) -> RewriteMonad extra a -> RewriteMonad extra b
forall (f :: Type -> Type).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> RewriteMonad extra b -> RewriteMonad extra a
$c<$ :: forall extra a b. a -> RewriteMonad extra b -> RewriteMonad extra a
fmap :: (a -> b) -> RewriteMonad extra a -> RewriteMonad extra b
$cfmap :: forall extra a b.
(a -> b) -> RewriteMonad extra a -> RewriteMonad extra b
Functor
    , Applicative (RewriteMonad extra)
a -> RewriteMonad extra a
Applicative (RewriteMonad extra)
-> (forall a b.
    RewriteMonad extra a
    -> (a -> RewriteMonad extra b) -> RewriteMonad extra b)
-> (forall a b.
    RewriteMonad extra a
    -> RewriteMonad extra b -> RewriteMonad extra b)
-> (forall a. a -> RewriteMonad extra a)
-> Monad (RewriteMonad extra)
RewriteMonad extra a
-> (a -> RewriteMonad extra b) -> RewriteMonad extra b
RewriteMonad extra a
-> RewriteMonad extra b -> RewriteMonad extra b
forall extra. Applicative (RewriteMonad extra)
forall a. a -> RewriteMonad extra a
forall extra a. a -> RewriteMonad extra a
forall a b.
RewriteMonad extra a
-> RewriteMonad extra b -> RewriteMonad extra b
forall a b.
RewriteMonad extra a
-> (a -> RewriteMonad extra b) -> RewriteMonad extra b
forall extra a b.
RewriteMonad extra a
-> RewriteMonad extra b -> RewriteMonad extra b
forall extra a b.
RewriteMonad extra a
-> (a -> RewriteMonad extra b) -> RewriteMonad extra b
forall (m :: Type -> Type).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> RewriteMonad extra a
$creturn :: forall extra a. a -> RewriteMonad extra a
>> :: RewriteMonad extra a
-> RewriteMonad extra b -> RewriteMonad extra b
$c>> :: forall extra a b.
RewriteMonad extra a
-> RewriteMonad extra b -> RewriteMonad extra b
>>= :: RewriteMonad extra a
-> (a -> RewriteMonad extra b) -> RewriteMonad extra b
$c>>= :: forall extra a b.
RewriteMonad extra a
-> (a -> RewriteMonad extra b) -> RewriteMonad extra b
$cp1Monad :: forall extra. Applicative (RewriteMonad extra)
Monad
    , Monad (RewriteMonad extra)
Monad (RewriteMonad extra)
-> (forall a. (a -> RewriteMonad extra a) -> RewriteMonad extra a)
-> MonadFix (RewriteMonad extra)
(a -> RewriteMonad extra a) -> RewriteMonad extra a
forall extra. Monad (RewriteMonad extra)
forall a. (a -> RewriteMonad extra a) -> RewriteMonad extra a
forall extra a. (a -> RewriteMonad extra a) -> RewriteMonad extra a
forall (m :: Type -> Type).
Monad m -> (forall a. (a -> m a) -> m a) -> MonadFix m
mfix :: (a -> RewriteMonad extra a) -> RewriteMonad extra a
$cmfix :: forall extra a. (a -> RewriteMonad extra a) -> RewriteMonad extra a
$cp1MonadFix :: forall extra. Monad (RewriteMonad extra)
MonadFix
    )
#if MIN_VERSION_transformers(0,5,6) && MIN_VERSION_mtl(2,3,0)
  deriving newtype
    ( MonadState (RewriteState extra)
    , MonadWriter Any
    , MonadReader RewriteEnv
    )
#endif

-- | Run the computation in the RewriteMonad
runR
  :: RewriteMonad extra a
  -> RewriteEnv
  -> RewriteState extra
  -> IO (a, RewriteState extra, Any)
runR :: RewriteMonad extra a
-> RewriteEnv
-> RewriteState extra
-> IO (a, RewriteState extra, Any)
runR RewriteMonad extra a
m = RWST RewriteEnv Any (RewriteState extra) IO a
-> RewriteEnv
-> RewriteState extra
-> IO (a, RewriteState extra, Any)
forall w r s (m :: Type -> Type) a.
Monoid w =>
RWST r w s m a -> r -> s -> m (a, s, w)
RWS.runRWST (RewriteMonad extra a
-> RWST RewriteEnv Any (RewriteState extra) IO a
forall extra a.
RewriteMonad extra a
-> RWST RewriteEnv Any (RewriteState extra) IO a
unR RewriteMonad extra a
m)

#if MIN_VERSION_transformers(0,5,6) && !MIN_VERSION_mtl(2,3,0)
-- For Control.Monad.Trans.RWS.Strict these are already defined, however the
-- CPS version of RWS is not included in `mtl` yet.

instance MonadState (RewriteState extra) (RewriteMonad extra) where
  get :: RewriteMonad extra (RewriteState extra)
get = RWST RewriteEnv Any (RewriteState extra) IO (RewriteState extra)
-> RewriteMonad extra (RewriteState extra)
forall extra a.
RWST RewriteEnv Any (RewriteState extra) IO a
-> RewriteMonad extra a
R RWST RewriteEnv Any (RewriteState extra) IO (RewriteState extra)
forall (m :: Type -> Type) r w s. Monad m => RWST r w s m s
RWS.get
  {-# INLINE get #-}
  put :: RewriteState extra -> RewriteMonad extra ()
put = RWST RewriteEnv Any (RewriteState extra) IO ()
-> RewriteMonad extra ()
forall extra a.
RWST RewriteEnv Any (RewriteState extra) IO a
-> RewriteMonad extra a
R (RWST RewriteEnv Any (RewriteState extra) IO ()
 -> RewriteMonad extra ())
-> (RewriteState extra
    -> RWST RewriteEnv Any (RewriteState extra) IO ())
-> RewriteState extra
-> RewriteMonad extra ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RewriteState extra
-> RWST RewriteEnv Any (RewriteState extra) IO ()
forall (m :: Type -> Type) s r w. Monad m => s -> RWST r w s m ()
RWS.put
  {-# INLINE put #-}
  state :: (RewriteState extra -> (a, RewriteState extra))
-> RewriteMonad extra a
state = RWST RewriteEnv Any (RewriteState extra) IO a
-> RewriteMonad extra a
forall extra a.
RWST RewriteEnv Any (RewriteState extra) IO a
-> RewriteMonad extra a
R (RWST RewriteEnv Any (RewriteState extra) IO a
 -> RewriteMonad extra a)
-> ((RewriteState extra -> (a, RewriteState extra))
    -> RWST RewriteEnv Any (RewriteState extra) IO a)
-> (RewriteState extra -> (a, RewriteState extra))
-> RewriteMonad extra a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RewriteState extra -> (a, RewriteState extra))
-> RWST RewriteEnv Any (RewriteState extra) IO a
forall (m :: Type -> Type) s a r w.
Monad m =>
(s -> (a, s)) -> RWST r w s m a
RWS.state
  {-# INLINE state #-}

instance MonadWriter Any (RewriteMonad extra) where
  writer :: (a, Any) -> RewriteMonad extra a
writer = RWST RewriteEnv Any (RewriteState extra) IO a
-> RewriteMonad extra a
forall extra a.
RWST RewriteEnv Any (RewriteState extra) IO a
-> RewriteMonad extra a
R (RWST RewriteEnv Any (RewriteState extra) IO a
 -> RewriteMonad extra a)
-> ((a, Any) -> RWST RewriteEnv Any (RewriteState extra) IO a)
-> (a, Any)
-> RewriteMonad extra a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, Any) -> RWST RewriteEnv Any (RewriteState extra) IO a
forall w (m :: Type -> Type) a r s.
(Monoid w, Monad m) =>
(a, w) -> RWST r w s m a
RWS.writer
  {-# INLINE writer #-}
  tell :: Any -> RewriteMonad extra ()
tell = RWST RewriteEnv Any (RewriteState extra) IO ()
-> RewriteMonad extra ()
forall extra a.
RWST RewriteEnv Any (RewriteState extra) IO a
-> RewriteMonad extra a
R (RWST RewriteEnv Any (RewriteState extra) IO ()
 -> RewriteMonad extra ())
-> (Any -> RWST RewriteEnv Any (RewriteState extra) IO ())
-> Any
-> RewriteMonad extra ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Any -> RWST RewriteEnv Any (RewriteState extra) IO ()
forall w (m :: Type -> Type) r s.
(Monoid w, Monad m) =>
w -> RWST r w s m ()
RWS.tell
  {-# INLINE tell #-}
  listen :: RewriteMonad extra a -> RewriteMonad extra (a, Any)
listen = RWST RewriteEnv Any (RewriteState extra) IO (a, Any)
-> RewriteMonad extra (a, Any)
forall extra a.
RWST RewriteEnv Any (RewriteState extra) IO a
-> RewriteMonad extra a
R (RWST RewriteEnv Any (RewriteState extra) IO (a, Any)
 -> RewriteMonad extra (a, Any))
-> (RewriteMonad extra a
    -> RWST RewriteEnv Any (RewriteState extra) IO (a, Any))
-> RewriteMonad extra a
-> RewriteMonad extra (a, Any)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RWST RewriteEnv Any (RewriteState extra) IO a
-> RWST RewriteEnv Any (RewriteState extra) IO (a, Any)
forall w (m :: Type -> Type) r s a.
(Monoid w, Monad m) =>
RWST r w s m a -> RWST r w s m (a, w)
RWS.listen (RWST RewriteEnv Any (RewriteState extra) IO a
 -> RWST RewriteEnv Any (RewriteState extra) IO (a, Any))
-> (RewriteMonad extra a
    -> RWST RewriteEnv Any (RewriteState extra) IO a)
-> RewriteMonad extra a
-> RWST RewriteEnv Any (RewriteState extra) IO (a, Any)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RewriteMonad extra a
-> RWST RewriteEnv Any (RewriteState extra) IO a
forall extra a.
RewriteMonad extra a
-> RWST RewriteEnv Any (RewriteState extra) IO a
unR
  {-# INLINE listen #-}
  pass :: RewriteMonad extra (a, Any -> Any) -> RewriteMonad extra a
pass = RWST RewriteEnv Any (RewriteState extra) IO a
-> RewriteMonad extra a
forall extra a.
RWST RewriteEnv Any (RewriteState extra) IO a
-> RewriteMonad extra a
R (RWST RewriteEnv Any (RewriteState extra) IO a
 -> RewriteMonad extra a)
-> (RewriteMonad extra (a, Any -> Any)
    -> RWST RewriteEnv Any (RewriteState extra) IO a)
-> RewriteMonad extra (a, Any -> Any)
-> RewriteMonad extra a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RWST RewriteEnv Any (RewriteState extra) IO (a, Any -> Any)
-> RWST RewriteEnv Any (RewriteState extra) IO a
forall w w' (m :: Type -> Type) r s a.
(Monoid w, Monoid w', Monad m) =>
RWST r w s m (a, w -> w') -> RWST r w' s m a
RWS.pass (RWST RewriteEnv Any (RewriteState extra) IO (a, Any -> Any)
 -> RWST RewriteEnv Any (RewriteState extra) IO a)
-> (RewriteMonad extra (a, Any -> Any)
    -> RWST RewriteEnv Any (RewriteState extra) IO (a, Any -> Any))
-> RewriteMonad extra (a, Any -> Any)
-> RWST RewriteEnv Any (RewriteState extra) IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RewriteMonad extra (a, Any -> Any)
-> RWST RewriteEnv Any (RewriteState extra) IO (a, Any -> Any)
forall extra a.
RewriteMonad extra a
-> RWST RewriteEnv Any (RewriteState extra) IO a
unR
  {-# INLINE pass #-}

instance MonadReader RewriteEnv (RewriteMonad extra) where
   ask :: RewriteMonad extra RewriteEnv
ask = RWST RewriteEnv Any (RewriteState extra) IO RewriteEnv
-> RewriteMonad extra RewriteEnv
forall extra a.
RWST RewriteEnv Any (RewriteState extra) IO a
-> RewriteMonad extra a
R RWST RewriteEnv Any (RewriteState extra) IO RewriteEnv
forall (m :: Type -> Type) r w s. Monad m => RWST r w s m r
RWS.ask
   {-# INLINE ask #-}
   local :: (RewriteEnv -> RewriteEnv)
-> RewriteMonad extra a -> RewriteMonad extra a
local RewriteEnv -> RewriteEnv
f = RWST RewriteEnv Any (RewriteState extra) IO a
-> RewriteMonad extra a
forall extra a.
RWST RewriteEnv Any (RewriteState extra) IO a
-> RewriteMonad extra a
R (RWST RewriteEnv Any (RewriteState extra) IO a
 -> RewriteMonad extra a)
-> (RewriteMonad extra a
    -> RWST RewriteEnv Any (RewriteState extra) IO a)
-> RewriteMonad extra a
-> RewriteMonad extra a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RewriteEnv -> RewriteEnv)
-> RWST RewriteEnv Any (RewriteState extra) IO a
-> RWST RewriteEnv Any (RewriteState extra) IO a
forall r w s (m :: Type -> Type) a.
(r -> r) -> RWST r w s m a -> RWST r w s m a
RWS.local RewriteEnv -> RewriteEnv
f (RWST RewriteEnv Any (RewriteState extra) IO a
 -> RWST RewriteEnv Any (RewriteState extra) IO a)
-> (RewriteMonad extra a
    -> RWST RewriteEnv Any (RewriteState extra) IO a)
-> RewriteMonad extra a
-> RWST RewriteEnv Any (RewriteState extra) IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RewriteMonad extra a
-> RWST RewriteEnv Any (RewriteState extra) IO a
forall extra a.
RewriteMonad extra a
-> RWST RewriteEnv Any (RewriteState extra) IO a
unR
   {-# INLINE local #-}
   reader :: (RewriteEnv -> a) -> RewriteMonad extra a
reader = RWST RewriteEnv Any (RewriteState extra) IO a
-> RewriteMonad extra a
forall extra a.
RWST RewriteEnv Any (RewriteState extra) IO a
-> RewriteMonad extra a
R (RWST RewriteEnv Any (RewriteState extra) IO a
 -> RewriteMonad extra a)
-> ((RewriteEnv -> a)
    -> RWST RewriteEnv Any (RewriteState extra) IO a)
-> (RewriteEnv -> a)
-> RewriteMonad extra a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RewriteEnv -> a) -> RWST RewriteEnv Any (RewriteState extra) IO a
forall (m :: Type -> Type) r a w s.
Monad m =>
(r -> a) -> RWST r w s m a
RWS.reader
   {-# INLINE reader #-}
#endif

instance MonadUnique (RewriteMonad extra) where
  getUniqueM :: RewriteMonad extra Int
getUniqueM = do
    Supply
sup <- Getting Supply (RewriteState extra) Supply
-> RewriteMonad extra Supply
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
use Getting Supply (RewriteState extra) Supply
forall extra. Lens' (RewriteState extra) Supply
uniqSupply
    let (Int
a,Supply
sup') = Supply -> (Int, Supply)
freshId Supply
sup
    (Supply -> Identity Supply)
-> RewriteState extra -> Identity (RewriteState extra)
forall extra. Lens' (RewriteState extra) Supply
uniqSupply ((Supply -> Identity Supply)
 -> RewriteState extra -> Identity (RewriteState extra))
-> Supply -> RewriteMonad extra ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= Supply
sup'
    Int
a Int -> RewriteMonad extra Int -> RewriteMonad extra Int
`seq` Int -> RewriteMonad extra Int
forall (m :: Type -> Type) a. Monad m => a -> m a
return Int
a

censor :: (Any -> Any) -> RewriteMonad extra a -> RewriteMonad extra a
censor :: (Any -> Any) -> RewriteMonad extra a -> RewriteMonad extra a
censor Any -> Any
f = RWST RewriteEnv Any (RewriteState extra) IO a
-> RewriteMonad extra a
forall extra a.
RWST RewriteEnv Any (RewriteState extra) IO a
-> RewriteMonad extra a
R (RWST RewriteEnv Any (RewriteState extra) IO a
 -> RewriteMonad extra a)
-> (RewriteMonad extra a
    -> RWST RewriteEnv Any (RewriteState extra) IO a)
-> RewriteMonad extra a
-> RewriteMonad extra a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Any -> Any)
-> RWST RewriteEnv Any (RewriteState extra) IO a
-> RWST RewriteEnv Any (RewriteState extra) IO a
forall w (m :: Type -> Type) r s a.
(Monoid w, Monad m) =>
(w -> w) -> RWST r w s m a -> RWST r w s m a
RWS.censor Any -> Any
f (RWST RewriteEnv Any (RewriteState extra) IO a
 -> RWST RewriteEnv Any (RewriteState extra) IO a)
-> (RewriteMonad extra a
    -> RWST RewriteEnv Any (RewriteState extra) IO a)
-> RewriteMonad extra a
-> RWST RewriteEnv Any (RewriteState extra) IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RewriteMonad extra a
-> RWST RewriteEnv Any (RewriteState extra) IO a
forall extra a.
RewriteMonad extra a
-> RWST RewriteEnv Any (RewriteState extra) IO a
unR
{-# INLINE censor #-}

data TransformContext
  = TransformContext
  { TransformContext -> InScopeSet
tfInScope :: !InScopeSet
  , TransformContext -> Context
tfContext :: Context
  }

-- | Monadic action that transforms a term given a certain context
type Transform m = TransformContext -> Term -> m Term

-- | A 'Transform' action in the context of the 'RewriteMonad'
type Rewrite extra = Transform (RewriteMonad extra)

-- Moved into Clash.Rewrite.WorkFree
{-# SPECIALIZE isWorkFree
      :: Lens' (RewriteState extra) (VarEnv Bool)
      -> BindingMap
      -> Term
      -> RewriteMonad extra Bool
  #-}