{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}

module Language.Haskell.Liquid.Termination.Structural (terminationVars) where

import Language.Haskell.Liquid.Types hiding (isDecreasing)
import Language.Haskell.Liquid.GHC.Misc (showPpr)
import Liquid.GHC.API  as GHC hiding ( showPpr
                                                      , Env
                                                      , text
                                                      )

import Text.PrettyPrint.HughesPJ hiding ((<>))

import qualified Data.HashSet as HS
import Data.HashSet (HashSet)
import qualified Data.Map.Strict as M
import Data.Map.Strict (Map)
import qualified Data.List as L

import Control.Monad (liftM, ap)
import Data.Foldable (fold)

terminationVars :: TargetInfo -> [Var]
terminationVars :: TargetInfo -> [Var]
terminationVars TargetInfo
info = TargetInfo -> [CoreBind]
failingBinds TargetInfo
info forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CoreBind -> [Var]
allBoundVars

failingBinds :: TargetInfo -> [CoreBind]
failingBinds :: TargetInfo -> [CoreBind]
failingBinds TargetInfo
info = forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Result a -> Bool
hasErrors forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoreBind -> Result ()
checkBind) [CoreBind]
structBinds
  where
    structCheckWholeProgram :: Bool
structCheckWholeProgram = forall a. HasConfig a => a -> Bool
structuralTerm TargetInfo
info
    program :: [CoreBind]
program = TargetSrc -> [CoreBind]
giCbs forall b c a. (b -> c) -> (a -> b) -> a -> c
. TargetInfo -> TargetSrc
giSrc forall a b. (a -> b) -> a -> b
$ TargetInfo
info
    structFuns :: HashSet Var
structFuns = GhcSpecTerm -> HashSet Var
gsStTerm forall b c a. (b -> c) -> (a -> b) -> a -> c
. TargetSpec -> GhcSpecTerm
gsTerm forall b c a. (b -> c) -> (a -> b) -> a -> c
. TargetInfo -> TargetSpec
giSpec forall a b. (a -> b) -> a -> b
$ TargetInfo
info
    structBinds :: [CoreBind]
structBinds
      | Bool
structCheckWholeProgram = [CoreBind]
program
      | Bool
otherwise = HashSet Var -> [CoreBind] -> [CoreBind]
findStructBinds HashSet Var
structFuns [CoreBind]
program

checkBind :: CoreBind -> Result ()
checkBind :: CoreBind -> Result ()
checkBind CoreBind
bind = do
  FunInfo [SrcCall]
srcCallInfo <- Env -> CoreBind -> Result (FunInfo [SrcCall])
getCallInfoBind Env
emptyEnv (CoreBind -> CoreBind
deShadowBind CoreBind
bind)
  let structCallInfo :: FunInfo [StructCall]
structCallInfo = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SrcCall -> StructCall
toStructCall forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FunInfo [SrcCall]
srcCallInfo
  forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold forall a b. (a -> b) -> a -> b
$ forall a b. (Var -> a -> b) -> FunInfo a -> FunInfo b
mapWithFun Var -> [StructCall] -> Result ()
structDecreasing FunInfo [StructCall]
structCallInfo

deShadowBind :: CoreBind -> CoreBind
deShadowBind :: CoreBind -> CoreBind
deShadowBind CoreBind
bind = forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ [CoreBind] -> [CoreBind]
deShadowBinds [CoreBind
bind]

findStructBinds :: HashSet Var -> CoreProgram -> [CoreBind]
findStructBinds :: HashSet Var -> [CoreBind] -> [CoreBind]
findStructBinds HashSet Var
structFuns [CoreBind]
program = forall a. (a -> Bool) -> [a] -> [a]
filter CoreBind -> Bool
isStructBind [CoreBind]
program
  where
    isStructBind :: CoreBind -> Bool
isStructBind (NonRec Var
f Expr Var
_) = Var
f forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`HS.member` HashSet Var
structFuns
    isStructBind (Rec []) = Bool
