{-# LANGUAGE CPP #-}
module Transformations.CaseCompletion (completeCase) where
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative ((<$>), (<*>))
#endif
import qualified Control.Monad.State as S (State, evalState, gets, modify)
import Data.List (find)
import Data.Maybe (fromMaybe, listToMaybe)
import Curry.Base.Ident
import qualified Curry.Syntax as CS
import Base.CurryTypes (toType)
import Base.Expr
import Base.Messages (internalError)
import Base.Types ( boolType, charType, floatType
, intType, listType
)
import Base.Subst
import Env.Interface (InterfaceEnv, lookupInterface)
import Transformations.CurryToIL (transType)
import Transformations.Dictionary (qImplMethodId)
import IL
completeCase :: InterfaceEnv -> Module -> Module
completeCase iEnv mdl@(Module mid is ds) = Module mid is ds'
where ds'= S.evalState (mapM ccDecl ds) (CCState mdl iEnv 0)
data CCState = CCState
{ modul :: Module
, interfaceEnv :: InterfaceEnv
, nextId :: Int
}
type CCM a = S.State CCState a
getModule :: CCM Module
getModule = S.gets modul
getInterfaceEnv :: CCM InterfaceEnv
getInterfaceEnv = S.gets interfaceEnv
freshIdent :: CCM Ident
freshIdent = do
nid <- S.gets nextId
S.modify $ \s -> s { nextId = succ nid }
return $ mkIdent $ "_#comp" ++ show nid
ccDecl :: Decl -> CCM Decl
ccDecl dd@(DataDecl _ _ _) = return dd
ccDecl edd@(ExternalDataDecl _ _) = return edd
ccDecl (FunctionDecl qid vs ty e) = FunctionDecl qid vs ty <$> ccExpr e
ccDecl ed@(ExternalDecl _ _) = return ed
ccExpr :: Expression -> CCM Expression
ccExpr l@(Literal _ _) = return l
ccExpr v@(Variable _ _) = return v
ccExpr f@(Function _ _ _) = return f
ccExpr c@(Constructor _ _ _) = return c
ccExpr (Apply e1 e2) = Apply <$> ccExpr e1 <*> ccExpr e2
ccExpr (Case ea e bs) = do
e' <- ccExpr e
bs' <- mapM ccAlt bs
ccCase ea e' bs'
ccExpr (Or e1 e2) = Or <$> ccExpr e1 <*> ccExpr e2
ccExpr (Exist v ty e) = Exist v ty <$> ccExpr e
ccExpr (Let b e) = Let <$> ccBinding b <*> ccExpr e
ccExpr (Letrec bs e) = Letrec <$> mapM ccBinding bs <*> ccExpr e
ccExpr (Typed e ty) = flip Typed ty <$> ccExpr e
ccAlt :: Alt -> CCM Alt
ccAlt (Alt p e) = Alt p <$> ccExpr e
ccBinding :: Binding -> CCM Binding
ccBinding (Binding v e) = Binding v <$> ccExpr e
ccCase :: Eval -> Expression -> [Alt] -> CCM Expression
ccCase Flex e alts = return $ Case Flex e alts
ccCase Rigid _ [] = internalError $ "CaseCompletion.ccCase: "
++ "empty alternative list"
ccCase Rigid e as@(Alt p _:_) = case p of
ConstructorPattern _ _ _ -> completeConsAlts Rigid e as
LiteralPattern _ _ -> completeLitAlts Rigid e as
VariablePattern _ _ -> completeVarAlts e as
completeConsAlts :: Eval -> Expression -> [Alt] -> CCM Expression
completeConsAlts ea ce alts = do
mdl <- getModule
menv <- getInterfaceEnv
complPats <- mapM genPat $ getComplConstrs mdl menv
[ c | (Alt (ConstructorPattern _ c _) _) <- consAlts ]
v <- freshIdent
w <- freshIdent
return $ case (complPats, defaultAlt v) of
(_:_, Just e') -> bindDefVar v ce w e' complPats
_ -> Case ea ce consAlts
where
consAlts = [ a | a@(Alt (ConstructorPattern _ _ _) _) <- alts ]
dataTy = let TypeConstructor qid tys = patTy
in TypeConstructor qid $ map TypeVariable [0 .. length tys - 1]
patTy = let Alt pat _ = head consAlts in typeOf pat
tySubst = matchType dataTy patTy idSubst
genPat (qid, tys) = ConstructorPattern patTy qid <$>
mapM (\ty' -> freshIdent >>= \v -> return (ty', v)) (subst tySubst tys)
defaultAlt v = listToMaybe [ replaceVar x (Variable ty v) e
| Alt (VariablePattern ty x) e <- alts ]
bindDefVar v e w e' ps
| v `elem` fv e' = mkBinding v e $ mkCase (Variable (typeOf e) v) w e' ps
| otherwise = mkCase e w e' ps
mkCase e w e' ps = case ps of
[p] -> Case ea e (consAlts ++ [Alt p e'])
_ -> mkBinding w e'
$ Case ea e (consAlts ++ [Alt p (Variable (typeOf e') w) | p <- ps])
completeLitAlts :: Eval -> Expression -> [Alt] -> CCM Expression
completeLitAlts ea ce alts = do
x <- freshIdent
return $ mkBinding x ce $ nestedCases x alts
where
nestedCases _ [] = failedExpr (typeOf $ head alts)
nestedCases x (Alt p ae : as) = case p of
LiteralPattern ty l -> Case ea (Variable ty x `eqExpr` Literal ty l)
[ Alt truePatt ae
, Alt falsePatt (nestedCases x as)
]
VariablePattern ty v -> replaceVar v (Variable ty x) ae
_ -> internalError "CaseCompletion.completeLitAlts: illegal alternative"
completeVarAlts :: Expression -> [Alt] -> CCM Expression
completeVarAlts _ [] = internalError $
"CaseCompletion.completeVarAlts: empty alternative list"
completeVarAlts ce (Alt p ae : _) = case p of
VariablePattern _ x -> return $ mkBinding x ce ae
_ -> internalError $
"CaseCompletion.completeVarAlts: variable pattern expected"
mkBinding :: Ident -> Expression -> Expression -> Expression
mkBinding v e e' = case e of
Variable _ _ -> replaceVar v e e'
_ -> Let (Binding v e) e'
replaceVar :: Ident -> Expression -> Expression -> Expression
replaceVar v e x@(Variable _ w)
| v == w = e
| otherwise = x
replaceVar v e (Apply e1 e2)
= Apply (replaceVar v e e1) (replaceVar v e e2)
replaceVar v e (Case ev e' bs)
= Case ev (replaceVar v e e') (map (replaceVarInAlt v e) bs)
replaceVar v e (Or e1 e2)
= Or (replaceVar v e e1) (replaceVar v e e2)
replaceVar v e (Exist w ty e')
| v == w = Exist w ty e'
| otherwise = Exist w ty (replaceVar v e e')
replaceVar v e (Let b e')
| v `occursInBinding` b = Let b e'
| otherwise = Let (replaceVarInBinding v e b)
(replaceVar v e e')
replaceVar v e (Letrec bs e')
| any (occursInBinding v) bs = Letrec bs e'
| otherwise = Letrec (map (replaceVarInBinding v e) bs)
(replaceVar v e e')
replaceVar _ _ e' = e'
replaceVarInAlt :: Ident -> Expression -> Alt -> Alt
replaceVarInAlt v e (Alt p e')
| v `occursInPattern` p = Alt p e'
| otherwise = Alt p (replaceVar v e e')
replaceVarInBinding :: Ident -> Expression -> Binding -> Binding
replaceVarInBinding v e (Binding w e')
| v == w = Binding w e'
| otherwise = Binding w (replaceVar v e e')
occursInPattern :: Ident -> ConstrTerm -> Bool
occursInPattern v (VariablePattern _ w) = v == w
occursInPattern v (ConstructorPattern _ _ vs) = v `elem` map snd vs
occursInPattern _ _ = False
occursInBinding :: Ident -> Binding -> Bool
occursInBinding v (Binding w _) = v == w
failedExpr :: Type -> Expression
failedExpr ty = Function ty (qualifyWith preludeMIdent (mkIdent "failed")) 0
eqExpr :: Expression -> Expression -> Expression
eqExpr e1 e2 = Apply (Apply (Function eqTy eq 2) e1) e2
where eq = qImplMethodId preludeMIdent qEqId ty $ mkIdent "=="
ty = case e2 of
Literal _ l -> case l of
Char _ -> charType
Int _ -> intType
Float _ -> floatType
_ -> internalError "CaseCompletion.eqExpr: no literal"
ty' = transType ty
eqTy = TypeArrow ty' (TypeArrow ty' boolType')
truePatt :: ConstrTerm
truePatt = ConstructorPattern boolType' qTrueId []
falsePatt :: ConstrTerm
falsePatt = ConstructorPattern boolType' qFalseId []
boolType' :: Type
boolType' = transType boolType
getComplConstrs :: Module -> InterfaceEnv -> [QualIdent] -> [(QualIdent, [Type])]
getComplConstrs _ _ []
= internalError "CaseCompletion.getComplConstrs: empty constructor list"
getComplConstrs (Module mid _ ds) menv cs@(c:_)
| c `elem` [qNilId, qConsId] = complementary cs
[(qNilId, []), (qConsId, [TypeVariable 0, transType (listType boolType)])]
| mid' == mid = getCCFromDecls cs ds
| otherwise = maybe [] (getCCFromIDecls mid' cs)
(lookupInterface mid' menv)
where mid' = fromMaybe mid (qidModule c)
getCCFromDecls :: [QualIdent] -> [Decl] -> [(QualIdent, [Type])]
getCCFromDecls cs ds = complementary cs cinfos
where
cinfos = map constrInfo
$ maybe [] extractConstrDecls (find (`declares` head cs) ds)
decl `declares` qid = case decl of
DataDecl _ _ cs' -> any (`declaresConstr` qid) cs'
_ -> False
declaresConstr (ConstrDecl cid _) qid = cid == qid
extractConstrDecls (DataDecl _ _ cs') = cs'
extractConstrDecls _ = []
constrInfo (ConstrDecl cid tys) = (cid, tys)
getCCFromIDecls :: ModuleIdent -> [QualIdent] -> CS.Interface
-> [(QualIdent, [Type])]
getCCFromIDecls mid cs (CS.Interface _ _ ds) = complementary cs cinfos
where
cinfos = map (uncurry constrInfo)
$ maybe [] extractConstrDecls (find (`declares` head cs) ds)
decl `declares` qid = case decl of
CS.IDataDecl _ _ _ _ cs' _ -> any (`declaresConstr` qid) cs'
CS.INewtypeDecl _ _ _ _ nc _ -> isNewConstrDecl qid nc
_ -> False
declaresConstr (CS.ConstrDecl _ cid _) qid = unqualify qid == cid
declaresConstr (CS.ConOpDecl _ _ oid _) qid = unqualify qid == oid
declaresConstr (CS.RecordDecl _ cid _) qid = unqualify qid == cid
isNewConstrDecl qid (CS.NewConstrDecl _ cid _) = unqualify qid == cid
isNewConstrDecl qid (CS.NewRecordDecl _ cid _) = unqualify qid == cid
extractConstrDecls (CS.IDataDecl _ _ _ vs cs' _) = zip (repeat vs) cs'
extractConstrDecls _ = []
constrInfo vs (CS.ConstrDecl _ cid tys) =
(qualifyWith mid cid, map (transType' vs) tys)
constrInfo vs (CS.ConOpDecl _ ty1 oid ty2) =
(qualifyWith mid oid, map (transType' vs) [ty1, ty2])
constrInfo vs (CS.RecordDecl _ cid fs) =
( qualifyWith mid cid
, [transType' vs ty | CS.FieldDecl _ ls ty <- fs, _ <- ls]
)
transType' vs = transType . toType vs
complementary :: [QualIdent] -> [(QualIdent, [Type])] -> [(QualIdent, [Type])]
complementary known others = filter ((`notElem` known) . fst) others
type TypeSubst = Subst Int Type
class SubstType a where
subst :: TypeSubst -> a -> a
instance SubstType a => SubstType [a] where
subst sigma = map (subst sigma)
instance SubstType Type where
subst sigma (TypeConstructor q tys) = TypeConstructor q $ subst sigma tys
subst sigma (TypeVariable tv) = substVar' TypeVariable subst sigma tv
subst sigma (TypeArrow ty1 ty2) = TypeArrow (subst sigma ty1) (subst sigma ty2)
subst _ (TypeForall _ _) =
internalError "Transformations.CaseCompletion.SubstType.Type.subst"
matchType :: Type -> Type -> TypeSubst -> TypeSubst
matchType ty1 ty2 = fromMaybe noMatch (matchType' ty1 ty2)
where
noMatch = internalError $ "Transformations.CaseCompletion.matchType: " ++
showsPrec 11 ty1 " " ++ showsPrec 11 ty2 ""
matchType' :: Type -> Type -> Maybe (TypeSubst -> TypeSubst)
matchType' (TypeVariable tv) ty
| ty == TypeVariable tv = Just id
| otherwise = Just (bindSubst tv ty)
matchType' (TypeConstructor tc1 tys1) (TypeConstructor tc2 tys2)
| tc1 == tc2 = Just $ foldr (\(ty1, ty2) -> (matchType ty1 ty2 .)) id $ tys
where tys = zip tys1 tys2
matchType' (TypeArrow ty11 ty12) (TypeArrow ty21 ty22) =
Just (matchType ty11 ty21 . matchType ty12 ty22)
matchType' _ _ = Nothing