-- | Check pattern match exhaustiveness, since we don't really handle that in
-- source code.
--
-- This is pretty easy because of how patterns work in Kempe.
--
-- Some of this code is from Dickinson, but we don't need the Maranget approach
-- because the pattern matching is simpler in Kempe.
module Kempe.Check.Pattern ( checkModuleExhaustive
                           ) where

import           Control.Monad              (forM_)
import           Control.Monad.State.Strict (State, execState)
import           Data.Coerce                (coerce)
import           Data.Foldable              (toList, traverse_)
import           Data.Foldable.Ext
import qualified Data.IntMap.Strict         as IM
import qualified Data.IntSet                as IS
import           Data.List.NonEmpty         (NonEmpty (..))
import           Kempe.AST
import           Kempe.Error
import           Kempe.Name
import           Kempe.Unique
import           Lens.Micro                 (Lens')
import           Lens.Micro.Mtl             (modifying)

checkAtom :: PatternEnv -> Atom c b -> Maybe (Error b)
checkAtom :: PatternEnv -> Atom c b -> Maybe (Error b)
checkAtom PatternEnv
env (Case b
l NonEmpty (Pattern c b, [Atom c b])
ls) =
    if PatternEnv -> NonEmpty (Pattern c b) -> Bool
forall c b. PatternEnv -> NonEmpty (Pattern c b) -> Bool
isExhaustive PatternEnv
env (NonEmpty (Pattern c b) -> Bool) -> NonEmpty (Pattern c b) -> Bool
forall a b. (a -> b) -> a -> b
$ ((Pattern c b, [Atom c b]) -> Pattern c b)
-> NonEmpty (Pattern c b, [Atom c b]) -> NonEmpty (Pattern c b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Pattern c b, [Atom c b]) -> Pattern c b
forall a b. (a, b) -> a
fst NonEmpty (Pattern c b, [Atom c b])
ls
        then Maybe (Error b)
forall a. Maybe a
Nothing
        else Error b -> Maybe (Error b)
forall a. a -> Maybe a
Just (b -> Error b
forall a. a -> Error a
InexhaustiveMatch b
l)
checkAtom PatternEnv
_ Atom c b
_ = Maybe (Error b)
forall a. Maybe a
Nothing

checkDecl :: PatternEnv -> KempeDecl a c b -> Maybe (Error b)
checkDecl :: PatternEnv -> KempeDecl a c b -> Maybe (Error b)
checkDecl PatternEnv
env (FunDecl b
_ Name b
_ [KempeTy a]
_ [KempeTy a]
_ [Atom c b]
as) = (Atom c b -> Maybe (Error b)) -> [Atom c b] -> Maybe (Error b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Alternative f) =>
(a -> f b) -> t a -> f b
foldMapAlternative (PatternEnv -> Atom c b -> Maybe (Error b)
forall c b. PatternEnv -> Atom c b -> Maybe (Error b)
checkAtom PatternEnv
env) [Atom c b]
as
checkDecl PatternEnv
_ KempeDecl a c b
_                      = Maybe (Error b)
forall a. Maybe a
Nothing

checkModule :: PatternEnv -> Declarations a c b -> Maybe (Error b)
checkModule :: PatternEnv -> Declarations a c b -> Maybe (Error b)
checkModule PatternEnv
env = (KempeDecl a c b -> Maybe (Error b))
-> Declarations a c b -> Maybe (Error b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Alternative f) =>
(a -> f b) -> t a -> f b
foldMapAlternative (PatternEnv -> KempeDecl a c b -> Maybe (Error b)
forall a c b. PatternEnv -> KempeDecl a c b -> Maybe (Error b)
checkDecl PatternEnv
env)

checkModuleExhaustive :: Declarations a c b -> Maybe (Error b)
checkModuleExhaustive :: Declarations a c b -> Maybe (Error b)
checkModuleExhaustive Declarations a c b
m =
    let env :: PatternEnv
env = PatternM () -> PatternEnv
forall a. PatternM a -> PatternEnv
runPatternM (PatternM () -> PatternEnv) -> PatternM () -> PatternEnv
forall a b. (a -> b) -> a -> b
$ Declarations a c b -> PatternM ()
forall a c b. Declarations a c b -> PatternM ()
patternEnvDecls Declarations a c b
m
        in PatternEnv -> Declarations a c b -> Maybe (Error b)
forall a c b. PatternEnv -> Declarations a c b -> Maybe (Error b)
checkModule PatternEnv
env Declarations a c b
m

data PatternEnv = PatternEnv { PatternEnv -> IntMap IntSet
allCons :: IM.IntMap IS.IntSet -- ^ all constructors indexed by type
                             , PatternEnv -> IntMap Int
types   :: IM.IntMap Int -- ^ all types indexed by constructor
                             }

allConsLens :: Lens' PatternEnv (IM.IntMap IS.IntSet)
allConsLens :: (IntMap IntSet -> f (IntMap IntSet)) -> PatternEnv -> f PatternEnv
allConsLens IntMap IntSet -> f (IntMap IntSet)
f PatternEnv
s = (IntMap IntSet -> PatternEnv) -> f (IntMap IntSet) -> f PatternEnv
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\IntMap IntSet
x -> PatternEnv
s { allCons :: IntMap IntSet
allCons = IntMap IntSet
x }) (IntMap IntSet -> f (IntMap IntSet)
f (PatternEnv -> IntMap IntSet
allCons PatternEnv
s))