False
    isStructBind (Rec ((Var
f,Expr Var
_):[(Var, Expr Var)]
xs)) = Var
f forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`HS.member` HashSet Var
structFuns Bool -> Bool -> Bool
|| CoreBind -> Bool
isStructBind (forall b. [(b, Expr b)] -> Bind b
Rec [(Var, Expr Var)]
xs)

allBoundVars :: CoreBind -> [Var]
allBoundVars :: CoreBind -> [Var]
allBoundVars (NonRec Var
v Expr Var
e) = Var
v forall a. a -> [a] -> [a]
: (Expr Var -> [CoreBind]
nextBinds Expr Var
e forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CoreBind -> [Var]
allBoundVars)
allBoundVars (Rec [(Var, Expr Var)]
binds) = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Var, Expr Var)]
binds forall a. [a] -> [a] -> [a]
++ (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(Var, Expr Var)]
binds forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Expr Var -> [CoreBind]
nextBinds forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CoreBind -> [Var]
allBoundVars)

nextBinds :: CoreExpr -> [CoreBind]
nextBinds :: Expr Var -> [CoreBind]
nextBinds = \case
  App Expr Var
e Expr Var
a -> Expr Var -> [CoreBind]
nextBinds Expr Var
e forall a. [a] -> [a] -> [a]
++ Expr Var -> [CoreBind]
nextBinds Expr Var
a
  Lam Var
_ Expr Var
e -> Expr Var -> [CoreBind]
nextBinds Expr Var
e
  Let CoreBind
b Expr Var
e -> CoreBind
b forall a. a -> [a] -> [a]
: Expr Var -> [CoreBind]
nextBinds Expr Var
e
  Case Expr Var
scrut Var
_ Type
_ [Alt Var]
alts -> Expr Var -> [CoreBind]
nextBinds Expr Var
scrut forall a. [a] -> [a] -> [a]
++ ([Expr Var
body | Alt AltCon
_ [Var]
_ Expr Var
body <- [Alt Var]
alts] forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Expr Var -> [CoreBind]
nextBinds)
  Cast Expr Var
e CoercionR
_ -> Expr Var -> [CoreBind]
nextBinds Expr Var
e
  Tick CoreTickish
_ Expr Var
e -> Expr Var -> [CoreBind]
nextBinds Expr Var
e
  Var{} -> []
  Lit{} -> []
  Coercion{} -> []
  Type{} -> []

------------------------------------------------------------------------------------------

-- Note that this is *not* the Either/Maybe monad, since it's important that we
-- collect all errors, not just the first error.
data Result a = Result
  { forall a. Result a -> a
resultVal :: a
  , forall a. Result a -> [TermError]
resultErrors :: [TermError]
  } deriving (Int -> Result a -> ShowS
forall a. Show a => Int -> Result a -> ShowS
forall a. Show a => [Result a] -> ShowS
forall a. Show a => Result a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Result a] -> ShowS
$cshowList :: forall a. Show a => [Result a] -> ShowS
show :: Result a -> String
$cshow :: forall a. Show a => Result a -> String
showsPrec :: Int -> Result a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Result a -> ShowS
Show)

data TermError = TE
  { TermError -> Var
teVar   :: Var
  , TermError -> UserError
teError :: UserError
  } deriving (Int -> TermError -> ShowS
[TermError] -> ShowS
TermError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TermError] -> ShowS
$cshowList :: [TermError] -> ShowS
show :: TermError -> String
$cshow :: TermError -> String
showsPrec :: Int -> TermError -> ShowS
$cshowsPrec :: Int -> TermError -> ShowS
Show)

hasErrors :: Result a -> Bool
hasErrors :: forall a. Result a -> Bool
hasErrors = Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Result a -> [TermError]
resultErrors

