{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}

module Language.Haskell.Liquid.Termination.Structural (terminationVars) where

import Language.Haskell.Liquid.Types hiding (isDecreasing)
import Language.Haskell.Liquid.GHC.Misc (showPpr)

import CoreSyn
import Var
import Name (getSrcSpan)
import VarSet
import CoreSubst (deShadowBinds)

import Text.PrettyPrint.HughesPJ hiding ((<>))

import qualified Data.HashSet as HS
import Data.HashSet (HashSet)
import qualified Data.Map.Strict as M
import Data.Map.Strict (Map)
import qualified Data.List as L

import Control.Monad (liftM, ap)
import Data.Foldable (fold)

terminationVars :: TargetInfo -> [Var]
terminationVars :: TargetInfo -> [Var]
terminationVars TargetInfo
info = TargetInfo -> [CoreBind]
failingBinds TargetInfo
info [CoreBind] -> (CoreBind -> [Var]) -> [Var]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CoreBind -> [Var]
allBoundVars

failingBinds :: TargetInfo -> [CoreBind]
failingBinds :: TargetInfo -> [CoreBind]
failingBinds TargetInfo
info = (CoreBind -> Bool) -> [CoreBind] -> [CoreBind]
forall a. (a -> Bool) -> [a] -> [a]
filter (Result () -> Bool
forall a. Result a -> Bool
hasErrors (Result () -> Bool) -> (CoreBind -> Result ()) -> CoreBind -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoreBind -> Result ()
checkBind) [CoreBind]
structBinds
  where 
    structCheckWholeProgram :: Bool
structCheckWholeProgram = TargetInfo -> Bool
forall a. HasConfig a => a -> Bool
structuralTerm TargetInfo
info
    program :: [CoreBind]
program = TargetSrc -> [CoreBind]
giCbs (TargetSrc -> [CoreBind])
-> (TargetInfo -> TargetSrc) -> TargetInfo -> [CoreBind]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TargetInfo -> TargetSrc
giSrc (TargetInfo -> [CoreBind]) -> TargetInfo -> [CoreBind]
forall a b. (a -> b) -> a -> b
$ TargetInfo
info
    structFuns :: HashSet Var
structFuns = GhcSpecTerm -> HashSet Var
gsStTerm (GhcSpecTerm -> HashSet Var)
-> (TargetInfo -> GhcSpecTerm) -> TargetInfo -> HashSet Var
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TargetSpec -> GhcSpecTerm
gsTerm (TargetSpec -> GhcSpecTerm)
-> (TargetInfo -> TargetSpec) -> TargetInfo -> GhcSpecTerm
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TargetInfo -> TargetSpec
giSpec (TargetInfo -> HashSet Var) -> TargetInfo -> HashSet Var
forall a b. (a -> b) -> a -> b
$ TargetInfo
info
    structBinds :: [CoreBind]
structBinds
      | Bool
structCheckWholeProgram = [CoreBind]
program
      | Bool
otherwise = HashSet Var -> [CoreBind] -> [CoreBind]
findStructBinds HashSet Var
structFuns [CoreBind]
program

checkBind :: CoreBind -> Result ()
checkBind :: CoreBind -> Result ()
checkBind CoreBind
bind = do
  FunInfo [SrcCall]
srcCallInfo <- Env -> CoreBind -> Result (FunInfo [SrcCall])
getCallInfoBind Env
emptyEnv (CoreBind -> CoreBind
deShadowBind CoreBind
bind)
  let structCallInfo :: FunInfo [StructCall]
structCallInfo = (SrcCall -> StructCall) -> [SrcCall] -> [StructCall]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SrcCall -> StructCall
toStructCall ([SrcCall] -> [StructCall])
-> FunInfo [SrcCall] -> FunInfo [StructCall]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FunInfo [SrcCall]
srcCallInfo
  FunInfo (Result ()) -> Result ()
forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold (FunInfo (Result ()) -> Result ())
-> FunInfo (Result ()) -> Result ()
forall a b. (a -> b) -> a -> b
$ (Var -> [StructCall] -> Result ())
-> FunInfo [StructCall] -> FunInfo (Result ())
forall a b. (Var -> a -> b) -> FunInfo a -> FunInfo b
mapWithFun Var -> [StructCall] -> Result ()
structDecreasing FunInfo [StructCall]
structCallInfo

deShadowBind :: CoreBind -> CoreBind
deShadowBind :: CoreBind -> CoreBind
deShadowBind CoreBind
bind = [CoreBind] -> CoreBind
forall a. [a] -> a
head ([CoreBind] -> CoreBind) -> [CoreBind] -> CoreBind
forall a b. (a -> b) -> a -> b
$ [CoreBind] -> [CoreBind]
deShadowBinds [CoreBind
bind]

findStructBinds :: HashSet Var -> CoreProgram -> [CoreBind]
findStructBinds :: HashSet Var -> [CoreBind] -> [CoreBind]
findStructBinds HashSet Var
structFuns [CoreBind]
program = (CoreBind -> Bool) -> [CoreBind] -> [CoreBind]
forall a. (a -> Bool) -> [a] -> [a]
filter CoreBind -> Bool
isStructBind [CoreBind]
program
  where
    isStructBind :: CoreBind -> Bool