typesLens :: Lens' PatternEnv (IM.IntMap Int)
typesLens :: (IntMap Int -> f (IntMap Int)) -> PatternEnv -> f PatternEnv
typesLens IntMap Int -> f (IntMap Int)
f PatternEnv
s = (IntMap Int -> PatternEnv) -> f (IntMap Int) -> f PatternEnv
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\IntMap Int
x -> PatternEnv
s { types :: IntMap Int
types = IntMap Int
x }) (IntMap Int -> f (IntMap Int)
f (PatternEnv -> IntMap Int
types PatternEnv
s))

type PatternM = State PatternEnv

patternEnvDecls :: Declarations a c b -> PatternM ()
patternEnvDecls :: Declarations a c b -> PatternM ()
patternEnvDecls = (KempeDecl a c b -> PatternM ())
-> Declarations a c b -> PatternM ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ KempeDecl a c b -> PatternM ()
forall a c b. KempeDecl a c b -> PatternM ()
declAdd

declAdd :: KempeDecl a c b -> PatternM ()
declAdd :: KempeDecl a c b -> PatternM ()
declAdd FunDecl{}                             = () -> PatternM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
declAdd ExtFnDecl{}                           = () -> PatternM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
declAdd Export{}                              = () -> PatternM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
declAdd (TyDecl a
_ (Name Text
_ (Unique Int
i) a
_) [Name a]
_ [(TyName b, [KempeTy a])]
ls) = do
    [(TyName b, [KempeTy a])]
-> ((TyName b, [KempeTy a]) -> PatternM ()) -> PatternM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(TyName b, [KempeTy a])]
ls (((TyName b, [KempeTy a]) -> PatternM ()) -> PatternM ())
-> ((TyName b, [KempeTy a]) -> PatternM ()) -> PatternM ()
forall a b. (a -> b) -> a -> b
$ \(Name Text
_ (Unique Int
j) b
_, [KempeTy a]
_) ->
        ASetter PatternEnv PatternEnv (IntMap Int) (IntMap Int)
-> (IntMap Int -> IntMap Int) -> PatternM ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
modifying ASetter PatternEnv PatternEnv (IntMap Int) (IntMap Int)
Lens' PatternEnv (IntMap Int)
typesLens (Int -> Int -> IntMap Int -> IntMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
j Int
i)
    let cons :: IntSet
cons = [Int] -> IntSet
IS.fromList ([Int] -> IntSet) -> [Int] -> IntSet
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Unique -> Int
unUnique (Unique -> Int)
-> ((TyName b, [KempeTy a]) -> Unique)
-> (TyName b, [KempeTy a])
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyName b -> Unique
forall a. Name a -> Unique
unique (TyName b -> Unique)
-> ((TyName b, [KempeTy a]) -> TyName b)
-> (TyName b, [KempeTy a])
-> Unique
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TyName b, [KempeTy a]) -> TyName b
forall a b. (a, b) -> a
fst ((TyName b, [KempeTy a]) -> Int)
-> [(TyName b, [KempeTy a])] -> [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(TyName b, [KempeTy a])]
ls)
    ASetter PatternEnv PatternEnv (IntMap IntSet) (IntMap IntSet)
-> (IntMap IntSet -> IntMap IntSet) -> PatternM ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
modifying ASetter PatternEnv PatternEnv (IntMap IntSet) (IntMap IntSet)
Lens' PatternEnv (IntMap IntSet)
allConsLens (Int -> IntSet -> IntMap IntSet -> IntMap IntSet
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
i IntSet
cons)

