{-# LANGUAGE CPP #-}
module Transformations.CurryToIL (ilTrans, transType) where
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative ((<$>), (<*>))
#endif
import Control.Monad.Extra (concatMapM)
import qualified Control.Monad.Reader as R
import Data.List (nub, partition)
import qualified Data.Map as Map (Map, empty, insert, lookup)
import qualified Data.Set as Set (Set, empty, insert, delete, toList)
import Curry.Base.Ident
import Curry.Syntax hiding (caseAlt)
import Base.CurryTypes (toType)
import Base.Expr
import Base.Messages (internalError)
import Base.Types
import Base.Typing
import Base.Utils (foldr2)
import Env.Value (ValueEnv, ValueInfo (..), qualLookupValue)
import qualified IL as IL
ilTrans :: ValueEnv -> Module Type -> IL.Module
ilTrans vEnv (Module _ _ m _ _ ds) = IL.Module m (imports m ds') ds'
where ds' = R.runReader (concatMapM trDecl ds) (TransEnv m vEnv)
imports :: ModuleIdent -> [IL.Decl] -> [ModuleIdent]
imports m = Set.toList . Set.delete m . foldr mdlsDecl Set.empty
mdlsDecl :: IL.Decl -> Set.Set ModuleIdent -> Set.Set ModuleIdent
mdlsDecl (IL.DataDecl _ _ cs) ms = foldr mdlsConstrsDecl ms cs
where mdlsConstrsDecl (IL.ConstrDecl _ tys) ms' = foldr mdlsType ms' tys
mdlsDecl (IL.ExternalDataDecl _ _) ms = ms
mdlsDecl (IL.FunctionDecl _ _ ty e) ms = mdlsType ty (mdlsExpr e ms)
mdlsDecl (IL.ExternalDecl _ ty) ms = mdlsType ty ms
mdlsType :: IL.Type -> Set.Set ModuleIdent -> Set.Set ModuleIdent
mdlsType (IL.TypeConstructor tc tys) ms = modules tc (foldr mdlsType ms tys)
mdlsType (IL.TypeVariable _) ms = ms
mdlsType (IL.TypeArrow ty1 ty2) ms = mdlsType ty1 (mdlsType ty2 ms)
mdlsType (IL.TypeForall _ ty) ms = mdlsType ty ms
mdlsExpr :: IL.Expression -> Set.Set ModuleIdent -> Set.Set ModuleIdent
mdlsExpr (IL.Function _ f _) ms = modules f ms
mdlsExpr (IL.Constructor _ c _) ms = modules c ms
mdlsExpr (IL.Apply e1 e2) ms = mdlsExpr e1 (mdlsExpr e2 ms)
mdlsExpr (IL.Case _ e as) ms = mdlsExpr e (foldr mdlsAlt ms as)
where
mdlsAlt (IL.Alt t e') = mdlsPattern t . mdlsExpr e'
mdlsPattern (IL.ConstructorPattern _ c _) = modules c
mdlsPattern _ = id
mdlsExpr (IL.Or e1 e2) ms = mdlsExpr e1 (mdlsExpr e2 ms)
mdlsExpr (IL.Exist _ _ e) ms = mdlsExpr e ms
mdlsExpr (IL.Let b e) ms = mdlsBinding b (mdlsExpr e ms)
mdlsExpr (IL.Letrec bs e) ms = foldr mdlsBinding (mdlsExpr e ms) bs
mdlsExpr _ ms = ms
mdlsBinding :: IL.Binding -> Set.Set ModuleIdent -> Set.Set ModuleIdent
mdlsBinding (IL.Binding _ e) = mdlsExpr e
modules :: QualIdent -> Set.Set ModuleIdent -> Set.Set ModuleIdent
modules x ms = maybe ms (`Set.insert` ms) (qidModule x)
data TransEnv = TransEnv
{ moduleIdent :: ModuleIdent
, valueEnv :: ValueEnv
}
type TransM a = R.Reader TransEnv a
getValueEnv :: TransM ValueEnv
getValueEnv = R.asks valueEnv
trQualify :: Ident -> TransM QualIdent
trQualify i = flip qualifyWith i <$> R.asks moduleIdent
varType :: QualIdent -> TransM Type
varType f = do
tyEnv <- getValueEnv
case qualLookupValue f tyEnv of
[Value _ _ _ (ForAll _ (PredType _ ty))] -> return ty
[Label _ _ (ForAll _ (PredType _ ty))] -> return ty
_ -> internalError $ "CurryToIL.varType: " ++ show f
constrType :: QualIdent -> TransM Type
constrType c = do
vEnv <- getValueEnv
case qualLookupValue c vEnv of
[DataConstructor _ _ _ (ForAll _ (PredType _ ty))] -> return ty
[NewtypeConstructor _ _ (ForAll _ (PredType _ ty))] -> return ty
_ -> internalError $ "CurryToIL.constrType: " ++ show c
trDecl :: Decl Type -> TransM [IL.Decl]
trDecl (DataDecl _ tc tvs cs _) = (:[]) <$> trData tc tvs cs
trDecl (ExternalDataDecl _ tc tvs) = (:[]) <$> trExternalData tc tvs
trDecl (FunctionDecl _ ty f eqs) = (:[]) <$> trFunction f ty eqs
trDecl (ExternalDecl _ vs) = mapM trExternal vs
trDecl _ = return []
trData :: Ident -> [Ident] -> [ConstrDecl] -> TransM IL.Decl
trData tc tvs cs = do
tc' <- trQualify tc
IL.DataDecl tc' (length tvs) <$> mapM trConstrDecl cs
trConstrDecl :: ConstrDecl -> TransM IL.ConstrDecl
trConstrDecl d = do
c' <- trQualify (constr d)
ty' <- arrowArgs <$> constrType c'
return $ IL.ConstrDecl c' (map transType ty')
where
constr (ConstrDecl _ c _) = c
constr (ConOpDecl _ _ op _) = op
constr (RecordDecl _ c _) = c
trExternalData :: Ident -> [Ident] -> TransM IL.Decl
trExternalData tc tvs = flip IL.ExternalDataDecl (length tvs) <$> trQualify tc
trExternal :: Var Type -> TransM IL.Decl
trExternal (Var ty f) = flip IL.ExternalDecl (transType ty) <$> trQualify f
transType :: Type -> IL.Type
transType ty = transType' ty []
transType' :: Type -> [IL.Type] -> IL.Type
transType' (TypeConstructor tc) = IL.TypeConstructor tc
transType' (TypeApply ty1 ty2) = transType' ty1 . (transType ty2 :)
transType' (TypeVariable tv) = foldl applyType' (IL.TypeVariable tv)
transType' (TypeConstrained tys _) = transType' (head tys)
transType' (TypeArrow ty1 ty2) =
foldl applyType' (IL.TypeArrow (transType ty1) (transType ty2))
transType' (TypeForall tvs ty) =
foldl applyType' (IL.TypeForall tvs (transType ty))
applyType' :: IL.Type -> IL.Type -> IL.Type
applyType' ty1 ty2 =
IL.TypeConstructor (qualifyWith preludeMIdent (mkIdent "Apply")) [ty1, ty2]
trFunction :: Ident -> Type -> [Equation Type] -> TransM IL.Decl
trFunction f ty eqs = do
f' <- trQualify f
let ty' = transType ty
vs' = zip (map (transType . typeOf) ts) vs
alts <- mapM (trEquation vs ws) eqs
return $ IL.FunctionDecl f' vs' ty' (flexMatch vs' alts)
where
Equation _ lhs _ = head eqs
(_, ts) = flatLhs lhs
(vs, ws) = splitAt (length ts) (argNames (mkIdent ""))
trEquation :: [Ident]
-> [Ident]
-> Equation Type
-> TransM Match
trEquation vs vs' (Equation _ (FunLhs _ _ ts) rhs) = do
let patternRenaming = foldr2 bindRenameEnv Map.empty vs ts
rhs' <- trRhs vs' patternRenaming rhs
return (zipWith trPattern vs ts, rhs')
trEquation _ _ _
= internalError "Translation of non-FunLhs euqation not defined"
type RenameEnv = Map.Map Ident Ident
bindRenameEnv :: Ident -> Pattern a -> RenameEnv -> RenameEnv
bindRenameEnv _ (LiteralPattern _ _ _) env = env
bindRenameEnv v (VariablePattern _ _ v') env = Map.insert v' v env
bindRenameEnv v (ConstructorPattern _ _ _ ts) env
= foldr2 bindRenameEnv env (argNames v) ts
bindRenameEnv v (AsPattern _ v' t) env
= Map.insert v' v (bindRenameEnv v t env)
bindRenameEnv _ _ _
= internalError "CurryToIL.bindRenameEnv"
trRhs :: [Ident] -> RenameEnv -> Rhs Type -> TransM IL.Expression
trRhs vs env (SimpleRhs _ e _) = trExpr vs env e
trRhs _ _ (GuardedRhs _ _ _) = internalError "CurryToIL.trRhs: GuardedRhs"
trExpr :: [Ident] -> RenameEnv -> Expression Type -> TransM IL.Expression
trExpr _ _ (Literal _ ty l) = return $ IL.Literal (transType ty) (trLiteral l)
trExpr _ env (Variable _ ty v)
| isQualified v = fun
| otherwise = case Map.lookup (unqualify v) env of
Nothing -> fun
Just v' -> return $ IL.Variable (transType ty) v'
where fun = (IL.Function (transType ty) v . arrowArity) <$> varType v
trExpr _ _ (Constructor _ ty c)
= (IL.Constructor (transType ty) c . arrowArity) <$> constrType c
trExpr vs env (Apply _ e1 e2)
= IL.Apply <$> trExpr vs env e1 <*> trExpr vs env e2
trExpr vs env (Let _ ds e) = do
e' <- trExpr vs env' e
case ds of
[FreeDecl _ vs']
-> return $ foldr (\ (Var ty v) -> IL.Exist v (transType ty)) e' vs'
[d] | all (`notElem` bv d) (qfv emptyMIdent d)
-> flip IL.Let e' <$> trBinding d
_ -> flip IL.Letrec e' <$> mapM trBinding ds
where
env' = foldr2 Map.insert env bvs bvs
bvs = bv ds
trBinding (PatternDecl _ (VariablePattern _ _ v) rhs)
= IL.Binding v <$> trRhs vs env' rhs
trBinding p = error $ "unexpected binding: " ++ show p
trExpr (v:vs) env (Case _ ct e alts) = do
e' <- trExpr vs env e
let matcher = if ct == Flex then flexMatch else rigidMatch
ty' = transType $ typeOf e
expr <- matcher [(ty', v)] <$> mapM (trAlt (v:vs) env) alts
return $ case expr of
IL.Case mode (IL.Variable _ v') alts'
| v == v' && v `notElem` fv alts' -> IL.Case mode e' alts'
_
| v `elem` fv expr -> IL.Let (IL.Binding v e') expr
| otherwise -> expr
trExpr vs env (Typed _ e (QualTypeExpr _ _ ty)) =
flip IL.Typed ty' <$> trExpr vs env e
where ty' = transType (toType [] ty)
trExpr _ _ _ = internalError "CurryToIL.trExpr"
trAlt :: [Ident] -> RenameEnv -> Alt Type -> TransM Match
trAlt ~(v:vs) env (Alt _ t rhs) = do
rhs' <- trRhs vs (bindRenameEnv v t env) rhs
return ([trPattern v t], rhs')
trLiteral :: Literal -> IL.Literal
trLiteral (Char c) = IL.Char c
trLiteral (Int i) = IL.Int i
trLiteral (Float f) = IL.Float f
trLiteral _ = internalError "CurryToIL.trLiteral"
data NestedTerm = NestedTerm IL.ConstrTerm [NestedTerm] deriving Show
pattern :: NestedTerm -> IL.ConstrTerm
pattern (NestedTerm t _) = t
arguments :: NestedTerm -> [NestedTerm]
arguments (NestedTerm _ ts) = ts
trPattern :: Ident -> Pattern Type -> NestedTerm
trPattern _ (LiteralPattern _ ty l)
= NestedTerm (IL.LiteralPattern (transType ty) $ trLiteral l) []
trPattern v (VariablePattern _ ty _)
= NestedTerm (IL.VariablePattern (transType ty) v) []
trPattern v (ConstructorPattern _ ty c ts)
= NestedTerm (IL.ConstructorPattern (transType ty) c vs')
(zipWith trPattern vs ts)
where vs = argNames v
vs' = zip (map (transType . typeOf) ts) vs
trPattern v (AsPattern _ _ t) = trPattern v t
trPattern _ _ = internalError "CurryToIL.trPattern"
argNames :: Ident -> [Ident]
argNames v = [mkIdent (prefix ++ show i) | i <- [1 :: Integer ..] ]
where prefix = idName v ++ "_"
type Match = ([NestedTerm], IL.Expression)
type Match' = (FunList NestedTerm, [NestedTerm], IL.Expression)
type FunList a = [a] -> [a]
flexMatch :: [(IL.Type, Ident)]
-> [Match]
-> IL.Expression
flexMatch [] alts = foldl1 IL.Or (map snd alts)
flexMatch (v:vs) alts
| notDemanded = varExp
| isInductive = conExp
| otherwise = optFlexMatch (IL.Or conExp varExp) (v:) vs (map skipPat alts)
where
isInductive = null varAlts
notDemanded = null conAlts
(varAlts, conAlts) = partition isVarMatch (map tagAlt alts)
varExp = flexMatch vs (map snd varAlts)
conExp = flexMatchInductive id v vs (map prep conAlts)
prep (p, (ts, e)) = (p, (id, ts, e))
optFlexMatch :: IL.Expression
-> FunList (IL.Type, Ident)
-> [(IL.Type, Ident)]
-> [Match']
-> IL.Expression
optFlexMatch def _ [] _ = def
optFlexMatch def prefix (v:vs) alts
| isInductive = flexMatchInductive prefix v vs alts'
| otherwise = optFlexMatch def (prefix . (v:)) vs (map skipPat' alts)
where
isInductive = not (any isVarMatch alts')
alts' = map tagAlt' alts
flexMatchInductive :: FunList (IL.Type, Ident)
-> (IL.Type, Ident)
-> [(IL.Type, Ident)]
-> [(IL.ConstrTerm, Match')]
-> IL.Expression
flexMatchInductive prefix v vs as
= IL.Case IL.Flex (uncurry IL.Variable v) (flexMatchAlts as)
where
flexMatchAlts [] = []
flexMatchAlts ((t, e) : alts) = IL.Alt t expr : flexMatchAlts others
where
expr = flexMatch (prefix (vars t ++ vs)) (map expandVars (e : map snd same))
expandVars (pref, ts1, e') = (pref ts1, e')
(same, others) = partition ((t ==) . fst) alts
rigidMatch :: [(IL.Type, Ident)] -> [Match] -> IL.Expression
rigidMatch vs alts = rigidOptMatch (snd $ head alts) id vs (map prepare alts)
where prepare (ts, e) = (id, ts, e)
rigidOptMatch :: IL.Expression
-> FunList (IL.Type, Ident)
-> [(IL.Type, Ident)]
-> [Match']
-> IL.Expression
rigidOptMatch def _ [] _ = def
rigidOptMatch def prefix (v : vs) alts
| isDemanded = rigidMatchDemanded prefix v vs alts'
| otherwise = rigidOptMatch def (prefix . (v:)) vs (map skipPat' alts)
where
isDemanded = not $ isVarMatch (head alts')
alts' = map tagAlt' alts
rigidMatchDemanded :: FunList (IL.Type, Ident)
-> (IL.Type, Ident)
-> [(IL.Type, Ident)]
-> [(IL.ConstrTerm, Match')]
-> IL.Expression
rigidMatchDemanded prefix v vs alts = IL.Case IL.Rigid (uncurry IL.Variable v)
$ map caseAlt (consPats ++ varPats)
where
(varPats, consPats) = partition isVarPattern $ nub $ map fst alts
caseAlt t = IL.Alt t expr
where
expr = rigidMatch (prefix $ vars t ++ vs) (matchingCases alts)
matchingCases a = map (expandVars (vars t)) $ filter (matches . fst) a
matches t' = t == t' || isVarPattern t'
expandVars vs' (p, (pref, ts1, e)) = (pref ts2, e)
where ts2 | isVarPattern p = map var2Pattern vs' ++ ts1
| otherwise = ts1
var2Pattern v' = NestedTerm (uncurry IL.VariablePattern v') []
isVarPattern :: IL.ConstrTerm -> Bool
isVarPattern (IL.VariablePattern _ _) = True
isVarPattern _ = False
isVarMatch :: (IL.ConstrTerm, a) -> Bool
isVarMatch = isVarPattern . fst
vars :: IL.ConstrTerm -> [(IL.Type, Ident)]
vars (IL.ConstructorPattern _ _ vs) = vs
vars _ = []
tagAlt :: Match -> (IL.ConstrTerm, Match)
tagAlt (t:ts, e) = (pattern t, (arguments t ++ ts, e))
tagAlt ([] , _) = error "CurryToIL.tagAlt: empty pattern list"
skipPat :: Match -> Match'
skipPat (t:ts, e) = ((t:), ts, e)
skipPat ([] , _) = error "CurryToIL.skipPat: empty pattern list"
tagAlt' :: Match' -> (IL.ConstrTerm, Match')
tagAlt' (pref, t:ts, e') = (pattern t, (pref, arguments t ++ ts, e'))
tagAlt' (_ , [] , _ ) = error "CurryToIL.tagAlt': empty pattern list"
skipPat' :: Match' -> Match'
skipPat' (pref, t:ts, e') = (pref . (t:), ts, e')
skipPat' (_ , [] , _ ) = error "CurryToIL.skipPat': empty pattern list"