{-# LANGUAGE LambdaCase #-}

module Clash.Core.Termination
  ( RecInfo
  , mkRecInfo
  , isRecursive
  , recursiveGroup
  ) where

import Control.Lens.Fold
import Data.Graph (SCC(..))
import qualified Data.Graph as Graph
import qualified Data.List as List

import Clash.Core.FreeVars
import Clash.Core.Var
import Clash.Core.VarEnv
import Clash.Driver.Types

-- Quick lookup for whether a binding is recursive or non-recursive. If a
-- binding is non-recursive, we can assume that it terminates and skip
-- analysing it.
--
data RecInfo = RecInfo
  { RecInfo -> [VarSet]
recBindings    :: [VarSet]
    -- ^ Recursive bindings, organized into groups of strongly connected
    -- components.
  , RecInfo -> VarSet
nonRecBindings :: VarSet
    -- ^ Non-recursive bindings
  }

instance Show RecInfo where
  show :: RecInfo -> String
show (RecInfo [VarSet]
rs VarSet
ns) =
    String
"recursive groups:\n" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> [[Var Any]] -> String
forall a. Show a => a -> String
show ((VarSet -> [Var Any]) -> [VarSet] -> [[Var Any]]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap VarSet -> [Var Any]
eltsVarSet [VarSet]
rs)
      String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"\n\nnon-recursive:\n" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Var Any] -> String
forall a. Show a => a -> String
show (VarSet -> [Var Any]
eltsVarSet VarSet
ns)