runPatternM :: PatternM a -> PatternEnv
runPatternM :: PatternM a -> PatternEnv
runPatternM = (PatternM a -> PatternEnv -> PatternEnv)
-> PatternEnv -> PatternM a -> PatternEnv
forall a b c. (a -> b -> c) -> b -> a -> c
flip PatternM a -> PatternEnv -> PatternEnv
forall s a. State s a -> s -> s
execState (IntMap IntSet -> IntMap Int -> PatternEnv
PatternEnv IntMap IntSet
forall a. Monoid a => a
mempty IntMap Int
forall a. Monoid a => a
mempty)

internalError :: a
internalError :: a
internalError = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Internal error: lookup in a PatternEnv failed"

-- given a constructor name, get the IntSet of all constructors of that type
assocUniques :: PatternEnv -> Name a -> IS.IntSet
assocUniques :: PatternEnv -> Name a -> IntSet
assocUniques PatternEnv
env (Name Text
_ (Unique Int
i) a
_) =
    let ty :: Int
ty = Int -> Int -> IntMap Int -> Int
forall a. a -> Int -> IntMap a -> a
IM.findWithDefault Int
forall a. a
internalError Int
i (PatternEnv -> IntMap Int
types PatternEnv
env)
        in IntSet -> Int -> IntMap IntSet -> IntSet
forall a. a -> Int -> IntMap a -> a
IM.findWithDefault IntSet
forall a. a
internalError Int
ty (PatternEnv -> IntMap IntSet
allCons PatternEnv
env)

hasWildcard :: Foldable t => t (Pattern c b) -> Bool
hasWildcard :: t (Pattern c b) -> Bool
hasWildcard = (Pattern c b -> Bool) -> t (Pattern c b) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Pattern c b -> Bool
forall c b. Pattern c b -> Bool
isWildcard where
    isWildcard :: Pattern c b -> Bool
isWildcard PatternWildcard{} = Bool
True
    isWildcard Pattern c b
_                 = Bool
False

-- | Only works on well-typed stuff
isExhaustive :: PatternEnv -> NonEmpty (Pattern c b) -> Bool
isExhaustive :: PatternEnv -> NonEmpty (Pattern c b) -> Bool
isExhaustive PatternEnv
_ (PatternWildcard{}:|[Pattern c b]
_)                      = Bool
True
isExhaustive PatternEnv
_ (PatternInt{}:|[Pattern c b]
ps)                          = [Pattern c b] -> Bool
forall (t :: * -> *) c b. Foldable t => t (Pattern c b) -> Bool
hasWildcard [Pattern c b]
ps
isExhaustive PatternEnv
_ (PatternBool b
_ Bool
True:|PatternBool b
_ Bool
False:[Pattern c b]
_) = Bool
True
isExhaustive PatternEnv
_ (PatternBool b
_ Bool
False:|PatternBool b
_ Bool
True:[Pattern c b]
_) = Bool
True
isExhaustive PatternEnv
_ (PatternBool{}:|[Pattern c b]
ps)                         = [Pattern c b] -> Bool
forall (t :: * -> *) c b. Foldable t => t (Pattern c b) -> Bool
hasWildcard [Pattern c b]
ps
isExhaustive PatternEnv
env ps :: NonEmpty (Pattern c b)
ps@(PatternCons{}:|[Pattern c b]
_)                     = PatternEnv -> NonEmpty (TyName c) -> Bool
forall a. PatternEnv -> NonEmpty (TyName a) -> Bool
isCompleteSet PatternEnv
env ((Pattern c b -> TyName c)
-> NonEmpty (Pattern c b) -> NonEmpty (TyName c)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Pattern c b -> TyName c
forall c b. Pattern c b -> TyName c
patternName NonEmpty (Pattern c b)
ps)

isCompleteSet :: PatternEnv -> NonEmpty (TyName a) -> Bool
isCompleteSet :: PatternEnv -> NonEmpty (TyName a) -> Bool
isCompleteSet PatternEnv
env ns :: NonEmpty (TyName a)
ns@(TyName a
n:|[TyName a]
_) =
    let allU :: IntSet
allU = PatternEnv -> TyName a -> IntSet
forall a. PatternEnv -> Name a -> IntSet
assocUniques PatternEnv
env TyName a
n
        ty :: [Int]
ty = [Unique] -> [Int]
coerce (TyName a -> Unique
forall a. Name a -> Unique
unique (TyName a -> Unique) -> [TyName a] -> [Unique]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty (TyName a) -> [TyName a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList NonEmpty (TyName a)
ns)
        in IntSet -> Bool
IS.null (IntSet
allU IntSet -> IntSet -> IntSet
IS.\\ [Int] -> IntSet
IS.fromList [Int]
ty)