isStructBind (NonRec Var
f Expr Var
_) = Var
f Var -> HashSet Var -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`HS.member` HashSet Var
structFuns
    isStructBind (Rec []) = Bool
False
    isStructBind (Rec ((Var
f,Expr Var
_):[(Var, Expr Var)]
xs)) = Var
f Var -> HashSet Var -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`HS.member` HashSet Var
structFuns Bool -> Bool -> Bool
|| CoreBind -> Bool
isStructBind ([(Var, Expr Var)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Var, Expr Var)]
xs)

allBoundVars :: CoreBind -> [Var]
allBoundVars :: CoreBind -> [Var]
allBoundVars (NonRec Var
v Expr Var
e) = Var
v Var -> [Var] -> [Var]
forall a. a -> [a] -> [a]
: (Expr Var -> [CoreBind]
nextBinds Expr Var
e [CoreBind] -> (CoreBind -> [Var]) -> [Var]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CoreBind -> [Var]
allBoundVars)
allBoundVars (Rec [(Var, Expr Var)]
binds) = ((Var, Expr Var) -> Var) -> [(Var, Expr Var)] -> [Var]
forall a b. (a -> b) -> [a] -> [b]
map (Var, Expr Var) -> Var
forall a b. (a, b) -> a
fst [(Var, Expr Var)]
binds [Var] -> [Var] -> [Var]
forall a. [a] -> [a] -> [a]
++ (((Var, Expr Var) -> Expr Var) -> [(Var, Expr Var)] -> [Expr Var]
forall a b. (a -> b) -> [a] -> [b]
map (Var, Expr Var) -> Expr Var
forall a b. (a, b) -> b
snd [(Var, Expr Var)]
binds [Expr Var] -> (Expr Var -> [CoreBind]) -> [CoreBind]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Expr Var -> [CoreBind]
nextBinds [CoreBind] -> (CoreBind -> [Var]) -> [Var]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CoreBind -> [Var]
allBoundVars)

nextBinds :: CoreExpr -> [CoreBind]
nextBinds :: Expr Var -> [CoreBind]
nextBinds = \case
  App Expr Var
e Expr Var
a -> Expr Var -> [CoreBind]
nextBinds Expr Var
e [CoreBind] -> [CoreBind] -> [CoreBind]
forall a. [a] -> [a] -> [a]
++ Expr Var -> [CoreBind]
nextBinds Expr Var
a
  Lam Var
_ Expr Var
e -> Expr Var -> [CoreBind]
nextBinds Expr Var
e
  Let CoreBind
b Expr Var
e -> [CoreBind
b] [CoreBind] -> [CoreBind] -> [CoreBind]
forall a. [a] -> [a] -> [a]
++ Expr Var -> [CoreBind]
nextBinds Expr Var
e
  Case Expr Var
scrut Var
_ Type
_ [Alt Var]
alts -> Expr Var -> [CoreBind]
nextBinds Expr Var
scrut [CoreBind] -> [CoreBind] -> [CoreBind]
forall a. [a] -> [a] -> [a]
++ ([Expr Var
body | (AltCon
_, [Var]
_, Expr Var
body) <- [Alt Var]
alts] [Expr Var] -> (Expr Var -> [CoreBind]) -> [CoreBind]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Expr Var -> [CoreBind]
nextBinds)
  Cast Expr Var
e Coercion
_ -> Expr Var -> [CoreBind]
nextBinds Expr Var
e
  Tick Tickish Var
_ Expr Var
e -> Expr Var -> [CoreBind]
nextBinds Expr Var
e
  Var{} -> []
  Lit{} -> []
  Coercion{} -> []
  Type{} -> []

------------------------------------------------------------------------------------------

-- Note that this is *not* the Either/Maybe monad, since it's important that we
-- collect all errors, not just the first error.
data Result a = Result
  { Result a -> a
resultVal :: a
  , Result a -> [TermError]
resultErrors :: [TermError]
  } deriving (Int -> Result a -> ShowS
[Result a] -> ShowS
Result a -> String
(Int -> Result a -> ShowS)
-> (Result a -> String) -> ([Result a] -> ShowS) -> Show (Result a)
forall a. Show a => Int -> Result a -> ShowS
forall a. Show a => [Result a] -> ShowS
forall a. Show a => Result a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Result a] -> ShowS
$cshowList :: forall a. Show a => [Result a] -> ShowS
show :: Result a -> String
$cshow :: forall a. Show a => Result a -> String
showsPrec :: Int -> Result a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Result a -> ShowS
Show)

data TermError = TE 
  { TermError -> Var
teVar   :: Var
  , TermError -> UserError
teError :: UserError
  } deriving (Int -> TermError -> ShowS
[TermError] -> ShowS
TermError -> String
(Int -> TermError -> ShowS)
-> (TermError -> String)
-> ([TermError] -> ShowS)
-> Show TermError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TermError] -> ShowS
$cshowList :: [TermError] -> ShowS
show :: TermError -> String
$cshow :: TermError -> String
showsPrec :: Int -> TermError -> ShowS
$cshowsPrec :: Int -> TermError -> ShowS
Show)

