{-# LANGUAGE CPP #-}
module Transformations.Lift (lift) where
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative ((<$>), (<*>))
#endif
import Control.Arrow (first)
import qualified Control.Monad.State as S (State, runState, gets, modify)
import Data.List
import qualified Data.Map as Map (Map, empty, insert, lookup)
import Data.Maybe (mapMaybe, fromJust)
import qualified Data.Set as Set (fromList, toList, unions)
import Curry.Base.Ident
import Curry.Base.SpanInfo
import Curry.Syntax
import Base.AnnotExpr
import Base.Expr
import Base.Messages (internalError)
import Base.SCC
import Base.Types
import Base.TypeSubst
import Base.Typing
import Base.Utils
import Env.Value
lift :: ValueEnv -> Module Type -> (Module Type, ValueEnv)
lift vEnv (Module spi ps m es is ds) = (lifted, valueEnv s')
where
(ds', s') = S.runState (mapM (absDecl "" []) ds) initState
initState = LiftState m vEnv Map.empty
lifted = Module spi ps m es is $ concatMap liftFunDecl ds'
type AbstractEnv = Map.Map Ident (Expression Type, Type)
data LiftState = LiftState
{ moduleIdent :: ModuleIdent
, valueEnv :: ValueEnv
, abstractEnv :: AbstractEnv
}
type LiftM a = S.State LiftState a
getModuleIdent :: LiftM ModuleIdent
getModuleIdent = S.gets moduleIdent
getValueEnv :: LiftM ValueEnv
getValueEnv = S.gets valueEnv
modifyValueEnv :: (ValueEnv -> ValueEnv) -> LiftM ()
modifyValueEnv f = S.modify $ \s -> s { valueEnv = f $ valueEnv s }
getAbstractEnv :: LiftM AbstractEnv
getAbstractEnv = S.gets abstractEnv
withLocalAbstractEnv :: AbstractEnv -> LiftM a -> LiftM a
withLocalAbstractEnv ae act = do
old <- getAbstractEnv
S.modify $ \s -> s { abstractEnv = ae }
res <- act
S.modify $ \s -> s { abstractEnv = old }
return res
absDecl :: String -> [Ident] -> Decl Type -> LiftM (Decl Type)
absDecl _ lvs (FunctionDecl p ty f eqs) = FunctionDecl p ty f
<$> mapM (absEquation lvs) eqs
absDecl pre lvs (PatternDecl p t rhs) = PatternDecl p t
<$> absRhs pre lvs rhs
absDecl _ _ d = return d
absEquation :: [Ident] -> Equation Type -> LiftM (Equation Type)
absEquation lvs (Equation p lhs@(FunLhs _ f ts) rhs) =
Equation p lhs <$> absRhs (idName f ++ ".") lvs' rhs
where lvs' = lvs ++ bv ts
absEquation _ _ = error "Lift.absEquation: no pattern match"
absRhs :: String -> [Ident] -> Rhs Type -> LiftM (Rhs Type)
absRhs pre lvs (SimpleRhs p e _) = simpleRhs p <$> absExpr pre lvs e
absRhs _ _ _ = error "Lift.absRhs: no simple RHS"
absDeclGroup :: String -> [Ident] -> [Decl Type] -> Expression Type
-> LiftM (Expression Type)
absDeclGroup pre lvs ds e = do
m <- getModuleIdent
absFunDecls pre lvs' (scc bv (qfv m) fds) vds e
where lvs' = lvs ++ bv vds
(fds, vds) = partition isFunDecl ds
absFunDecls :: String -> [Ident] -> [[Decl Type]] -> [Decl Type]
-> Expression Type -> LiftM (Expression Type)
absFunDecls pre lvs [] vds e = do
vds' <- mapM (absDecl pre lvs) vds
e' <- absExpr pre lvs e
return (Let NoSpanInfo vds' e')
absFunDecls pre lvs (fds:fdss) vds e = do
m <- getModuleIdent
env <- getAbstractEnv
vEnv <- getValueEnv
let
fs = bv fds
ftys = map extractFty fds
extractFty (FunctionDecl _ _ f (Equation _ (FunLhs _ _ ts) rhs : _)) =
(f, foldr TypeArrow (typeOf rhs) $ map typeOf ts)
extractFty _ =
internalError "Lift.absFunDecls.extractFty"
fvsRhs = Set.unions
[ Set.fromList (filter (not . isDummyType . fst)
(maybe [(ty, v)]
(qafv' ty)
(Map.lookup v env)))
| (ty, v) <- concatMap (qafv m) fds ]
qafv' ty (re, fty) =
let unifier = matchType fty ty idSubst
in map (\(ty', v) -> (subst unifier ty', v)) $ qafv m re
fvs = filter ((`elem` lvs) . snd) (Set.toList fvsRhs)
env' = foldr bindF env fs
bindF f =
Map.insert f ( apply (mkFun m pre dummyType f) (map (uncurry mkVar) fvs)
, fromJust $ lookup f ftys )
fs' = filter (\f -> null $ lookupValue (liftIdent pre f) vEnv) fs
withLocalAbstractEnv env' $ do
fds' <- mapM (absFunDecl pre fvs lvs) [d | d <- fds, any (`elem` fs') (bv d)]
e' <- absFunDecls pre lvs fdss vds e
return (Let NoSpanInfo fds' e')
absFunDecl :: String -> [(Type, Ident)] -> [Ident] -> Decl Type
-> LiftM (Decl Type)
absFunDecl pre fvs lvs (FunctionDecl p _ f eqs) = do
m <- getModuleIdent
d <- absDecl pre lvs $ FunctionDecl p undefined f' eqs'
let FunctionDecl _ _ _ eqs'' = d
modifyValueEnv $ bindGlobalInfo
(\qf tySc -> Value qf False (eqnArity $ head eqs') tySc) m f' $ polyType ty''
return $ FunctionDecl p ty'' f' eqs''
where f' = liftIdent pre f
ty' = foldr TypeArrow (typeOf rhs') (map typeOf ts')
where Equation _ (FunLhs _ _ ts') rhs' = head eqs'
ty'' = genType ty'
eqs' = map addVars eqs
genType ty''' = subst (foldr2 bindSubst idSubst tvs tvs') ty'''
where tvs = nub (typeVars ty''')
tvs' = map TypeVariable [0 ..]
addVars (Equation p' (FunLhs _ _ ts) rhs) =
Equation p' (FunLhs NoSpanInfo
f' (map (uncurry (VariablePattern NoSpanInfo)) fvs ++ ts)) rhs
addVars _ = error "Lift.absFunDecl.addVars: no pattern match"
absFunDecl pre _ _ (ExternalDecl p vs) = ExternalDecl p <$> mapM (absVar pre) vs
absFunDecl _ _ _ _ = error "Lift.absFunDecl: no pattern match"
absVar :: String -> Var Type -> LiftM (Var Type)
absVar pre (Var ty f) = do
m <- getModuleIdent
modifyValueEnv $ bindGlobalInfo
(\qf tySc -> Value qf False (arrowArity ty) tySc) m f' $ polyType ty
return $ Var ty f'
where f' = liftIdent pre f
absExpr :: String -> [Ident] -> Expression Type -> LiftM (Expression Type)
absExpr _ _ l@(Literal _ _ _) = return l
absExpr pre lvs var@(Variable _ ty v)
| isQualified v = return var
| otherwise = do
getAbstractEnv >>= \env -> case Map.lookup (unqualify v) env of
Nothing -> return var
Just (e, fty) -> let unifier = matchType fty ty idSubst
in absExpr pre lvs $ fmap (subst unifier) $ absType ty e
where
absType ty' (Variable spi _ v') = Variable spi ty' v'
absType ty' (Apply spi e1 e2) =
Apply spi (absType (TypeArrow (typeOf e2) ty') e1) e2
absType _ _ = internalError "Lift.absExpr.absType"
absExpr _ _ c@(Constructor _ _ _) = return c
absExpr pre lvs (Apply spi e1 e2) = Apply spi <$> absExpr pre lvs e1
<*> absExpr pre lvs e2
absExpr pre lvs (Let _ ds e) = absDeclGroup pre lvs ds e
absExpr pre lvs (Case spi ct e bs) =
Case spi ct <$> absExpr pre lvs e <*> mapM (absAlt pre lvs) bs
absExpr pre lvs (Typed spi e ty) =
flip (Typed spi) ty <$> absExpr pre lvs e
absExpr _ _ e = internalError $ "Lift.absExpr: " ++ show e
absAlt :: String -> [Ident] -> Alt Type -> LiftM (Alt Type)
absAlt pre lvs (Alt p t rhs) = Alt p t <$> absRhs pre lvs' rhs
where lvs' = lvs ++ bv t
liftFunDecl :: Eq a => Decl a -> [Decl a]
liftFunDecl (FunctionDecl p a f eqs) =
FunctionDecl p a f eqs' : map renameFunDecl (concat dss')
where (eqs', dss') = unzip $ map liftEquation eqs
liftFunDecl d = [d]
liftVarDecl :: Eq a => Decl a -> (Decl a, [Decl a])
liftVarDecl (PatternDecl p t rhs) = (PatternDecl p t rhs', ds')
where (rhs', ds') = liftRhs rhs
liftVarDecl ex@(FreeDecl _ _) = (ex, [])
liftVarDecl _ = error "Lift.liftVarDecl: no pattern match"
liftEquation :: Eq a => Equation a -> (Equation a, [Decl a])
liftEquation (Equation p lhs rhs) = (Equation p lhs rhs', ds')
where (rhs', ds') = liftRhs rhs
liftRhs :: Eq a => Rhs a -> (Rhs a, [Decl a])
liftRhs (SimpleRhs p e _) = first (simpleRhs p) (liftExpr e)
liftRhs _ = error "Lift.liftRhs: no pattern match"
liftDeclGroup :: Eq a => [Decl a] -> ([Decl a], [Decl a])
liftDeclGroup ds = (vds', concat (map liftFunDecl fds ++ dss'))
where (fds , vds ) = partition isFunDecl ds
(vds', dss') = unzip $ map liftVarDecl vds
liftExpr :: Eq a => Expression a -> (Expression a, [Decl a])
liftExpr l@(Literal _ _ _) = (l, [])
liftExpr v@(Variable _ _ _) = (v, [])
liftExpr c@(Constructor _ _ _) = (c, [])
liftExpr (Apply spi e1 e2) = (Apply spi e1' e2', ds1 ++ ds2)
where (e1', ds1) = liftExpr e1
(e2', ds2) = liftExpr e2
liftExpr (Let _ ds e) = (mkLet ds' e', ds1 ++ ds2)
where (ds', ds1) = liftDeclGroup ds
(e' , ds2) = liftExpr e
liftExpr (Case spi ct e alts) = (Case spi ct e' alts', concat $ ds' : dss')
where (e' , ds' ) = liftExpr e
(alts', dss') = unzip $ map liftAlt alts
liftExpr (Typed spi e ty) =
(Typed spi e' ty, ds) where (e', ds) = liftExpr e
liftExpr _ = internalError "Lift.liftExpr"
liftAlt :: Eq a => Alt a -> (Alt a, [Decl a])
liftAlt (Alt p t rhs) = (Alt p t rhs', ds') where (rhs', ds') = liftRhs rhs
type RenameMap a = [((a, Ident), Ident)]
renameFunDecl :: Eq a => Decl a -> Decl a
renameFunDecl (FunctionDecl p a f eqs) =
FunctionDecl p a f (map renameEquation eqs)
renameFunDecl d = d
renameEquation :: Eq a => Equation a -> Equation a
renameEquation (Equation p lhs rhs) = Equation p lhs' (renameRhs rm rhs)
where (rm, lhs') = renameLhs lhs
renameLhs :: Eq a => Lhs a -> (RenameMap a, Lhs a)
renameLhs (FunLhs spi f ts) = (rm, FunLhs spi f ts')
where (rm, ts') = foldr renamePattern ([], []) ts
renameLhs _ = error "Lift.renameLhs"
renamePattern :: Eq a => Pattern a -> (RenameMap a, [Pattern a])
-> (RenameMap a, [Pattern a])
renamePattern (VariablePattern spi a v) (rm, ts)
| v `elem` varPatNames ts =
let v' = updIdentName (++ ("." ++ show (length rm))) v
in (((a, v), v') : rm, VariablePattern spi a v' : ts)
renamePattern t (rm, ts) = (rm, t : ts)
renameRhs :: Eq a => RenameMap a -> Rhs a -> Rhs a
renameRhs rm (SimpleRhs p e _) = simpleRhs p (renameExpr rm e)
renameRhs _ _ = error "Lift.renameRhs"
renameExpr :: Eq a => RenameMap a -> Expression a -> Expression a
renameExpr _ l@(Literal _ _ _) = l
renameExpr rm v@(Variable spi a v')
| isQualified v' = v
| otherwise = case lookup (a, unqualify v') rm of
Just v'' -> Variable spi a (qualify v'')
_ -> v
renameExpr _ c@(Constructor _ _ _) = c
renameExpr rm (Typed spi e ty) = Typed spi (renameExpr rm e) ty
renameExpr rm (Apply spi e1 e2) =
Apply spi (renameExpr rm e1) (renameExpr rm e2)
renameExpr rm (Let spi ds e) =
Let spi (map (renameDecl rm) ds) (renameExpr rm e)
renameExpr rm (Case spi ct e alts) =
Case spi ct (renameExpr rm e) (map (renameAlt rm) alts)
renameExpr _ _ = error "Lift.renameExpr"
renameDecl :: Eq a => RenameMap a -> Decl a -> Decl a
renameDecl rm (PatternDecl p t rhs) = PatternDecl p t (renameRhs rm rhs)
renameDecl _ d = d
renameAlt :: Eq a => RenameMap a -> Alt a -> Alt a
renameAlt rm (Alt p t rhs) = Alt p t (renameRhs rm rhs)
isFunDecl :: Decl a -> Bool
isFunDecl (FunctionDecl _ _ _ _) = True
isFunDecl (ExternalDecl _ _ ) = True
isFunDecl _ = False
mkFun :: ModuleIdent -> String -> a -> Ident -> Expression a
mkFun m pre a = Variable NoSpanInfo a . qualifyWith m . liftIdent pre
liftIdent :: String -> Ident -> Ident
liftIdent prefix x = renameIdent (mkIdent $ prefix ++ showIdent x) $ idUnique x
varPatNames :: [Pattern a] -> [Ident]
varPatNames = mapMaybe varPatName
varPatName :: Pattern a -> Maybe Ident
varPatName (VariablePattern _ _ i) = Just i
varPatName _ = Nothing
dummyType :: Type
dummyType = TypeForall [] undefined
isDummyType :: Type -> Bool
isDummyType (TypeForall [] _) = True
isDummyType _ = False