instance Semigroup RecInfo where
  {-# INLINE (<>) #-}
  RecInfo [VarSet]
rX VarSet
nX <> :: RecInfo -> RecInfo -> RecInfo
<> RecInfo [VarSet]
rY VarSet
nY =
    [VarSet] -> VarSet -> RecInfo
RecInfo ([VarSet]
rX [VarSet] -> [VarSet] -> [VarSet]
forall a. Semigroup a => a -> a -> a
<> [VarSet]
rY) (VarSet
nX VarSet -> VarSet -> VarSet
forall a. Semigroup a => a -> a -> a
<> VarSet
nY)

instance Monoid RecInfo where
  {-# INLINE mempty #-}
  mempty :: RecInfo
mempty = [VarSet] -> VarSet -> RecInfo
RecInfo [VarSet]
forall a. Monoid a => a
mempty VarSet
forall a. Monoid a => a
mempty

  {-# INLINE mappend #-}
  mappend :: RecInfo -> RecInfo -> RecInfo
mappend = RecInfo -> RecInfo -> RecInfo
forall a. Semigroup a => a -> a -> a
(<>)

-- | Given a map of top-level bindings, identify which terms are recursive and
-- organize them into groups of mutually recursive bindings. For example,
-- calling mkRecInfo on a BindingMap with the definitions
--
--   f []     = []
--   f (x:xs) = g x : h xs
--
--   g x = x + 1
--
--   h []     = []
--   h (x:xs) = x : f xs
--
--   i []     = []
--   i (x:xs) = x * 2 : i xs
--
-- would identify [f, g] and [i] as recursive groups, and g as non-recursive.
--
mkRecInfo :: BindingMap -> RecInfo
mkRecInfo :: BindingMap -> RecInfo
mkRecInfo =
  [RecInfo] -> RecInfo
forall a. Monoid a => [a] -> a
mconcat ([RecInfo] -> RecInfo)
-> (BindingMap -> [RecInfo]) -> BindingMap -> RecInfo
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SCC (Binding Term) -> RecInfo)
-> [SCC (Binding Term)] -> [RecInfo]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap SCC (Binding Term) -> RecInfo
forall a. SCC (Binding a) -> RecInfo
asInfo ([SCC (Binding Term)] -> [RecInfo])
-> (BindingMap -> [SCC (Binding Term)]) -> BindingMap -> [RecInfo]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BindingMap -> [SCC (Binding Term)]
dependencies
 where
  -- Convert a SCC to RecInfo
  asInfo :: SCC (Binding a) -> RecInfo
asInfo = \case
    AcyclicSCC Binding a
x -> [VarSet] -> VarSet -> RecInfo
RecInfo [] (Var Term -> VarSet
forall a. Var a -> VarSet
unitVarSet (Var Term -> VarSet) -> Var Term -> VarSet
forall a b. (a -> b) -> a -> b
$ Binding a -> Var Term
forall a. Binding a -> Var Term
bindingId Binding a
x)
    CyclicSCC [Binding a]
xs -> [VarSet] -> VarSet -> RecInfo
RecInfo [[Var Term] -> VarSet
forall a. [Var a] -> VarSet
mkVarSet ([Var Term] -> VarSet) -> [Var Term] -> VarSet
forall a b. (a -> b) -> a -> b
$ (Binding a -> Var Term) -> [Binding a] -> [Var Term]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Binding a -> Var Term
forall a. Binding a -> Var Term
bindingId [Binding a]
xs] VarSet
emptyVarSet

  -- Get the SCCs of the dependency graph of free variables.
  dependencies :: BindingMap -> [SCC (Binding Term)]
dependencies =
    [(Binding Term, Var Term, [Var Term])] -> [SCC (Binding Term)]
forall key node. Ord key => [(node, key, [key])] -> [SCC node]
Graph.stronglyConnComp ([(Binding Term, Var Term, [Var Term])] -> [SCC (Binding Term)])
-> (BindingMap -> [(Binding Term, Var Term, [Var Term])])
-> BindingMap
-> [SCC (Binding Term)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VarEnv (Binding Term, Var Term, [Var Term])
-> [(Binding Term, Var Term, [Var Term])]
forall a. VarEnv a -> [a]
eltsVarEnv (VarEnv (Binding Term, Var Term, [Var Term])
 -> [(Binding Term, Var Term, [Var Term])])
-> (BindingMap -> VarEnv (Binding Term, Var Term, [Var Term]))
-> BindingMap
-> [(Binding Term, Var Term, [Var Term])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Binding Term -> (Binding Term, Var Term, [Var Term]))
-> BindingMap -> VarEnv (Binding Term, Var Term, [Var Term])
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Binding Term -> (Binding Term, Var Term, [Var Term])
go
   where
    go :: Binding Term -> (Binding Term, Var Term, [Var Term])
go Binding Term
x = let fvs :: [Var Term]
fvs = Binding Term -> Term
forall a. Binding a -> a
bindingTerm Binding Term
x Term -> Getting (Endo [Var Term]) Term (Var Term) -> [Var Term]
forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. Getting (Endo [Var Term]) Term (Var Term)
Fold Term (Var Term)
freeIds
            in (Binding Term
x, Binding Term -> Var Term
forall a. Binding a -> Var Term
bindingId Binding Term
x, [Var Term]
fvs)

-- | Check if a global binder is recursive. To be conservative, binders which
-- are not included in the RecInfo are assumed to be recursive.
--
isRecursive :: Id -> RecInfo -> Bool
isRecursive :: Var Term -> RecInfo -> Bool
isRecursive Var Term
i
  | Var Term -> Bool
forall a. Var a -> Bool
isGlobalId Var Term
i = Bool -> Bool
not (Bool -> Bool) -> (RecInfo -> Bool) -> RecInfo -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Var Term -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
elemVarSet Var Term
i (VarSet -> Bool) -> (RecInfo -> VarSet) -> RecInfo -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecInfo -> VarSet
nonRecBindings
  | Bool
otherwise = String -> RecInfo -> Bool
forall a. HasCallStack => String -> a
error (String
"isRecursive: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Var Term -> String
forall a. Show a => a -> String
show Var Term
i String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" is not a global Id")

-- | Return the recursive group that a global binder belongs to. If the
-- binder is non-recursive or not included in the RecInfo, Nothing is returned.
--
recursiveGroup :: Id -> RecInfo -> Maybe VarSet
recursiveGroup :: Var Term -> RecInfo -> Maybe VarSet
recursiveGroup Var Term
i = (VarSet -> Bool) -> [VarSet] -> Maybe VarSet
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe a
List.find (Var Term -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
elemVarSet Var Term
i) ([VarSet] -> Maybe VarSet)
-> (RecInfo -> [VarSet]) -> RecInfo -> Maybe VarSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecInfo -> [VarSet]
recBindings