addError :: Var -> Doc -> Result a -> Result a
addError :: forall a. Var -> Doc -> Result a -> Result a
addError Var
fun Doc
expl (Result a
x [TermError]
errs) = forall a. a -> [TermError] -> Result a
Result a
x (Var -> Doc -> TermError
mkTermError Var
fun Doc
expl forall a. a -> [a] -> [a]
: [TermError]
errs)

mkTermError :: Var -> Doc -> TermError
mkTermError :: Var -> Doc -> TermError
mkTermError Var
fun Doc
expl = TE
  { teVar :: Var
teVar   = Var
fun
  , teError :: UserError
teError = forall t. SrcSpan -> Doc -> Doc -> TError t
ErrStTerm (forall a. NamedThing a => a -> SrcSpan
getSrcSpan Var
fun) (String -> Doc
text forall a b. (a -> b) -> a -> b
$ forall a. Outputable a => a -> String
showPpr Var
fun) Doc
expl
  }

instance Monoid a => Monoid (Result a) where
  mempty :: Result a
mempty  = forall a. a -> [TermError] -> Result a
Result forall a. Monoid a => a
mempty []

instance Semigroup a => Semigroup (Result a) where
  Result a
x [TermError]
e1 <> :: Result a -> Result a -> Result a
<> Result a
y [TermError]
e2 = forall a. a -> [TermError] -> Result a
Result (a
x forall a. Semigroup a => a -> a -> a
<> a
y) ([TermError]
e1 forall a. [a] -> [a] -> [a]
++ [TermError]
e2)

instance Monad Result where
  Result a
x [TermError]
e1 >>= :: forall a b. Result a -> (a -> Result b) -> Result b
>>= a -> Result b
f =
    let Result b
y [TermError]
e2 = a -> Result b
f a
x in
    forall a. a -> [TermError] -> Result a
Result b
y ([TermError]
e2 forall a. [a] -> [a] -> [a]
++ [TermError]
e1)

instance Applicative Result where
  pure :: forall a. a -> Result a
pure a
x = forall a. a -> [TermError] -> Result a
Result a
x []
  <*> :: forall a b. Result (a -> b) -> Result a -> Result b
(<*>) = forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

instance Functor Result where
  fmap :: forall a b. (a -> b) -> Result a -> Result b
fmap = forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM

--------------------------------------------------------------------------------

data Env = Env
  { Env -> Maybe Var
envCurrentFun :: Maybe Var
  , Env -> [Expr Var]
envCurrentArgs :: [CoreArg]
  , Env -> [Fun]
envCheckedFuns :: [Fun]
  }

data Fun = Fun
  { Fun -> Var
funName :: Var
  , Fun -> [Param]
funParams :: [Param]
  }

data Param = Param
  { Param -> VarSet
paramNames :: VarSet
  , Param -> VarSet
paramSubterms :: VarSet
  } deriving (Param -> Param -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Param -> Param -> Bool
$c/= :: Param -> Param -> Bool
== :: Param -> Param -> Bool
$c== :: Param -> Param -> Bool
Eq)

emptyEnv :: Env
emptyEnv :: Env
emptyEnv = Env
  { envCurrentFun :: Maybe Var
envCurrentFun = forall a. Maybe a
Nothing
  , envCurrentArgs :: [Expr Var]
envCurrentArgs = []
  , envCheckedFuns :: [Fun]
envCheckedFuns = []
  }

mkFun :: Var -> Fun
mkFun :: Var -> Fun
mkFun Var
name = Fun
  { funName :: Var
funName = Var
name
  , funParams :: [Param]
funParams = []
  }

mkParam :: Var -> Param
mkParam :: Var -> Param
mkParam Var
name = Param
  { paramNames :: VarSet
paramNames = Var -> VarSet
unitVarSet Var
name
  , paramSubterms :: VarSet
paramSubterms = VarSet
emptyVarSet
  }

lookupFun :: Env -> Var -> Maybe Fun
lookupFun :: Env -> Var -> Maybe Fun
lookupFun Env
env Var
name = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
L.find (\Fun
fun -> Fun -> Var
funName Fun
fun forall a. Eq a => a -> a -> Bool
== Var
name) forall a b. (a -> b) -> a -> b
$ Env -> [Fun]
envCheckedFuns Env
env