hasErrors :: Result a -> Bool
hasErrors :: Result a -> Bool
hasErrors = Bool -> Bool
not (Bool -> Bool) -> (Result a -> Bool) -> Result a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TermError] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([TermError] -> Bool)
-> (Result a -> [TermError]) -> Result a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Result a -> [TermError]
forall a. Result a -> [TermError]
resultErrors

addError :: Var -> Doc -> Result a -> Result a
addError :: Var -> Doc -> Result a -> Result a
addError Var
fun Doc
expl (Result a
x [TermError]
errs) = a -> [TermError] -> Result a
forall a. a -> [TermError] -> Result a
Result a
x (Var -> Doc -> TermError
mkTermError Var
fun Doc
expl TermError -> [TermError] -> [TermError]
forall a. a -> [a] -> [a]
: [TermError]
errs)

mkTermError :: Var -> Doc -> TermError
mkTermError :: Var -> Doc -> TermError
mkTermError Var
fun Doc
expl = TE :: Var -> UserError -> TermError
TE
  { teVar :: Var
teVar   = Var
fun
  , teError :: UserError
teError = SrcSpan -> Doc -> Doc -> UserError
forall t. SrcSpan -> Doc -> Doc -> TError t
ErrStTerm (Var -> SrcSpan
forall a. NamedThing a => a -> SrcSpan
getSrcSpan Var
fun) (String -> Doc
text (String -> Doc) -> String -> Doc
forall a b. (a -> b) -> a -> b
$ Var -> String
forall a. Outputable a => a -> String
showPpr Var
fun) Doc
expl
  }

instance Monoid a => Monoid (Result a) where
  mempty :: Result a
mempty  = a -> [TermError] -> Result a
forall a. a -> [TermError] -> Result a
Result a
forall a. Monoid a => a
mempty []

instance Semigroup a => Semigroup (Result a) where
  Result a
x [TermError]
e1 <> :: Result a -> Result a -> Result a
<> Result a
y [TermError]
e2 = a -> [TermError] -> Result a
forall a. a -> [TermError] -> Result a
Result (a
x a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
y) ([TermError]
e1 [TermError] -> [TermError] -> [TermError]
forall a. [a] -> [a] -> [a]
++ [TermError]
e2)

instance Monad Result where
  Result a
x [TermError]
e1 >>= :: Result a -> (a -> Result b) -> Result b
>>= a -> Result b
f =
    let Result b
y [TermError]
e2 = a -> Result b
f a
x in
    b -> [TermError] -> Result b
forall a. a -> [TermError] -> Result a
Result b
y ([TermError]
e2 [TermError] -> [TermError] -> [TermError]
forall a. [a] -> [a] -> [a]
++ [TermError]
e1)

instance Applicative Result where
  pure :: a -> Result a
pure a
x = a -> [TermError] -> Result a
forall a. a -> [TermError] -> Result a
Result a
x []
  <*> :: Result (a -> b) -> Result a -> Result b
(<*>) = Result (a -> b) -> Result a -> Result b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

instance Functor Result where
  fmap :: (a -> b) -> Result a -> Result b
fmap = (a -> b) -> Result a -> Result b
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM

--------------------------------------------------------------------------------

data Env = Env
  { Env -> Maybe Var
envCurrentFun :: Maybe Var
  , Env -> [Expr Var]
envCurrentArgs :: [CoreArg]
  , Env -> [Fun]
envCheckedFuns :: [Fun]
  }

data Fun = Fun
  { Fun -> Var
funName :: Var
  , Fun -> [Param]
funParams :: [Param]
  }

data Param = Param
  { Param -> VarSet
paramNames :: VarSet
  , Param -> VarSet
paramSubterms :: VarSet
  } deriving (Param -> Param -> Bool
(Param -> Param -> Bool) -> (Param -> Param -> Bool) -> Eq Param
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Param -> Param -> Bool
$c/= :: Param -> Param -> Bool
== :: Param -> Param -> Bool
$c== :: Param -> Param -> Bool
Eq)

emptyEnv :: Env
emptyEnv :: Env
emptyEnv = Env :: Maybe Var -> [Expr Var] -> [Fun] -> Env
Env
  { envCurrentFun :: Maybe Var
envCurrentFun = Maybe Var
forall a. Maybe a
Nothing
  , envCurrentArgs :: [Expr Var]
envCurrentArgs = []
  , envCheckedFuns :: [Fun]
envCheckedFuns = []
  }

mkFun :: Var -> Fun
mkFun :: Var -> Fun
mkFun Var
name = Fun :: Var -> [Param] -> Fun
Fun
  { funName :: Var
funName = Var
name
  , funParams :: [Param]
funParams = []
  }

mkParam :: Var -> Param
mkParam :: Var -> Param
mkParam Var
name = Param :: VarSet -> VarSet -> Param
Param
  { paramNames :: VarSet
paramNames = Var -> VarSet
unitVarSet Var
name
  , paramSubterms :: VarSet
paramSubterms = VarSet
emptyVarSet
  }

