{-# LANGUAGE LambdaCase #-}
module Clash.Core.Termination
( RecInfo
, mkRecInfo
, isRecursive
, recursiveGroup
) where
import Control.Lens.Fold
import Data.Graph (SCC(..))
import qualified Data.Graph as Graph
import qualified Data.List as List
import Clash.Core.FreeVars
import Clash.Core.Var
import Clash.Core.VarEnv
import Clash.Driver.Types
data RecInfo = RecInfo
{ RecInfo -> [VarSet]
recBindings :: [VarSet]
, RecInfo -> VarSet
nonRecBindings :: VarSet
}
instance Show RecInfo where
show :: RecInfo -> String
show (RecInfo [VarSet]
rs VarSet
ns) =
String
"recursive groups:\n" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> [[Var Any]] -> String
forall a. Show a => a -> String
show ((VarSet -> [Var Any]) -> [VarSet] -> [[Var Any]]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap VarSet -> [Var Any]
eltsVarSet [VarSet]
rs)
String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"\n\nnon-recursive:\n" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Var Any] -> String
forall a. Show a => a -> String
show (VarSet -> [Var Any]
eltsVarSet VarSet
ns)
instance Semigroup RecInfo where
{-# INLINE (<>) #-}
RecInfo [VarSet]
rX VarSet
nX <> :: RecInfo -> RecInfo -> RecInfo
<> RecInfo [VarSet]
rY VarSet
nY =
[VarSet] -> VarSet -> RecInfo
RecInfo ([VarSet]
rX [VarSet] -> [VarSet] -> [VarSet]
forall a. Semigroup a => a -> a -> a
<> [VarSet]
rY) (VarSet
nX VarSet -> VarSet -> VarSet
forall a. Semigroup a => a -> a -> a
<> VarSet
nY)
instance Monoid RecInfo where
{-# INLINE mempty #-}
mempty :: RecInfo
mempty = [VarSet] -> VarSet -> RecInfo
RecInfo [VarSet]
forall a. Monoid a => a
mempty VarSet
forall a. Monoid a => a
mempty
{-# INLINE mappend #-}
mappend :: RecInfo -> RecInfo -> RecInfo
mappend = RecInfo -> RecInfo -> RecInfo
forall a. Semigroup a => a -> a -> a
(<>)
mkRecInfo :: BindingMap -> RecInfo
mkRecInfo :: BindingMap -> RecInfo
mkRecInfo =
[RecInfo] -> RecInfo
forall a. Monoid a => [a] -> a
mconcat ([RecInfo] -> RecInfo)
-> (BindingMap -> [RecInfo]) -> BindingMap -> RecInfo
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SCC (Binding Term) -> RecInfo)
-> [SCC (Binding Term)] -> [RecInfo]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap SCC (Binding Term) -> RecInfo
forall a. SCC (Binding a) -> RecInfo
asInfo ([SCC (Binding Term)] -> [RecInfo])
-> (BindingMap -> [SCC (Binding Term)]) -> BindingMap -> [RecInfo]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BindingMap -> [SCC (Binding Term)]
dependencies
where
asInfo :: SCC (Binding a) -> RecInfo
asInfo = \case
AcyclicSCC Binding a
x -> [VarSet] -> VarSet -> RecInfo
RecInfo [] (Var Term -> VarSet
forall a. Var a -> VarSet
unitVarSet (Var Term -> VarSet) -> Var Term -> VarSet
forall a b. (a -> b) -> a -> b
$ Binding a -> Var Term
forall a. Binding a -> Var Term
bindingId Binding a
x)
CyclicSCC [Binding a]
xs -> [VarSet] -> VarSet -> RecInfo
RecInfo [[Var Term] -> VarSet
forall a. [Var a] -> VarSet
mkVarSet ([Var Term] -> VarSet) -> [Var Term] -> VarSet
forall a b. (a -> b) -> a -> b
$ (Binding a -> Var Term) -> [Binding a] -> [Var Term]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Binding a -> Var Term
forall a. Binding a -> Var Term
bindingId [Binding a]
xs] VarSet
emptyVarSet
dependencies :: BindingMap -> [SCC (Binding Term)]
dependencies =
[(Binding Term, Var Term, [Var Term])] -> [SCC (Binding Term)]
forall key node. Ord key => [(node, key, [key])] -> [SCC node]
Graph.stronglyConnComp ([(Binding Term, Var Term, [Var Term])] -> [SCC (Binding Term)])
-> (BindingMap -> [(Binding Term, Var Term, [Var Term])])
-> BindingMap
-> [SCC (Binding Term)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VarEnv (Binding Term, Var Term, [Var Term])
-> [(Binding Term, Var Term, [Var Term])]
forall a. VarEnv a -> [a]
eltsVarEnv (VarEnv (Binding Term, Var Term, [Var Term])
-> [(Binding Term, Var Term, [Var Term])])
-> (BindingMap -> VarEnv (Binding Term, Var Term, [Var Term]))
-> BindingMap
-> [(Binding Term, Var Term, [Var Term])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Binding Term -> (Binding Term, Var Term, [Var Term]))
-> BindingMap -> VarEnv (Binding Term, Var Term, [Var Term])
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Binding Term -> (Binding Term, Var Term, [Var Term])
go
where
go :: Binding Term -> (Binding Term, Var Term, [Var Term])
go Binding Term
x = let fvs :: [Var Term]
fvs = Binding Term -> Term
forall a. Binding a -> a
bindingTerm Binding Term
x Term -> Getting (Endo [Var Term]) Term (Var Term) -> [Var Term]
forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. Getting (Endo [Var Term]) Term (Var Term)
Fold Term (Var Term)
freeIds
in (Binding Term
x, Binding Term -> Var Term
forall a. Binding a -> Var Term
bindingId Binding Term
x, [Var Term]
fvs)
isRecursive :: Id -> RecInfo -> Bool
isRecursive :: Var Term -> RecInfo -> Bool
isRecursive Var Term
i
| Var Term -> Bool
forall a. Var a -> Bool
isGlobalId Var Term
i = Bool -> Bool
not (Bool -> Bool) -> (RecInfo -> Bool) -> RecInfo -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Var Term -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
elemVarSet Var Term
i (VarSet -> Bool) -> (RecInfo -> VarSet) -> RecInfo -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecInfo -> VarSet
nonRecBindings
| Bool
otherwise = String -> RecInfo -> Bool
forall a. HasCallStack => String -> a
error (String
"isRecursive: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Var Term -> String
forall a. Show a => a -> String
show Var Term
i String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" is not a global Id")
recursiveGroup :: Id -> RecInfo -> Maybe VarSet
recursiveGroup :: Var Term -> RecInfo -> Maybe VarSet
recursiveGroup Var Term
i = (VarSet -> Bool) -> [VarSet] -> Maybe VarSet
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe a
List.find (Var Term -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
elemVarSet Var Term
i) ([VarSet] -> Maybe VarSet)
-> (RecInfo -> [VarSet]) -> RecInfo -> Maybe VarSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecInfo -> [VarSet]
recBindings