clearCurrentArgs :: Env -> Env
clearCurrentArgs :: Env -> Env
clearCurrentArgs Env
env = Env
env { envCurrentArgs :: [Expr Var]
envCurrentArgs = [] }

setCurrentFun :: Var -> Env -> Env
setCurrentFun :: Var -> Env -> Env
setCurrentFun Var
fun Env
env = Env
env { envCurrentFun :: Maybe Var
envCurrentFun = forall a. a -> Maybe a
Just Var
fun }

clearCurrentFun :: Env -> Env
clearCurrentFun :: Env -> Env
clearCurrentFun Env
env = Env
env { envCurrentFun :: Maybe Var
envCurrentFun = forall a. Maybe a
Nothing }

addArg :: CoreArg -> Env -> Env
addArg :: Expr Var -> Env -> Env
addArg Expr Var
arg Env
env = Env
env { envCurrentArgs :: [Expr Var]
envCurrentArgs = Expr Var
argforall a. a -> [a] -> [a]
:Env -> [Expr Var]
envCurrentArgs Env
env }

addParam :: Var -> Env -> Env
addParam :: Var -> Env -> Env
addParam Var
param Env
env = case Env -> Maybe Var
envCurrentFun Env
env of
  Maybe Var
Nothing -> Env
env
  Just Var
name -> Env
env { envCheckedFuns :: [Fun]
envCheckedFuns = Var -> Fun -> Fun
updateFunNamed Var
name forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Env -> [Fun]
envCheckedFuns Env
env }
  where
    updateFunNamed :: Var -> Fun -> Fun
updateFunNamed Var
name Fun
fun
      | Fun -> Var
funName Fun
fun forall a. Eq a => a -> a -> Bool
== Var
name = Fun
fun { funParams :: [Param]
funParams = Var -> Param
mkParam Var
param forall a. a -> [a] -> [a]
: Fun -> [Param]
funParams Fun
fun }
      | Bool
otherwise = Fun
fun

addSynonym :: Var -> Var -> Env -> Env
addSynonym :: Var -> Var -> Env -> Env
addSynonym Var
oldName Var
newName' Env
env = Env
env { envCheckedFuns :: [Fun]
envCheckedFuns = Fun -> Fun
updateFun forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Env -> [Fun]
envCheckedFuns Env
env }
  where
    updateFun :: Fun -> Fun
updateFun Fun
fun = Fun
fun { funParams :: [Param]
funParams = Param -> Param
updateParam forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Fun -> [Param]
funParams Fun
fun }
    updateParam :: Param -> Param
updateParam Param
param
      | Var