lookupFun :: Env -> Var -> Maybe Fun
lookupFun :: Env -> Var -> Maybe Fun
lookupFun Env
env Var
name = (Fun -> Bool) -> [Fun] -> Maybe Fun
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
L.find (\Fun
fun -> Fun -> Var
funName Fun
fun Var -> Var -> Bool
forall a. Eq a => a -> a -> Bool
== Var
name) ([Fun] -> Maybe Fun) -> [Fun] -> Maybe Fun
forall a b. (a -> b) -> a -> b
$ Env -> [Fun]
envCheckedFuns Env
env

clearCurrentArgs :: Env -> Env
clearCurrentArgs :: Env -> Env
clearCurrentArgs Env
env = Env
env { envCurrentArgs :: [Expr Var]
envCurrentArgs = [] }

setCurrentFun :: Var -> Env -> Env
setCurrentFun :: Var -> Env -> Env
setCurrentFun Var
fun Env
env = Env
env { envCurrentFun :: Maybe Var
envCurrentFun = Var -> Maybe Var
forall a. a -> Maybe a
Just Var
fun }

clearCurrentFun :: Env -> Env
clearCurrentFun :: Env -> Env
clearCurrentFun Env
env = Env
env { envCurrentFun :: Maybe Var
envCurrentFun = Maybe Var
forall a. Maybe a
Nothing }

addArg :: CoreArg -> Env -> Env
addArg :: Expr Var -> Env -> Env
addArg Expr Var
arg Env
env = Env
env { envCurrentArgs :: [Expr Var]
envCurrentArgs = Expr Var
argExpr Var -> [Expr Var] -> [Expr Var]
forall a. a -> [a] -> [a]
:Env -> [Expr Var]
envCurrentArgs Env
env }

addParam :: Var -> Env -> Env
addParam :: Var -> Env -> Env
addParam Var
param Env
env = case Env -> Maybe Var
envCurrentFun Env
env of
  Maybe Var
Nothing -> Env
env
  Just Var
name -> Env
env { envCheckedFuns :: [Fun]
envCheckedFuns = Var -> Fun -> Fun
updateFunNamed Var
name (Fun -> Fun) -> [Fun] -> [Fun]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Env -> [Fun]
envCheckedFuns Env
env }
  where
    updateFunNamed :: Var -> Fun -> Fun
updateFunNamed Var
name Fun
fun
      | Fun -> Var
funName Fun
fun Var -> Var -> Bool
forall a. Eq a => a -> a -> Bool
== Var
name = Fun
fun { funParams :: [Param]
funParams = Var -> Param
mkParam Var
param Param -> [Param] -> [Param]
forall a. a -> [a] -> [a]
: Fun -> [Param]
funParams Fun
fun }
      | Bool
otherwise = Fun
fun

addSynonym :: Var -> Var -> Env -> Env
addSynonym :: Var -> Var -> Env -> Env
addSynonym Var
oldName Var
newName Env
env = Env
env { envCheckedFuns :: [Fun]
envCheckedFuns = Fun -> Fun
updateFun (Fun -> Fun) -> [Fun] -> [Fun]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Env -> [Fun]
envCheckedFuns Env
env }
  where
    updateFun :: Fun -> Fun
updateFun Fun
fun = Fun
fun { funParams :: [Param]
funParams = Param -> Param
updateParam (Param -> Param) -> [Param] -> [Param]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Fun -> [Param]
funParams Fun
fun }
    updateParam :: Param -> Param
updateParam Param
param
      | Var
oldName Var -> VarSet -> Bool
`elemVarSet` Param -> VarSet
paramNames Param
param = Param
param { paramNames :: VarSet
paramNames = Param -> VarSet
paramNames Param
param VarSet -> Var -> VarSet
`extendVarSet` Var
newName }
      | Var
oldName Var -> VarSet -> Bool
`elemVarSet` Param -> VarSet
paramSubterms Param
param = Param
param { paramSubterms :: VarSet
paramSubterms = Param -> VarSet
paramSubterms Param
param VarSet -> Var -> VarSet
`extendVarSet` Var
newName }
      | Bool
otherwise = Param
param

addSubterms :: Var -> [Var] -> Env -> Env
addSubterms :: Var -> [Var] -> Env -> Env
addSubterms Var
var [Var]
subterms Env
env = Env
env { envCheckedFuns :: [Fun]
envCheckedFuns = Fun -> Fun
updateFun (Fun -> Fun) -> [Fun] -> [Fun]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Env -> [Fun]
envCheckedFuns Env
env }
  where
    updateFun :: Fun -> Fun
updateFun Fun
fun = Fun
fun { funParams :: [Param]
funParams = Param -> Param
updateParam (Param -> Param) -> [Param] -> [Param]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Fun -> [Param]
funParams Fun
fun }
    updateParam :: Param -> Param
updateParam Param
param
      | Var
var Var -> VarSet -> Bool
`elemVarSet` Param -> VarSet
paramNames Param
param Bool -> Bool -> Bool
|| Var
var Var -> VarSet -> Bool
`elemVarSet` Param -> VarSet
paramSubterms Param
param = Param
param { paramSubterms :: VarSet
paramSubterms = Param -> VarSet
paramSubterms Param
param VarSet -> [Var] -> VarSet
`extendVarSetList` [Var]
subterms }
      | Bool
otherwise = Param
param

