module IRTS.Defunctionalise(module IRTS.Defunctionalise
, module IRTS.Lang
) where
import IRTS.Lang
import Idris.Core.TT
import Idris.Core.CaseTree
import Debug.Trace
import Data.Maybe
import Data.List
import Control.Monad
import Control.Monad.State
data DExp = DV LVar
| DApp Bool Name [DExp]
| DLet Name DExp DExp
| DUpdate Name DExp
| DProj DExp Int
| DC (Maybe LVar) Int Name [DExp]
| DCase CaseType DExp [DAlt]
| DChkCase DExp [DAlt]
| DConst Const
| DForeign FDesc FDesc [(FDesc, DExp)]
| DOp PrimFn [DExp]
| DNothing
| DError String
deriving Eq
data DAlt = DConCase Int Name [Name] DExp
| DConstCase Const DExp
| DDefaultCase DExp
deriving (Show, Eq)
data DDecl = DFun Name [Name] DExp
| DConstructor Name Int Int
deriving (Show, Eq)
type DDefs = Ctxt DDecl
defunctionalise :: Int -> LDefs -> DDefs
defunctionalise nexttag defs
= let all = toAlist defs
(allD, (enames, anames)) = runState (mapM (addApps defs) all) ([], [])
anames' = sort (nub anames)
enames' = nub enames
newecons = sortBy conord $ concatMap (toCons enames') (getFn all)
newacons = sortBy conord $ concatMap (toConsA anames') (getFn all)
eval = mkEval newecons
app = mkApply newacons
app2 = mkApply2 newacons
condecls = declare nexttag (newecons ++ newacons) in
addAlist (eval : app : app2 : condecls ++ allD) emptyContext
where conord (n, _, _) (n', _, _) = compare n n'
getFn :: [(Name, LDecl)] -> [(Name, Int)]
getFn xs = mapMaybe fnData xs
where fnData (n, LFun _ _ args _) = Just (n, length args)
fnData _ = Nothing
addApps :: LDefs -> (Name, LDecl) -> State ([Name], [(Name, Int)]) (Name, DDecl)
addApps defs o@(n, LConstructor _ t a)
= return (n, DConstructor n t a)
addApps defs (n, LFun _ _ args e)
= do e' <- aa args e
return (n, DFun n args e')
where
aa :: [Name] -> LExp -> State ([Name], [(Name, Int)]) DExp
aa env (LV (Glob n)) | n `elem` env = return $ DV (Glob n)
| otherwise = aa env (LApp False (LV (Glob n)) [])
aa env (LApp tc (LV (Glob n)) args)
= do args' <- mapM (aa env) args
case lookupCtxtExact n defs of
Just (LConstructor _ i ar) -> return $ DApp tc n args'
Just (LFun _ _ as _) -> let arity = length as in
fixApply tc n args' arity
Nothing -> return $ chainAPPLY (DV (Glob n)) args'
aa env (LLazyApp n args)
= do args' <- mapM (aa env) args
case lookupCtxtExact n defs of
Just (LConstructor _ i ar) -> return $ DApp False n args'
Just (LFun _ _ as _) -> let arity = length as in
fixLazyApply n args' arity
Nothing -> return $ chainAPPLY (DV (Glob n)) args'
aa env (LForce (LLazyApp n args)) = aa env (LApp False (LV (Glob n)) args)
aa env (LForce e) = liftM eEVAL (aa env e)
aa env (LLet n v sc) = liftM2 (DLet n) (aa env v) (aa (n : env) sc)
aa env (LCon loc i n args) = liftM (DC loc i n) (mapM (aa env) args)
aa env (LProj t@(LV (Glob n)) i)
| n `elem` env = do t' <- aa env t
return $ DProj (DUpdate n t') i
aa env (LProj t i) = do t' <- aa env t
return $ DProj t' i
aa env (LCase up e alts) = do e' <- aa env e
alts' <- mapM (aaAlt env) alts
return $ DCase up e' alts'
aa env (LConst c) = return $ DConst c
aa env (LForeign t n args)
= do args' <- mapM (aaF env) args
return $ DForeign t n args'
aa env (LOp LFork args) = liftM (DOp LFork) (mapM (aa env) args)
aa env (LOp f args) = do args' <- mapM (aa env) args
return $ DOp f args'
aa env LNothing = return DNothing
aa env (LError e) = return $ DError e
aaF env (t, e) = do e' <- aa env e
return (t, e')
aaAlt env (LConCase i n args e)
= liftM (DConCase i n args) (aa (args ++ env) e)
aaAlt env (LConstCase c e) = liftM (DConstCase c) (aa env e)
aaAlt env (LDefaultCase e) = liftM DDefaultCase (aa env e)
fixApply tc n args ar
| length args == ar
= return $ DApp tc n args
| length args < ar
= do (ens, ans) <- get
let alln = map (\x -> (n, x)) [length args .. ar]
put (ens, alln ++ ans)
return $ DApp tc (mkUnderCon n (ar length args)) args
| length args > ar
= return $ chainAPPLY (DApp tc n (take ar args)) (drop ar args)
fixLazyApply n args ar
| length args == ar
= do (ens, ans) <- get
put (n : ens, ans)
return $ DApp False (mkFnCon n) args
| length args < ar
= do (ens, ans) <- get
let alln = map (\x -> (n, x)) [length args .. ar]
put (ens, alln ++ ans)
return $ DApp False (mkUnderCon n (ar length args)) args
| length args > ar
= return $ chainAPPLY (DApp False n (take ar args)) (drop ar args)
chainAPPLY f [] = f
chainAPPLY f (a : as) = chainAPPLY (DApp False (sMN 0 "APPLY") [f, a]) as
preEval [] t = t
preEval (x : xs) t
| needsEval x t = DLet x (DV (Glob x)) (preEval xs t)
| otherwise = preEval xs t
needsEval x (DApp _ _ args) = any (needsEval x) args
needsEval x (DC _ _ _ args) = any (needsEval x) args
needsEval x (DCase up e alts) = needsEval x e || any nec alts
where nec (DConCase _ _ _ e) = needsEval x e
nec (DConstCase _ e) = needsEval x e
nec (DDefaultCase e) = needsEval x e
needsEval x (DChkCase e alts) = needsEval x e || any nec alts
where nec (DConCase _ _ _ e) = needsEval x e
nec (DConstCase _ e) = needsEval x e
nec (DDefaultCase e) = needsEval x e
needsEval x (DLet n v e)
| x == n = needsEval x v
| otherwise = needsEval x v || needsEval x e
needsEval x (DForeign _ _ args) = any (needsEval x) (map snd args)
needsEval x (DOp op args) = any (needsEval x) args
needsEval x (DProj (DV (Glob x')) _) = x == x'
needsEval x _ = False
eEVAL x = DApp False (sMN 0 "EVAL") [x]
data EvalApply a = EvalCase (Name -> a)
| ApplyCase a
| Apply2Case a
toCons :: [Name] -> (Name, Int) -> [(Name, Int, EvalApply DAlt)]
toCons ns (n, i)
| n `elem` ns
= (mkFnCon n, i,
EvalCase (\tlarg ->
(DConCase (1) (mkFnCon n) (take i (genArgs 0))
(dupdate tlarg
(DApp False n (map (DV . Glob) (take i (genArgs 0))))))))
: []
| otherwise = []
where dupdate tlarg x = DUpdate tlarg x
toConsA :: [(Name, Int)] -> (Name, Int) -> [(Name, Int, EvalApply DAlt)]
toConsA ns (n, i)
| Just ar <- lookup n ns
= mkApplyCase n ar i
| otherwise = []
where dupdate tlarg x = x
mkApplyCase fname n ar | n == ar = []
mkApplyCase fname n ar
= let nm = mkUnderCon fname (ar n) in
(nm, n, ApplyCase (DConCase (1) nm (take n (genArgs 0))
(DApp False (mkUnderCon fname (ar (n + 1)))
(map (DV . Glob) (take n (genArgs 0) ++
[sMN 0 "arg"])))))
:
if (ar (n + 2) >=0 )
then (nm, n, Apply2Case (DConCase (1) nm (take n (genArgs 0))
(DApp False (mkUnderCon fname (ar (n + 2)))
(map (DV . Glob) (take n (genArgs 0) ++
[sMN 0 "arg0", sMN 0 "arg1"])))))
:
mkApplyCase fname (n + 1) ar
else mkApplyCase fname (n + 1) ar
mkEval :: [(Name, Int, EvalApply DAlt)] -> (Name, DDecl)
mkEval xs = (sMN 0 "EVAL", DFun (sMN 0 "EVAL") [sMN 0 "arg"]
(mkBigCase (sMN 0 "EVAL") 256 (DV (Glob (sMN 0 "arg")))
(mapMaybe evalCase xs ++
[DDefaultCase (DV (Glob (sMN 0 "arg")))])))
where
evalCase (n, t, EvalCase x) = Just (x (sMN 0 "arg"))
evalCase _ = Nothing
mkApply :: [(Name, Int, EvalApply DAlt)] -> (Name, DDecl)
mkApply xs = (sMN 0 "APPLY", DFun (sMN 0 "APPLY") [sMN 0 "fn", sMN 0 "arg"]
(case mapMaybe applyCase xs of
[] -> DNothing
cases ->
mkBigCase (sMN 0 "APPLY") 256
(DV (Glob (sMN 0 "fn")))
(cases ++
[DDefaultCase DNothing])))
where
applyCase (n, t, ApplyCase x) = Just x
applyCase _ = Nothing
mkApply2 :: [(Name, Int, EvalApply DAlt)] -> (Name, DDecl)
mkApply2 xs = (sMN 0 "APPLY2", DFun (sMN 0 "APPLY2") [sMN 0 "fn", sMN 0 "arg0", sMN 0 "arg1"]
(case mapMaybe applyCase xs of
[] -> DNothing
cases ->
mkBigCase (sMN 0 "APPLY") 256
(DV (Glob (sMN 0 "fn")))
(cases ++
[DDefaultCase
(DApp False (sMN 0 "APPLY")
[DApp False (sMN 0 "APPLY")
[DV (Glob (sMN 0 "fn")),
DV (Glob (sMN 0 "arg0"))],
DV (Glob (sMN 0 "arg1"))])
])))
where
applyCase (n, t, Apply2Case x) = Just x
applyCase _ = Nothing
declare :: Int -> [(Name, Int, EvalApply DAlt)] -> [(Name, DDecl)]
declare t xs = dec' t xs [] where
dec' t [] acc = reverse acc
dec' t ((n, ar, _) : xs) acc = dec' (t + 1) xs ((n, DConstructor n t ar) : acc)
genArgs i = sMN i "P_c" : genArgs (i + 1)
mkFnCon n = sMN 0 ("P_" ++ show n)
mkUnderCon n 0 = n
mkUnderCon n missing = sMN missing ("U_" ++ show n)
instance Show DExp where
show e = show' [] e where
show' env (DV (Loc i)) = "var " ++ env!!i
show' env (DV (Glob n)) = "GLOB " ++ show n
show' env (DApp _ e args) = show e ++ "(" ++
showSep ", " (map (show' env) args) ++")"
show' env (DLet n v e) = "let " ++ show n ++ " = " ++ show' env v ++ " in " ++
show' (env ++ [show n]) e
show' env (DUpdate n e) = "!update " ++ show n ++ "(" ++ show' env e ++ ")"
show' env (DC loc i n args) = atloc loc ++ "CON " ++ show n ++ "(" ++ showSep ", " (map (show' env) args) ++ ")"
where atloc Nothing = ""
atloc (Just l) = "@" ++ show (LV l) ++ ":"
show' env (DProj t i) = show t ++ "!" ++ show i
show' env (DCase up e alts) = "case" ++ update ++ show' env e ++ " of {\n\t" ++
showSep "\n\t| " (map (showAlt env) alts)
where update = case up of
Shared -> " "
Updatable -> "! "
show' env (DChkCase e alts) = "case' " ++ show' env e ++ " of {\n\t" ++
showSep "\n\t| " (map (showAlt env) alts)
show' env (DConst c) = show c
show' env (DForeign ty n args)
= "foreign " ++ show n ++ "(" ++ showSep ", " (map (show' env) (map snd args)) ++ ")"
show' env (DOp f args) = show f ++ "(" ++ showSep ", " (map (show' env) args) ++ ")"
show' env (DError str) = "error " ++ show str
show' env DNothing = "____"
showAlt env (DConCase _ n args e)
= show n ++ "(" ++ showSep ", " (map show args) ++ ") => "
++ show' env e
showAlt env (DConstCase c e) = show c ++ " => " ++ show' env e
showAlt env (DDefaultCase e) = "_ => " ++ show' env e
mkBigCase cn max arg branches
| length branches <= max = DChkCase arg branches
| otherwise =
let bs = sortBy tagOrd branches
(all, def) = case (last bs) of
DDefaultCase t -> (init all, Just (DDefaultCase t))
_ -> (all, Nothing)
bss = groupsOf max all
cs = map mkCase bss in
DChkCase arg branches
where mkCase bs = DChkCase arg bs
tagOrd (DConCase t _ _ _) (DConCase t' _ _ _) = compare t t'
tagOrd (DConstCase c _) (DConstCase c' _) = compare c c'
tagOrd (DDefaultCase _) (DDefaultCase _) = EQ
tagOrd (DConCase _ _ _ _) (DDefaultCase _) = LT
tagOrd (DConCase _ _ _ _) (DConstCase _ _) = LT
tagOrd (DConstCase _ _) (DDefaultCase _) = LT
tagOrd (DDefaultCase _) (DConCase _ _ _ _) = GT
tagOrd (DConstCase _ _) (DConCase _ _ _ _) = GT
tagOrd (DDefaultCase _) (DConstCase _ _) = GT
groupsOf :: Int -> [DAlt] -> [[DAlt]]
groupsOf x [] = []
groupsOf x xs = let (batch, rest) = span (tagLT (x + tagHead xs)) xs in
batch : groupsOf x rest
where tagHead (DConstCase (I i) _ : _) = i
tagHead (DConCase t _ _ _ : _) = t
tagHead (DDefaultCase _ : _) = 1
tagLT i (DConstCase (I j) _) = i < j
tagLT i (DConCase j _ _ _) = i < j
tagLT i (DDefaultCase _) = False
dumpDefuns :: DDefs -> String
dumpDefuns ds = showSep "\n" $ map showDef (toAlist ds)
where showDef (x, DFun fn args exp)
= show fn ++ "(" ++ showSep ", " (map show args) ++ ") = \n\t" ++
show exp ++ "\n"
showDef (x, DConstructor n t a) = "Constructor " ++ show n ++ " " ++ show t