oldName Var -> VarSet -> Bool
`elemVarSet` Param -> VarSet
paramNames Param
param = Param
param { paramNames :: VarSet
paramNames = Param -> VarSet
paramNames Param
param VarSet -> Var -> VarSet
`extendVarSet` Var
newName' }
      | Var
oldName Var -> VarSet -> Bool
`elemVarSet` Param -> VarSet
paramSubterms Param
param = Param
param { paramSubterms :: VarSet
paramSubterms = Param -> VarSet
paramSubterms Param
param VarSet -> Var -> VarSet
`extendVarSet` Var
newName' }
      | Bool
otherwise = Param
param

addSubterms :: Var -> [Var] -> Env -> Env
addSubterms :: Var -> [Var] -> Env -> Env
addSubterms Var
var [Var]
subterms Env
env = Env
env { envCheckedFuns :: [Fun]
envCheckedFuns = Fun -> Fun
updateFun forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Env -> [Fun]
envCheckedFuns Env
env }
  where
    updateFun :: Fun -> Fun
updateFun Fun
fun = Fun
fun { funParams :: [Param]
funParams = Param -> Param
updateParam forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Fun -> [Param]
funParams Fun
fun }
    updateParam :: Param -> Param
updateParam Param
param
      | Var
var Var -> VarSet -> Bool
`elemVarSet` Param -> VarSet
paramNames Param
param Bool -> Bool -> Bool
|| Var
var Var -> VarSet -> Bool
`elemVarSet` Param -> VarSet
paramSubterms Param
param = Param
param { paramSubterms :: VarSet
paramSubterms = Param -> VarSet
paramSubterms Param
param VarSet -> [Var] -> VarSet
`extendVarSetList` [Var]
subterms }
      | Bool
otherwise = Param
param

addCheckedFun :: Var -> Env -> Env
addCheckedFun :: Var -> Env -> Env
addCheckedFun Var
name Env
env = Env
env { envCheckedFuns :: [Fun]
envCheckedFuns = Var -> Fun
mkFun Var
name forall a. a -> [a] -> [a]
: Env -> [Fun]
envCheckedFuns Env
env }

isParam :: Var -> Param -> Bool
Var
var isParam :: Var -> Param -> Bool
`isParam` Param
param = Var
var Var -> VarSet -> Bool
`elemVarSet` Param -> VarSet
paramNames Param
param

isParamSubterm :: Var -> Param -> Bool
Var
var isParamSubterm :: Var -> Param -> Bool
`isParamSubterm` Param
param = Var
var Var -> VarSet -> Bool
`elemVarSet` Param -> VarSet
paramSubterms Param
param

--------------------------------------------------------------------------------

newtype FunInfo a = FunInfo (Map Var a)

data SrcCall = SrcCall
  { SrcCall -> Var
srcCallFun :: Var
  , SrcCall -> [(Param, Expr Var)]
srcCallArgs :: [(Param, CoreArg)]
  }

instance Semigroup a => Semigroup (FunInfo a) where
  FunInfo Map Var a
xs <> :: FunInfo a -> FunInfo a -> FunInfo a
<> FunInfo Map Var a
ys = forall a. Map Var a -> FunInfo a
FunInfo forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith forall a. Semigroup a => a -> a -> a
(<>) Map Var a
xs Map Var a
ys

instance Semigroup a => Monoid (FunInfo a) where
  mempty :: FunInfo a
mempty = forall a. Map Var a -> FunInfo a
FunInfo forall k a. Map k a
M.empty

instance Functor FunInfo where
  fmap :: forall a b. (a -> b) -> FunInfo a -> FunInfo b
fmap a -> b
f (FunInfo Map Var a
xs) = forall a. Map Var a -> FunInfo a
FunInfo (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f Map Var a
xs)

instance Foldable FunInfo where
  foldMap :: forall m a. Monoid m => (a -> m) -> FunInfo a -> m
foldMap a -> m
f (FunInfo Map Var a
m) = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap a -> m
f Map Var a
m

mapWithFun :: (Var -> a -> b) -> FunInfo a -> FunInfo b
mapWithFun :: forall a b. (Var -> a -> b) -> FunInfo a -> FunInfo b
mapWithFun Var -> a -> b
f (FunInfo Map Var a
x) = forall a. Map Var a -> FunInfo a
FunInfo (forall k a b. (k -> a -> b) -> Map k a -> Map k b
M.mapWithKey Var -> a -> b
f Map Var a
x)

mkFunInfo :: Var -> a -> FunInfo a
mkFunInfo :: forall a. Var -> a -> FunInfo a
mkFunInfo Var
fun a
x = forall a. Map Var a -> FunInfo a
FunInfo forall a b. (a -> b) -> a -> b
$ forall k a. k -> a -> Map k a
M.singleton Var
fun a
x

mkSrcCall :: Var -> [(Param, CoreArg)] -> SrcCall
mkSrcCall :: Var -> [(Param, Expr Var)] -> SrcCall
mkSrcCall Var
fun [(Param, Expr Var)]
args = SrcCall
  { srcCallFun :: Var
srcCallFun = Var
fun
  , srcCallArgs :: [(Param, Expr Var)]
srcCallArgs = [(Param, Expr Var)]
args
  }

toVar :: CoreExpr -> Maybe Var
toVar :: Expr Var -> Maybe Var
toVar (Var Var
x) = forall a. a -> Maybe a
Just Var
x
toVar (Cast Expr Var
e CoercionR
_) = Expr Var -> Maybe Var
toVar Expr Var
e
toVar (Tick CoreTickish
_ Expr Var
e) = Expr Var -> Maybe Var
toVar Expr Var
e
toVar Expr Var
_ = forall a. Maybe a
Nothing

zipExact :: [a] -> [b] -> Maybe [(a, b)]
zipExact :: forall a b. [a] -> [b] -> Maybe [(a, b)]
zipExact [] [] = forall a. a -> Maybe a
Just []
zipExact (a
x:[a]
xs) (b
y:[b]
ys) = ((a
x, b
y)forall a. a -> [a] -> [a]
:) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b. [a] -> [b] -> Maybe [(a, b)]
zipExact [a]
xs [b]
ys
zipExact [a]
_ [b]
_ = forall a. Maybe a
Nothing

-- Collect information about all of the recursive calls in a function
-- definition which will be needed to check for structural termination.
getCallInfoExpr :: Env -> CoreExpr -> Result (FunInfo [SrcCall])
getCallInfoExpr :: Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr Env
env = \case
  Var (Env -> Var -> Maybe Fun
lookupFun Env
env -> Just Fun
fun) ->
    case forall a b. [a] -> [b] -> Maybe [(a, b)]
zipExact (Fun -> [Param]
funParams Fun
fun) (forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ Env -> [Expr Var]
envCurrentArgs Env
env) of
      Just [(Param, Expr Var)]
args -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Var -> a -> FunInfo a
mkFunInfo (Fun -> Var
funName Fun
fun) [Var -> [(Param, Expr Var)] -> SrcCall
mkSrcCall (Fun -> Var
funName Fun
fun) [(Param, Expr Var)]
args]
      Maybe [(Param, Expr Var)]
Nothing -> forall a. Var -> Doc -> Result a -> Result a
addError (Fun -> Var
funName Fun
fun) Doc
"Unsaturated call to function" forall a. Monoid a => a
mempty

  App Expr Var
e Expr Var
a
    | forall b. Expr b -> Bool
isTypeArg Expr Var
a -> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr Env
env Expr Var
e
    | Bool
otherwise -> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr Env
argEnv Expr Var
a forall a. Semigroup a => a -> a -> a
<> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr Env
appEnv Expr Var
e
      where
        argEnv :: Env
argEnv = Env -> Env
clearCurrentFun forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Env
clearCurrentArgs forall a b. (a -> b) -> a -> b
$ Env
env
        appEnv :: Env
appEnv = Env -> Env
clearCurrentFun forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr Var -> Env -> Env
addArg Expr Var
a forall a b. (a -> b) -> a -> b
$ Env
env

  Lam Var
x Expr Var
e
    | Var -> Bool
isTyVar Var
x -> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr Env
env Expr Var
e
    | Bool
otherwise -> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr (Var -> Env -> Env
addParam Var
x Env
env) Expr Var
e

  Let CoreBind
bind Expr Var
e -> Env -> CoreBind -> Result (FunInfo [SrcCall])
getCallInfoBind Env
env CoreBind
bind forall a. Semigroup a => a -> a -> a
<> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr Env
env Expr Var
e

  Case (Expr Var -> Maybe Var
toVar -> Just Var
var) Var
bndr Type
_ [Alt Var]
alts -> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Alt Var -> Result (FunInfo [SrcCall])
getCallInfoAlt [Alt Var]
alts
    where
      getCallInfoAlt :: Alt Var -> Result (FunInfo [SrcCall])
getCallInfoAlt (Alt AltCon
_ [Var]
subterms Expr Var
body) = Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr ([Var] -> Env
branchEnv [Var]
subterms) Expr Var
body
      branchEnv :: [Var] -> Env
branchEnv [Var]
subterms = Var -> [Var] -> Env -> Env
addSubterms Var
var [Var]
subterms forall b c a. (b -> c) -> (a -> b) -> a -> c
. Var -> Var -> Env -> Env
addSynonym Var
var Var
bndr forall a b. (a -> b) -> a -> b
$ Env
env

  Case Expr Var
scrut Var
_ Type
_ [Alt Var]
alts -> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr Env
env Expr Var
scrut forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Alt Var -> Result (FunInfo [SrcCall])
getCallInfoAlt [Alt Var]
alts
    where
      getCallInfoAlt :: Alt Var -> Result (FunInfo [SrcCall])
getCallInfoAlt (Alt AltCon
_ [Var]
_ Expr Var
body) = Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr Env
env Expr Var
body

  Cast Expr Var
e CoercionR
_ -> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr Env
env Expr Var
e
  Tick CoreTickish
_ Expr Var
e -> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr Env
env Expr Var
e

  Var{} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
  Lit{} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
  Coercion{} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
  Type{} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty

getCallInfoBind :: Env -> CoreBind -> Result (FunInfo [SrcCall])
getCallInfoBind :: Env -> CoreBind -> Result (FunInfo [SrcCall])
getCallInfoBind Env
env = \case
  NonRec Var
_ Expr Var
e -> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr (Env -> Env
clearCurrentFun Env
env) Expr Var
e
  Rec [] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
  Rec [(Var
f, Expr Var
e)] -> Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr (Var -> Env -> Env
addCheckedFun Var
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. Var -> Env -> Env
setCurrentFun Var
f forall a b. (a -> b) -> a -> b
$ Env
env) Expr Var
e
  Rec [(Var, Expr Var)]
binds -> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Var, Expr Var) -> Result (FunInfo [SrcCall])
failBind [(Var, Expr Var)]
binds
    where failBind :: (Var, Expr Var) -> Result (FunInfo [SrcCall])
failBind (Var
f, Expr Var
e) =
            forall a. Var -> Doc -> Result a -> Result a
addError Var
f Doc
"Structural checking of mutually-recursive functions is not supported" forall a b. (a -> b) -> a -> b
$
            Env -> Expr Var -> Result (FunInfo [SrcCall])
getCallInfoExpr (Env -> Env
clearCurrentFun Env
env) Expr Var
e

--------------------------------------------------------------------------------

data StructInfo = Unchanged Int | Decreasing Int

unStructInfo :: StructInfo -> Int
unStructInfo :: StructInfo -> Int
unStructInfo (Unchanged Int
p) = Int
p
unStructInfo (Decreasing Int
p) = Int
p

isDecreasing :: StructInfo -> Bool
isDecreasing :: StructInfo -> Bool
isDecreasing (Decreasing Int
_) = Bool
True
isDecreasing (Unchanged Int
_) = Bool
False

data StructCall = StructCall
  { StructCall -> Var
structCallFun :: Var
  , StructCall -> [Int]
structCallArgs :: [Int]
  , StructCall -> [Int]
structCallDecArgs :: [Int]
  }

mkStructCall :: Var -> [StructInfo] -> StructCall
mkStructCall :: Var -> [StructInfo] -> StructCall
mkStructCall Var
fun [StructInfo]
sis = StructCall
  { structCallFun :: Var
structCallFun = Var
fun
  , structCallArgs :: [Int]
structCallArgs = forall a b. (a -> b) -> [a] -> [b]
map StructInfo -> Int
unStructInfo [StructInfo]
sis
  , structCallDecArgs :: [Int]
structCallDecArgs = forall a b. (a -> b) -> [a] -> [b]
map StructInfo -> Int
unStructInfo forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter StructInfo -> Bool
isDecreasing forall a b. (a -> b) -> a -> b
$ [StructInfo]
sis
  }

-- This is where we  check a function call. We go through  the list of arguments
-- and find the  indices of those which are decreasing.  Note that this approach
-- is only guaranteed to  work when the arguments to the  function are named, so
-- e.g.
-- foo (x:xs) (y:ys) = foo xs (y:ys)
-- won't necessarily work, but
-- foo (x:xs) yys@(y:ys) = foo xs yys
-- will.
toStructCall :: SrcCall -> StructCall
toStructCall :: SrcCall -> StructCall
toStructCall SrcCall
srcCall = Var -> [StructInfo] -> StructCall
mkStructCall (SrcCall -> Var
srcCallFun SrcCall
srcCall) forall a b. (a -> b) -> a -> b
$ Int -> [(Param, Expr Var)] -> [StructInfo]
toStructArgs Int
0 (SrcCall -> [(Param, Expr Var)]
srcCallArgs SrcCall
srcCall)
  where
    toStructArgs :: Int -> [(Param, Expr Var)] -> [StructInfo]
toStructArgs Int
_ [] = []
    toStructArgs Int
index ((Param
param, Expr Var -> Maybe Var
toVar -> Just Var
v):[(Param, Expr Var)]
args)
      | Var
v Var -> Param -> Bool
`isParam` Param
param = Int -> StructInfo
Unchanged Int
index forall a. a -> [a] -> [a]
: Int -> [(Param, Expr Var)] -> [StructInfo]
toStructArgs (Int
index forall a. Num a => a -> a -> a
+ Int
1) [(Param, Expr Var)]
args
      | Var
v Var -> Param -> Bool
`isParamSubterm` Param
param = Int -> StructInfo
Decreasing Int
index forall a. a -> [a] -> [a]
: Int -> [(Param, Expr Var)] -> [StructInfo]
toStructArgs (Int
index forall a. Num a => a -> a -> a
+ Int
1) [(Param, Expr Var)]
args
    toStructArgs Int
index ((Param, Expr Var)
_:[(Param, Expr Var)]
args) = Int -> [(Param, Expr Var)] -> [StructInfo]
toStructArgs (Int
index forall a. Num a => a -> a -> a
+ Int
1) [(Param, Expr Var)]
args

-- Check if there is some way to lexicographically order the arguments so that
-- they are structurally decreasing. Essentially, in order for there to be, we
-- must be able to find some argument which is always either unchanged or
-- decreasing. We can then remove every call where that argument is decreasing
-- and recurse.
structDecreasing :: Var -> [StructCall] -> Result ()
structDecreasing :: Var -> [StructCall] -> Result ()
structDecreasing Var
_ [] = forall a. Monoid a => a
mempty
structDecreasing Var
funName [StructCall]
calls
  | forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
sharedArgs = forall a. Var -> Doc -> Result a -> Result a
addError Var
funName Doc
"Non-structural recursion" forall a. Monoid a => a
mempty
  | Bool
otherwise = Var -> [StructCall] -> Result ()
structDecreasing Var
funName forall a b. (a -> b) -> a -> b
$ (forall a b. (a -> b) -> [a] -> [b]
map StructCall -> StructCall
removeSharedArgs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter StructCall -> Bool
noneDecreasing) [StructCall]
calls
  where
    sharedArgs :: [Int]
sharedArgs = forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 forall a. Eq a => [a] -> [a] -> [a]
L.intersect (StructCall -> [Int]
structCallArgs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [StructCall]
calls)
    noneDecreasing :: StructCall -> Bool
noneDecreasing StructCall
call = forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ StructCall -> [Int]
structCallDecArgs StructCall
call forall a. Eq a => [a] -> [a] -> [a]
`L.intersect` [Int]
sharedArgs
    removeSharedArgs :: StructCall -> StructCall
removeSharedArgs StructCall
call = StructCall
call { structCallArgs :: [Int]
structCallArgs = StructCall -> [Int]
structCallArgs StructCall
call forall a. Eq a => [a] -> [a] -> [a]
L.\\ [Int]
sharedArgs }