addCheckedFun :: Var -> Env -> Env
addCheckedFun :: Var -> Env -> Env
addCheckedFun Var
name Env
env = Env
env { envCheckedFuns :: [Fun]
envCheckedFuns = Var -> Fun
mkFun Var
name Fun -> [Fun] -> [Fun]
forall a. a -> [a] -> [a]
: Env -> [Fun]
envCheckedFuns Env
env }

isParam :: Var -> Param -> Bool
Var
var isParam :: Var -> Param -> Bool
`isParam` Param
param = Var
var Var -> VarSet -> Bool
`elemVarSet` Param -> VarSet
paramNames Param
param

isParamSubterm :: Var -> Param -> Bool
Var
var isParamSubterm :: Var -> Param -> Bool
`isParamSubterm` Param
param = Var
var Var -> VarSet -> Bool
`elemVarSet` Param -> VarSet
paramSubterms Param
param

--------------------------------------------------------------------------------

newtype FunInfo a = FunInfo (Map Var a)

data SrcCall = SrcCall
  { SrcCall -> Var
srcCallFun :: Var
  , SrcCall -> [(Param, Expr Var)]
srcCallArgs :: [(Param, CoreArg)]
  }

instance Semigroup a => Semigroup (FunInfo a) where
  FunInfo Map Var a
xs <> :: FunInfo a -> FunInfo a -> FunInfo a
<> FunInfo Map Var a
ys = Map Var a -> FunInfo a
forall a. Map Var a -> FunInfo a
FunInfo (Map Var a -> FunInfo a) -> Map Var a -> FunInfo a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> Map Var a -> Map Var a -> Map Var a
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith a -> a -> a
forall a. Semigroup a => a -> a -> a
(<>) Map Var a
xs Map Var a
ys

instance Semigroup a => Monoid (FunInfo a) where
  mempty :: FunInfo a
mempty = Map Var a -> FunInfo a
forall a. Map Var a -> FunInfo a
FunInfo Map Var a
forall k a. Map k a
M.empty

instance Functor FunInfo where
  fmap :: (a -> b) -> FunInfo a -> FunInfo b
fmap a -> b
f (FunInfo Map Var a
xs) = Map Var b -> FunInfo b
forall a. Map Var a -> FunInfo a
FunInfo ((a -> b) -> Map Var a -> Map Var b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f Map Var a
xs)

instance Foldable FunInfo where
  foldMap :: (a -> m) -> FunInfo a -> m
foldMap a -> m
f (FunInfo Map Var a
m) = (a -> m) -> Map Var a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap a -> m
f Map Var a
m

mapWithFun :: (Var -> a -> b) -> FunInfo a -> FunInfo b
mapWithFun :: (Var -> a -> b) -> FunInfo a -> FunInfo b
mapWithFun Var -> a -> b
f (FunInfo Map Var a
x) = Map Var b -> FunInfo b
forall a. Map Var a -> FunInfo a
FunInfo ((Var -> a -> b) -> Map Var a -> Map Var b
forall k a b. (k -> a -> b) -> Map k a -> Map k b
M.mapWithKey Var -> a -> b
f Map Var a
x)

mkFunInfo :: Var -> a -> FunInfo a
mkFunInfo :: Var -> a -> FunInfo a
mkFunInfo Var
fun a
x = Map Var a -> FunInfo a
forall a. Map Var a -> FunInfo a
FunInfo (Map Var a -> FunInfo a) -> Map Var a -> FunInfo a
forall a b. (a -> b) -> a -> b
$ Var -> a -> Map Var a
forall k a. k -> a -> Map k a
M.singleton Var
fun a
x

mkSrcCall :: Var -> [(Param, CoreArg)] -> SrcCall
mkSrcCall :: Var -> [(Param, Expr Var)] -> SrcCall
mkSrcCall Var
fun [(Param, Expr Var)]
args = SrcCall :: Var -> [(Param, Expr Var)] -> SrcCall
SrcCall
  { srcCallFun :: Var
srcCallFun = Var
fun
  , srcCallArgs :: [(Param, Expr Var)]
srcCallArgs = [(Param, Expr Var)]
args
  }

toVar :: CoreExpr -> Maybe Var
toVar :: Expr Var -> Maybe Var
toVar (Var Var
x) = Var -> Maybe Var
forall a. a -> Maybe a
Just Var
x
toVar (Cast Expr Var
e Coercion
_) = Expr Var -> Maybe Var
toVar Expr Var
e
toVar (Tick Tickish Var
_ Expr Var
e) = Expr Var -> Maybe Var
toVar Expr Var
e
toVar Expr Var
_ = Maybe Var
forall a. Maybe a
Nothing

