module Checks.InstanceCheck (instanceCheck) where
import Control.Monad.Extra (concatMapM, whileM)
import qualified Control.Monad.State as S (State, execState, gets, modify)
import Data.List (nub, partition, sortBy)
import qualified Data.Map as Map
import qualified Data.Set.Extra as Set
import Curry.Base.Ident
import Curry.Base.Position
import Curry.Base.Pretty
import Curry.Base.SpanInfo
import Curry.Syntax hiding (impls)
import Curry.Syntax.Pretty
import Base.CurryTypes
import Base.Messages (Message, posMessage, message, internalError)
import Base.SCC (scc)
import Base.TypeExpansion
import Base.Types
import Base.TypeSubst
import Base.Utils (fst3, snd3, findMultiples)
import Env.Class
import Env.Instance
import Env.TypeConstructor
instanceCheck :: ModuleIdent -> TCEnv -> ClassEnv -> InstEnv -> [Decl a]
-> (InstEnv, [Message])
instanceCheck m tcEnv clsEnv inEnv ds =
case findMultiples (local ++ imported) of
[] -> execINCM (checkDecls tcEnv clsEnv ds) state
iss -> (inEnv, map (errMultipleInstances tcEnv) iss)
where
local = map (flip InstSource m) $ concatMap (genInstIdents m tcEnv) ds
imported = map (uncurry InstSource) $ map (fmap fst3) $ Map.toList inEnv
state = INCState m inEnv []
data InstSource = InstSource InstIdent ModuleIdent
instance Eq InstSource where
InstSource i1 _ == InstSource i2 _ = i1 == i2
type INCM = S.State INCState
data INCState = INCState
{ moduleIdent :: ModuleIdent
, instEnv :: InstEnv
, errors :: [Message]
}
execINCM :: INCM a -> INCState -> (InstEnv, [Message])
execINCM incm s =
let s' = S.execState incm s in (instEnv s', reverse $ nub $ errors s')
getModuleIdent :: INCM ModuleIdent
getModuleIdent = S.gets moduleIdent
getInstEnv :: INCM InstEnv
getInstEnv = S.gets instEnv
modifyInstEnv :: (InstEnv -> InstEnv) -> INCM ()
modifyInstEnv f = S.modify $ \s -> s { instEnv = f $ instEnv s }
report :: Message -> INCM ()
report err = S.modify (\s -> s { errors = err : errors s })
ok :: INCM ()
ok = return ()
checkDecls :: TCEnv -> ClassEnv -> [Decl a] -> INCM ()
checkDecls tcEnv clsEnv ds = do
mapM_ (bindInstance tcEnv clsEnv) ids
mapM (declDeriveInfo tcEnv clsEnv) (filter hasDerivedInstances tds) >>=
mapM_ (bindDerivedInstances clsEnv) . groupDeriveInfos
mapM_ (checkInstance tcEnv clsEnv) ids
mapM_ (checkDefault tcEnv clsEnv) dds
where (tds, ods) = partition isTypeDecl ds
ids = filter isInstanceDecl ods
dds = filter isDefaultDecl ods
bindInstance :: TCEnv -> ClassEnv -> Decl a -> INCM ()
bindInstance tcEnv clsEnv (InstanceDecl _ cx qcls inst ds) = do
m <- getModuleIdent
let PredType ps _ = expandPolyType m tcEnv clsEnv $
QualTypeExpr NoSpanInfo cx inst
modifyInstEnv $
bindInstInfo (genInstIdent m tcEnv qcls inst) (m, ps, impls [] ds)
where impls is [] = is
impls is (FunctionDecl _ _ f eqs:ds')
| f' `elem` map fst is = impls is ds'
| otherwise = impls ((f', eqnArity $ head eqs) : is) ds'
where f' = unRenameIdent f
impls _ _ = internalError "InstanceCheck.bindInstance.impls"
bindInstance _ _ _ = ok
hasDerivedInstances :: Decl a -> Bool
hasDerivedInstances (DataDecl _ _ _ _ clss) = not $ null clss
hasDerivedInstances (NewtypeDecl _ _ _ _ clss) = not $ null clss
hasDerivedInstances _ = False
data DeriveInfo = DeriveInfo Position QualIdent PredType [Type] [QualIdent]
declDeriveInfo :: TCEnv -> ClassEnv -> Decl a -> INCM DeriveInfo
declDeriveInfo tcEnv clsEnv (DataDecl p tc tvs cs clss) =
mkDeriveInfo tcEnv clsEnv p tc tvs (concat tyss) clss
where tyss = map constrDeclTypes cs
constrDeclTypes (ConstrDecl _ _ tys) = tys
constrDeclTypes (ConOpDecl _ ty1 _ ty2) = [ty1, ty2]
constrDeclTypes (RecordDecl _ _ fs) = tys
where tys = [ty | FieldDecl _ ls ty <- fs, _ <- ls]
declDeriveInfo tcEnv clsEnv (NewtypeDecl p tc tvs nc clss) =
mkDeriveInfo tcEnv clsEnv p tc tvs [nconstrType nc] clss
declDeriveInfo _ _ _ =
internalError "InstanceCheck.declDeriveInfo: no data or newtype declaration"
mkDeriveInfo :: TCEnv -> ClassEnv -> SpanInfo -> Ident -> [Ident] -> [TypeExpr]
-> [QualIdent] -> INCM DeriveInfo
mkDeriveInfo tcEnv clsEnv spi tc tvs tys clss = do
m <- getModuleIdent
let otc = qualifyWith m tc
oclss = map (flip (getOrigName m) tcEnv) clss
PredType ps ty = expandConstrType m tcEnv clsEnv otc tvs tys
(tys', ty') = arrowUnapply ty
return $ DeriveInfo p otc (PredType ps ty') tys' $ sortClasses clsEnv oclss
where p = spanInfo2Pos spi
sortClasses :: ClassEnv -> [QualIdent] -> [QualIdent]
sortClasses clsEnv clss = map fst $ sortBy compareDepth $ map adjoinDepth clss
where (_, d1) `compareDepth` (_, d2) = d1 `compare` d2
adjoinDepth cls = (cls, length $ allSuperClasses cls clsEnv)
groupDeriveInfos :: [DeriveInfo] -> [[DeriveInfo]]
groupDeriveInfos ds = scc bound free ds
where bound (DeriveInfo _ tc _ _ _) = [tc]
free (DeriveInfo _ _ _ tys _) = concatMap typeConstrs tys
bindDerivedInstances :: ClassEnv -> [DeriveInfo] -> INCM ()
bindDerivedInstances clsEnv dis = do
mapM_ (enterInitialPredSet clsEnv) dis
whileM $ concatMapM (inferPredSets clsEnv) dis >>= updatePredSets
enterInitialPredSet :: ClassEnv -> DeriveInfo -> INCM ()
enterInitialPredSet clsEnv (DeriveInfo p tc pty _ clss) =
mapM_ (bindDerivedInstance clsEnv p tc pty []) clss
bindDerivedInstance :: ClassEnv -> Position -> QualIdent -> PredType -> [Type]
-> QualIdent -> INCM ()
bindDerivedInstance clsEnv p tc pty tys cls = do
m <- getModuleIdent
(i, ps) <- inferPredSet clsEnv p tc pty tys cls
modifyInstEnv $ bindInstInfo i (m, ps, impls)
where impls | cls == qEqId = [(eqOpId, 2)]
| cls == qOrdId = [(leqOpId, 2)]
| cls == qEnumId = [ (succId, 1), (predId, 1), (toEnumId, 1)
, (fromEnumId, 1), (enumFromId, 1)
, (enumFromThenId, 2)
]
| cls == qBoundedId = [(maxBoundId, 0), (minBoundId, 0)]
| cls == qReadId = [(readsPrecId, 2)]
| cls == qShowId = [(showsPrecId, 2)]
| otherwise =
internalError "InstanceCheck.bindDerivedInstance.impls"
inferPredSets :: ClassEnv -> DeriveInfo -> INCM [(InstIdent, PredSet)]
inferPredSets clsEnv (DeriveInfo p tc pty tys clss) =
mapM (inferPredSet clsEnv p tc pty tys) clss
inferPredSet :: ClassEnv -> Position -> QualIdent -> PredType -> [Type]
-> QualIdent -> INCM (InstIdent, PredSet)
inferPredSet clsEnv p tc (PredType ps inst) tys cls = do
m <- getModuleIdent
let doc = ppPred m $ Pred cls inst
sclss = superClasses cls clsEnv
ps' = Set.fromList [Pred cls ty | ty <- tys]
ps'' = Set.fromList [Pred scls inst | scls <- sclss]
ps''' = ps `Set.union` ps' `Set.union` ps''
ps'''' <- reducePredSet p "derived instance" doc clsEnv ps'''
mapM_ (reportUndecidable p "derived instance" doc) $ Set.toList ps''''
return ((cls, tc), ps'''')
updatePredSets :: [(InstIdent, PredSet)] -> INCM Bool
updatePredSets = (=<<) (return . or) . mapM (uncurry updatePredSet)
updatePredSet :: InstIdent -> PredSet -> INCM Bool
updatePredSet i ps = do
inEnv <- getInstEnv
case lookupInstInfo i inEnv of
Just (m, ps', is)
| ps == ps' -> return False
| otherwise -> do
modifyInstEnv $ bindInstInfo i (m, ps, is)
return True
Nothing -> internalError "InstanceCheck.updatePredSet"
reportUndecidable :: Position -> String -> Doc -> Pred -> INCM ()
reportUndecidable p what doc predicate@(Pred _ ty) = do
m <- getModuleIdent
case ty of
TypeVariable _ -> return ()
_ -> report $ errMissingInstance m p what doc predicate
checkInstance :: TCEnv -> ClassEnv -> Decl a -> INCM ()
checkInstance tcEnv clsEnv (InstanceDecl spi cx cls inst _) = do
m <- getModuleIdent
let PredType ps ty = expandPolyType m tcEnv clsEnv $
QualTypeExpr NoSpanInfo cx inst
ocls = getOrigName m cls tcEnv
ps' = Set.fromList [ Pred scls ty | scls <- superClasses ocls clsEnv ]
doc = ppPred m $ Pred cls ty
what = "instance declaration"
ps'' <- reducePredSet p what doc clsEnv ps'
Set.mapM_ (report . errMissingInstance m p what doc) $
ps'' `Set.difference` (maxPredSet clsEnv ps)
where p = spanInfo2Pos spi
checkInstance _ _ _ = ok
checkDefault :: TCEnv -> ClassEnv -> Decl a -> INCM ()
checkDefault tcEnv clsEnv (DefaultDecl p tys) =
mapM_ (checkDefaultType (spanInfo2Pos p) tcEnv clsEnv) tys
checkDefault _ _ _ = ok
checkDefaultType :: Position -> TCEnv -> ClassEnv -> TypeExpr -> INCM ()
checkDefaultType p tcEnv clsEnv ty = do
m <- getModuleIdent
let PredType _ ty' = expandPolyType m tcEnv clsEnv $
QualTypeExpr NoSpanInfo [] ty
ps <- reducePredSet p what empty clsEnv (Set.singleton $ Pred qNumId ty')
Set.mapM_ (report . errMissingInstance m p what empty) ps
where what = "default declaration"
reducePredSet :: Position -> String -> Doc -> ClassEnv -> PredSet
-> INCM PredSet
reducePredSet p what doc clsEnv ps = do
m <- getModuleIdent
inEnv <- getInstEnv
let (ps1, ps2) = partitionPredSet $ minPredSet clsEnv $ reducePreds inEnv ps
Set.mapM_ (report . errMissingInstance m p what doc) ps2
return ps1
where
reducePreds inEnv = Set.concatMap $ reducePred inEnv
reducePred inEnv predicate = maybe (Set.singleton predicate)
(reducePreds inEnv)
(instPredSet inEnv predicate)
instPredSet :: InstEnv -> Pred -> Maybe PredSet
instPredSet inEnv (Pred qcls ty) =
case unapplyType False ty of
(TypeConstructor tc, tys) ->
fmap (expandAliasType tys . snd3) (lookupInstInfo (qcls, tc) inEnv)
_ -> Nothing
genInstIdents :: ModuleIdent -> TCEnv -> Decl a -> [InstIdent]
genInstIdents m tcEnv (DataDecl _ tc _ _ qclss) =
map (flip (genInstIdent m tcEnv) $ ConstructorType NoSpanInfo $ qualify tc)
qclss
genInstIdents m tcEnv (NewtypeDecl _ tc _ _ qclss) =
map (flip (genInstIdent m tcEnv) $ ConstructorType NoSpanInfo $ qualify tc)
qclss
genInstIdents m tcEnv (InstanceDecl _ _ qcls ty _) =
[genInstIdent m tcEnv qcls ty]
genInstIdents _ _ _ = []
genInstIdent :: ModuleIdent -> TCEnv -> QualIdent -> TypeExpr -> InstIdent
genInstIdent m tcEnv qcls = qualInstIdent m tcEnv . (,) qcls . typeConstr
qualInstIdent :: ModuleIdent -> TCEnv -> InstIdent -> InstIdent
qualInstIdent m tcEnv (cls, tc) = (qual cls, qual tc)
where
qual = flip (getOrigName m) tcEnv
unqualInstIdent :: TCEnv -> InstIdent -> InstIdent
unqualInstIdent tcEnv (qcls, tc) = (unqual qcls, unqual tc)
where
unqual = head . flip reverseLookupByOrigName tcEnv
errMultipleInstances :: TCEnv -> [InstSource] -> Message
errMultipleInstances tcEnv iss = message $
text "Multiple instances for the same class and type" $+$
nest 2 (vcat (map ppInstSource iss))
where
ppInstSource (InstSource i m) = ppInstIdent (unqualInstIdent tcEnv i) <+>
parens (text "defined in" <+> ppMIdent m)
errMissingInstance :: ModuleIdent -> Position -> String -> Doc -> Pred
-> Message
errMissingInstance m p what doc predicate = posMessage p $ vcat
[ text "Missing instance for" <+> ppPred m predicate
, text "in" <+> text what <+> doc
]