zipExact :: [a] -> [b] -> Maybe [(a, b)]
zipExact :: [a] -> [b] -> Maybe [(a, b)]
zipExact [] [] = [(a, b)] -> Maybe [(a, b)]
forall a. a -> Maybe a
Just []
zipExact (a
x:[a]
xs) (b
y:[b]
ys) = ((a
x, b
y)(a, b) -> [(a, b)] -> [(a, b)]
forall a. a -> [a] -> [a]
:) ([(a, b)] -> [(a, b)]) -> Maybe [(a, b)] -> Maybe [(a, b)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [a] -> [b] -> Maybe [(a, b)]
forall a b. [a] -> [b] -> Maybe [(a, b)]
zipExact [a]
xs [b]
ys
zipExact [a]
_ [b]
_ = Maybe [(a, b)]
forall a. Maybe a
Nothing

-- Collect information about all of the recursive calls in a function
-- definition which will be needed to check for structural termination.
getCallInfoExpr :: Env -> CoreExpr -> Result (FunInfo [SrcCall])
getCallInfoExpr :: Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr Env
env = \case
  Var (Env -> Var -> Maybe Fun
lookupFun Env
env -> Just Fun
fun) ->
    case [Param] -> [Expr Var] -> Maybe [(Param, Expr Var)]
forall a b. [a] -> [b] -> Maybe [(a, b)]
zipExact (Fun -> [Param]
funParams Fun
fun) ([Expr Var] -> [Expr Var]
forall a. [a] -> [a]
reverse ([Expr Var] -> [Expr Var]) -> [Expr Var] -> [Expr Var]
forall a b. (a -> b) -> a -> b
$ Env -> [Expr Var]
envCurrentArgs Env
env) of
      Just [(Param, Expr Var)]
args -> FunInfo [SrcCall] -> Result (FunInfo [SrcCall])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunInfo [SrcCall] -> Result (FunInfo [SrcCall]))
-> FunInfo [SrcCall] -> Result (FunInfo [SrcCall])
forall a b. (a -> b) -> a -> b
$ Var -> [SrcCall] -> FunInfo [SrcCall]
forall a. Var -> a -> FunInfo a
mkFunInfo (Fun -> Var
funName Fun
fun) [Var -> [(Param, Expr Var)] -> SrcCall
mkSrcCall (Fun -> Var
funName Fun
fun) [(Param, Expr Var)]
args]
      Maybe [(Param, Expr Var)]
Nothing -> Var
-> Doc -> Result (FunInfo [SrcCall]) -> Result (FunInfo [SrcCall])
forall a. Var -> Doc -> Result a -> Result a
addError (Fun -> Var
funName Fun
fun) Doc
"Unsaturated call to function" Result (FunInfo [SrcCall])
forall a. Monoid a => a
mempty

  App Expr Var
e Expr Var
a
    | Expr Var -> Bool
forall b. Expr b -> Bool
isTypeArg Expr Var
a -> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr Env
env Expr Var
e
    | Bool
otherwise -> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr Env
argEnv Expr Var
a Result (FunInfo [SrcCall])
-> Result (FunInfo [SrcCall]) -> Result (FunInfo [SrcCall])
forall a. Semigroup a => a -> a -> a
<> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr Env
appEnv Expr Var
e
      where
        argEnv :: Env
argEnv = Env -> Env
clearCurrentFun (Env -> Env) -> (Env -> Env) -> Env -> Env
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Env
clearCurrentArgs (Env -> Env) -> Env -> Env
forall a b. (a -> b) -> a -> b
$ Env
env
        appEnv :: Env
appEnv = Env -> Env
clearCurrentFun (Env -> Env) -> (Env -> Env) -> Env -> Env
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr Var -> Env -> Env
addArg Expr Var
a (Env -> Env) -> Env -> Env
forall a b. (a -> b) -> a -> b
$ Env
env

  Lam Var
x Expr Var
e
    | Var -> Bool
isTyVar Var
x -> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr Env
env Expr Var
e
    | Bool
otherwise -> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr (Var -> Env -> Env
addParam Var
x Env
env) Expr Var
e

  Let CoreBind
bind Expr Var
e -> Env -> CoreBind -> Result (FunInfo [SrcCall])
getCallInfoBind Env
env CoreBind
bind Result (FunInfo [SrcCall])
-> Result (FunInfo [SrcCall]) -> Result (FunInfo [SrcCall])
forall a. Semigroup a => a -> a -> a
<> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr Env
env Expr Var
e

  Case (Expr Var -> Maybe Var
toVar -> Just Var
var) Var
bndr Type
_ [Alt Var]
alts -> (Alt Var -> Result (FunInfo [SrcCall]))
-> [Alt Var] -> Result (FunInfo [SrcCall])
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Alt Var -> Result (FunInfo [SrcCall])
forall a. (a, [Var], Expr Var) -> Result (FunInfo [SrcCall])
getCallInfoAlt [Alt Var]
alts
    where
      getCallInfoAlt :: (a, [Var], Expr Var) -> Result (FunInfo [SrcCall])
getCallInfoAlt (a
_, [Var]
subterms, Expr Var
body) = Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr ([Var] -> Env
branchEnv [Var]
subterms) Expr Var
body
      branchEnv :: [Var] -> Env
branchEnv [Var]
subterms = Var -> [Var] -> Env -> Env
addSubterms Var
var [Var]
subterms (Env -> Env) -> (Env -> Env) -> Env -> Env
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Var -> Var -> Env -> Env
addSynonym Var
var Var
bndr (Env -> Env) -> Env -> Env
forall a b. (a -> b) -> a -> b
$ Env
env

  Case Expr Var
scrut Var
_ Type
_ [Alt Var]
alts -> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr Env
env Expr Var
scrut Result (FunInfo [SrcCall])
-> Result (FunInfo [SrcCall]) -> Result (FunInfo [SrcCall])
forall a. Semigroup a => a -> a -> a
<> (Alt Var -> Result (FunInfo [SrcCall]))
-> [Alt Var] -> Result (FunInfo [SrcCall])
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Alt Var -> Result (FunInfo [SrcCall])
forall a b. (a, b, Expr Var) -> Result (FunInfo [SrcCall])
getCallInfoAlt [Alt Var]
alts
    where
      getCallInfoAlt :: (a, b, Expr Var) -> Result (FunInfo [SrcCall])
getCallInfoAlt (a
_, b
_, Expr Var
body) = Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr Env
env Expr Var
body

  Cast Expr Var
e Coercion
_ -> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr Env
env Expr Var
e
  Tick Tickish Var
_ Expr Var
e -> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr Env
env Expr Var
e

  Var{} -> FunInfo [SrcCall] -> Result (FunInfo [SrcCall])
forall (f :: * -> *) a. Applicative f => a -> f a
pure FunInfo [SrcCall]
forall a. Monoid a => a
mempty
  Lit{} -> FunInfo [SrcCall] -> Result (FunInfo [SrcCall])
forall (f :: * -> *) a. Applicative f => a -> f a
pure FunInfo [SrcCall]
forall a. Monoid a => a
mempty
  Coercion{} -> FunInfo [SrcCall] -> Result (FunInfo [SrcCall])
forall (f :: * -> *) a. Applicative f => a -> f a
pure FunInfo [SrcCall]
forall a. Monoid a => a
mempty
  Type{} -> FunInfo [SrcCall] -> Result (FunInfo [SrcCall])
forall (f :: * -> *) a. Applicative f => a -> f a
pure FunInfo [SrcCall]
forall a. Monoid a => a
mempty

getCallInfoBind :: Env -> CoreBind -> Result (FunInfo [SrcCall])
getCallInfoBind :: Env -> CoreBind -> Result (FunInfo [SrcCall])
getCallInfoBind Env
env = \case
  NonRec Var
_ Expr Var
e -> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr (Env -> Env
clearCurrentFun Env
env) Expr Var
e
  Rec [] -> FunInfo [SrcCall] -> Result (FunInfo [SrcCall])
forall (f :: * -> *) a. Applicative f => a -> f a
pure FunInfo [SrcCall]
forall a. Monoid a => a
mempty
  Rec [(Var
f, Expr Var
e)] -> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr (Var -> Env -> Env
addCheckedFun Var
f (Env -> Env) -> (Env -> Env) -> Env -> Env
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Var -> Env -> Env
setCurrentFun Var
f (Env -> Env) -> Env -> Env
forall a b. (a -> b) -> a -> b
$ Env
env) Expr Var
e
  Rec [(Var, Expr Var)]
binds -> ((Var, Expr Var) -> Result (FunInfo [SrcCall]))
-> [(Var, Expr Var)] -> Result (FunInfo [SrcCall])
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Var, Expr Var) -> Result (FunInfo [SrcCall])
failBind [(Var, Expr Var)]
binds
    where failBind :: (Var, Expr Var) -> Result (FunInfo [SrcCall])
failBind (Var
f, Expr Var
e) =
            Var
-> Doc -> Result (FunInfo [SrcCall]) -> Result (FunInfo [SrcCall])
forall a. Var -> Doc -> Result a -> Result a
addError Var
f Doc
"Structural checking of mutually-recursive functions is not supported" (Result (FunInfo [SrcCall]) -> Result (FunInfo [SrcCall]))
-> Result (FunInfo [SrcCall]) -> Result (FunInfo [SrcCall])
forall a b. (a -> b) -> a -> b
$
            Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr (Env -> Env
clearCurrentFun Env
env) Expr Var
e

--------------------------------------------------------------------------------

data StructInfo = Unchanged Int | Decreasing Int

unStructInfo :: StructInfo -> Int
unStructInfo :: StructInfo -> Int
unStructInfo (Unchanged Int
p) = Int
p
unStructInfo (Decreasing Int
p) = Int
p

isDecreasing :: StructInfo -> Bool
isDecreasing :: StructInfo -> Bool
isDecreasing (Decreasing Int
_) = Bool
True
isDecreasing (Unchanged Int
_) = Bool
False

data StructCall = StructCall
  { StructCall -> Var
structCallFun :: Var
  , StructCall -> [Int]
structCallArgs :: [Int]
  , StructCall -> [Int]
structCallDecArgs :: [Int]
  }

mkStructCall :: Var -> [StructInfo] -> StructCall
mkStructCall :: Var -> [StructInfo] -> StructCall
mkStructCall Var
fun [StructInfo]
sis = StructCall :: Var -> [Int] -> [Int] -> StructCall
StructCall
  { structCallFun :: Var
structCallFun = Var
fun
  , structCallArgs :: [Int]
structCallArgs = (StructInfo -> Int) -> [StructInfo] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map StructInfo -> Int
unStructInfo [StructInfo]
sis
  , structCallDecArgs :: [Int]
structCallDecArgs = (StructInfo -> Int) -> [StructInfo] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map StructInfo -> Int
unStructInfo ([StructInfo] -> [Int])
-> ([StructInfo] -> [StructInfo]) -> [StructInfo] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (StructInfo -> Bool) -> [StructInfo] -> [StructInfo]
forall a. (a -> Bool) -> [a] -> [a]
filter StructInfo -> Bool
isDecreasing ([StructInfo] -> [Int]) -> [StructInfo] -> [Int]
forall a b. (a -> b) -> a -> b
$ [StructInfo]
sis
  }

-- This is where we  check a function call. We go through  the list of arguments
-- and find the  indices of those which are decreasing.  Note that this approach
-- is only guaranteed to  work when the arguments to the  function are named, so
-- e.g.
-- foo (x:xs) (y:ys) = foo xs (y:ys)
-- won't necessarily work, but
-- foo (x:xs) yys@(y:ys) = foo xs yys
-- will.
toStructCall :: SrcCall -> StructCall
toStructCall :: SrcCall -> StructCall
toStructCall SrcCall
srcCall = Var -> [StructInfo] -> StructCall
mkStructCall (SrcCall -> Var
srcCallFun SrcCall
srcCall) ([StructInfo] -> StructCall) -> [StructInfo] -> StructCall
forall a b. (a -> b) -> a -> b
$ Int -> [(Param, Expr Var)] -> [StructInfo]
toStructArgs Int
0 (SrcCall -> [(Param, Expr Var)]
srcCallArgs SrcCall
srcCall)
  where
    toStructArgs :: Int -> [(Param, Expr Var)] -> [StructInfo]
toStructArgs Int
_ [] = []
    toStructArgs Int
index ((Param
param, Expr Var -> Maybe Var
toVar -> Just Var
v):[(Param, Expr Var)]
args)
      | Var
v Var -> Param -> Bool
`isParam` Param
param = Int -> StructInfo
Unchanged Int
index StructInfo -> [StructInfo] -> [StructInfo]
forall a. a -> [a] -> [a]
: Int -> [(Param, Expr Var)] -> [StructInfo]
toStructArgs (Int
index Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [(Param, Expr Var)]
args
      | Var
v Var -> Param -> Bool
`isParamSubterm` Param
param = Int -> StructInfo
Decreasing Int
index StructInfo -> [StructInfo] -> [StructInfo]
forall a. a -> [a] -> [a]
: Int -> [(Param, Expr Var)] -> [StructInfo]
toStructArgs (Int
index Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [(Param, Expr Var)]
args
    toStructArgs Int
index ((Param, Expr Var)
_:[(Param, Expr Var)]
args) = Int -> [(Param, Expr Var)] -> [StructInfo]
toStructArgs (Int
index Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [(Param, Expr Var)]
args

-- Check if there is some way to lexicographically order the arguments so that
-- they are structurally decreasing. Essentially, in order for there to be, we
-- must be able to find some argument which is always either unchanged or
-- decreasing. We can then remove every call where that argument is decreasing
-- and recurse.
structDecreasing :: Var -> [StructCall] -> Result ()
structDecreasing :: Var -> [StructCall] -> Result ()
structDecreasing Var
_ [] = Result ()
forall a. Monoid a => a
mempty
structDecreasing Var
funName [StructCall]
calls
  | [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
sharedArgs = Var -> Doc -> Result () -> Result ()
forall a. Var -> Doc -> Result a -> Result a
addError Var
funName Doc
"Non-structural recursion" Result ()
forall a. Monoid a => a
mempty
  | Bool
otherwise = Var -> [StructCall] -> Result ()
structDecreasing Var
funName ([StructCall] -> Result ()) -> [StructCall] -> Result ()
forall a b. (a -> b) -> a -> b
$ ((StructCall -> StructCall) -> [StructCall] -> [StructCall]
forall a b. (a -> b) -> [a] -> [b]
map StructCall -> StructCall
removeSharedArgs ([StructCall] -> [StructCall])
-> ([StructCall] -> [StructCall]) -> [StructCall] -> [StructCall]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (StructCall -> Bool) -> [StructCall] -> [StructCall]
forall a. (a -> Bool) -> [a] -> [a]
filter StructCall -> Bool
noneDecreasing) [StructCall]
calls
  where
    sharedArgs :: [Int]
sharedArgs = ([Int] -> [Int] -> [Int]) -> [[Int]] -> [Int]
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 [Int] -> [Int] -> [Int]
forall a. Eq a => [a] -> [a] -> [a]
L.intersect (StructCall -> [Int]
structCallArgs (StructCall -> [Int]) -> [StructCall] -> [[Int]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [StructCall]
calls)
    noneDecreasing :: StructCall -> Bool
noneDecreasing StructCall
call = [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([Int] -> Bool) -> [Int] -> Bool
forall a b. (a -> b) -> a -> b
$ StructCall -> [Int]
structCallDecArgs StructCall
call [Int] -> [Int] -> [Int]
forall a. Eq a => [a] -> [a] -> [a]
`L.intersect` [Int]
sharedArgs
    removeSharedArgs :: StructCall -> StructCall
removeSharedArgs StructCall
call = StructCall
call { structCallArgs :: [Int]
structCallArgs = StructCall -> [Int]
structCallArgs StructCall
call [Int] -> [Int] -> [Int]
forall a. Eq a => [a] -> [a] -> [a]
L.\\ [Int]
sharedArgs }