{-# LANGUAGE PatternGuards #-}

module IRTS.Defunctionalise(module IRTS.Defunctionalise,
                            module IRTS.Lang) where

import IRTS.Lang
import Idris.Core.TT
import Idris.Core.CaseTree

import Debug.Trace
import Data.Maybe
import Data.List
import Control.Monad
import Control.Monad.State

data DExp = DV LVar
          | DApp Bool Name [DExp] -- True = tail call
          | DLet Name DExp DExp -- name just for pretty printing
          | DUpdate Name DExp -- eval expression, then update var with it
          | DProj DExp Int
          | DC (Maybe LVar) Int Name [DExp]
          | DCase CaseType DExp [DAlt]
          | DChkCase DExp [DAlt] -- a case where the type is unknown (for EVAL/APPLY)
          | DConst Const
          | DForeign FDesc FDesc [(FDesc, DExp)]
          | DOp PrimFn [DExp]
          | DNothing -- erased value, can be compiled to anything since it'll never
                     -- be inspected
          | DError String
  deriving Eq

data DAlt = DConCase Int Name [Name] DExp
          | DConstCase Const DExp
          | DDefaultCase DExp
  deriving (Show, Eq)

data DDecl = DFun Name [Name] DExp -- name, arg names, definition
           | DConstructor Name Int Int -- constructor name, tag, arity
  deriving (Show, Eq)

type DDefs = Ctxt DDecl

defunctionalise :: Int -> LDefs -> DDefs
defunctionalise nexttag defs
     = let all = toAlist defs
           -- sort newcons so that EVAL and APPLY cons get sequential tags
           (allD, (enames, anames)) = runState (mapM (addApps defs) all) ([], [])
           anames' = sort (nub anames)
           enames' = nub enames
           newecons = sortBy conord $ concatMap (toCons enames') (getFn all)
           newacons = sortBy conord $ concatMap (toConsA anames') (getFn all)
           eval = mkEval newecons
           app = mkApply newacons
           condecls = declare nexttag (newecons ++ newacons) in
           addAlist (eval : app : condecls ++ allD) emptyContext
   where conord (n, _, _) (n', _, _) = compare n n'

getFn :: [(Name, LDecl)] -> [(Name, Int)]
getFn xs = mapMaybe fnData xs
  where fnData (n, LFun _ _ args _) = Just (n, length args)
        fnData _ = Nothing

-- To defunctionalise:
--
-- 1 Create a data constructor for each function
-- 2 Create a data constructor for each underapplication of a function
-- 3 Convert underapplications to their corresponding constructors
-- 4 Create an EVAL function which calls the appropriate function for data constructors
--   created as part of step 1
-- 5 Create an APPLY function which adds an argument to each underapplication (or calls
--   APPLY again for an exact application)
-- 6 Wrap overapplications in chains of APPLY
-- 7 Wrap unknown applications (i.e. applications of local variables) in chains of APPLY
-- 8 Add explicit EVAL to case, primitives, and foreign calls

addApps :: LDefs -> (Name, LDecl) -> State ([Name], [(Name, Int)]) (Name, DDecl)
addApps defs o@(n, LConstructor _ t a)
    = return (n, DConstructor n t a)
addApps defs (n, LFun _ _ args e)
    = do e' <- aa args e
         return (n, DFun n args e')
  where
    aa :: [Name] -> LExp -> State ([Name], [(Name, Int)]) DExp
    aa env (LV (Glob n)) | n `elem` env = return $ DV (Glob n)
                         | otherwise = aa env (LApp False (LV (Glob n)) [])
    aa env (LApp tc (LV (Glob n)) args)
       = do args' <- mapM (aa env) args
            case lookupCtxtExact n defs of
                Just (LConstructor _ i ar) -> return $ DApp tc n args'
                Just (LFun _ _ as _) -> let arity = length as in
                                               fixApply tc n args' arity
                Nothing -> return $ chainAPPLY (DV (Glob n)) args'
    aa env (LLazyApp n args)
       = do args' <- mapM (aa env) args
            case lookupCtxtExact n defs of
                Just (LConstructor _ i ar) -> return $ DApp False n args'
                Just (LFun _ _ as _) -> let arity = length as in
                                           fixLazyApply n args' arity
                Nothing -> return $ chainAPPLY (DV (Glob n)) args'
    aa env (LForce (LLazyApp n args)) = aa env (LApp False (LV (Glob n)) args)
    aa env (LForce e) = liftM eEVAL (aa env e)
    aa env (LLet n v sc) = liftM2 (DLet n) (aa env v) (aa (n : env) sc)
    aa env (LCon loc i n args) = liftM (DC loc i n) (mapM (aa env) args)
    aa env (LProj t@(LV (Glob n)) i)
        | n `elem` env = do t' <- aa env t
                            return $ DProj (DUpdate n t') i
    aa env (LProj t i) = do t' <- aa env t
                            return $ DProj t' i
    aa env (LCase up e alts) = do e' <- aa env e
                                  alts' <- mapM (aaAlt env) alts
                                  return $ DCase up e' alts'
    aa env (LConst c) = return $ DConst c
    aa env (LForeign t n args) 
        = do args' <- mapM (aaF env) args
             return $ DForeign t n args'
    aa env (LOp LFork args) = liftM (DOp LFork) (mapM (aa env) args)
    aa env (LOp f args) = do args' <- mapM (aa env) args
                             return $ DOp f args'
    aa env LNothing = return DNothing
    aa env (LError e) = return $ DError e

    aaF env (t, e) = do e' <- aa env e
                        return (t, e')

    aaAlt env (LConCase i n args e)
         = liftM (DConCase i n args) (aa (args ++ env) e)
    aaAlt env (LConstCase c e) = liftM (DConstCase c) (aa env e)
    aaAlt env (LDefaultCase e) = liftM DDefaultCase (aa env e)

    fixApply tc n args ar
        | length args == ar
             = return $ DApp tc n args
        | length args < ar
             = do (ens, ans) <- get
                  let alln = map (\x -> (n, x)) [length args .. ar] 
                  put (ens, alln ++ ans)
                  return $ DApp tc (mkUnderCon n (ar - length args)) args
        | length args > ar
             = return $ chainAPPLY (DApp tc n (take ar args)) (drop ar args)

    fixLazyApply n args ar
        | length args == ar
             = do (ens, ans) <- get
                  put (n : ens, ans)
                  return $ DApp False (mkFnCon n) args
        | length args < ar
             = do (ens, ans) <- get
                  let alln = map (\x -> (n, x)) [length args .. ar] 
                  put (ens, alln ++ ans)
                  return $ DApp False (mkUnderCon n (ar - length args)) args
        | length args > ar
             = return $ chainAPPLY (DApp False n (take ar args)) (drop ar args)

    chainAPPLY f [] = f
    chainAPPLY f (a : as) = chainAPPLY (DApp False (sMN 0 "APPLY") [f, a]) as

    -- if anything in the DExp is projected from, we'll need to evaluate it,
    -- but we only want to do it once, rather than every time we project.

    preEval [] t = t
    preEval (x : xs) t
       | needsEval x t = DLet x (DV (Glob x)) (preEval xs t)
       | otherwise = preEval xs t

    needsEval x (DApp _ _ args) = or (map (needsEval x) args)
    needsEval x (DC _ _ _ args) = or (map (needsEval x) args)
    needsEval x (DCase up e alts) = needsEval x e || or (map nec alts)
      where nec (DConCase _ _ _ e) = needsEval x e
            nec (DConstCase _ e) = needsEval x e
            nec (DDefaultCase e) = needsEval x e
    needsEval x (DChkCase e alts) = needsEval x e || or (map nec alts)
      where nec (DConCase _ _ _ e) = needsEval x e
            nec (DConstCase _ e) = needsEval x e
            nec (DDefaultCase e) = needsEval x e
    needsEval x (DLet n v e)
          | x == n = needsEval x v
          | otherwise = needsEval x v || needsEval x e
    needsEval x (DForeign _ _ args) = or (map (needsEval x) (map snd args))
    needsEval x (DOp op args) = or (map (needsEval x) args)
    needsEval x (DProj (DV (Glob x')) _) = x == x'
    needsEval x _ = False

eEVAL x = DApp False (sMN 0 "EVAL") [x]

data EvalApply a = EvalCase (Name -> a)
                 | ApplyCase a

-- For a function name, generate a list of
-- data constuctors, and whether to handle them in EVAL or APPLY

toCons :: [Name] -> (Name, Int) -> [(Name, Int, EvalApply DAlt)]
toCons ns (n, i)
    | n `elem` ns
      = (mkFnCon n, i,
          EvalCase (\tlarg ->
            (DConCase (-1) (mkFnCon n) (take i (genArgs 0))
              (dupdate tlarg
                (DApp False n (map (DV . Glob) (take i (genArgs 0))))))))
          : [] -- mkApplyCase n 0 i
    | otherwise = []
  where dupdate tlarg x = DUpdate tlarg x

toConsA :: [(Name, Int)] -> (Name, Int) -> [(Name, Int, EvalApply DAlt)]
toConsA ns (n, i)
    | Just ar <- lookup n ns
--       = (mkFnCon n, i,
--           EvalCase (\tlarg ->
--             (DConCase (-1) (mkFnCon n) (take i (genArgs 0))
--               (dupdate tlarg
--                 (DApp False n (map (DV . Glob) (take i (genArgs 0))))))))
          = mkApplyCase n ar i
    | otherwise = []
  where dupdate tlarg x = x

mkApplyCase fname n ar | n == ar = []
mkApplyCase fname n ar
        = let nm = mkUnderCon fname (ar - n) in
              (nm, n, ApplyCase (DConCase (-1) nm (take n (genArgs 0))
                  (DApp False (mkUnderCon fname (ar - (n + 1)))
                       (map (DV . Glob) (take n (genArgs 0) ++
                         [sMN 0 "arg"])))))
                            : mkApplyCase fname (n + 1) ar

mkEval :: [(Name, Int, EvalApply DAlt)] -> (Name, DDecl)
mkEval xs = (sMN 0 "EVAL", DFun (sMN 0 "EVAL") [sMN 0 "arg"]
               (mkBigCase (sMN 0 "EVAL") 256 (DV (Glob (sMN 0 "arg")))
                  (mapMaybe evalCase xs ++
                      [DDefaultCase (DV (Glob (sMN 0 "arg")))])))
  where
    evalCase (n, t, EvalCase x) = Just (x (sMN 0 "arg"))
    evalCase _ = Nothing

mkApply :: [(Name, Int, EvalApply DAlt)] -> (Name, DDecl)
mkApply xs = (sMN 0 "APPLY", DFun (sMN 0 "APPLY") [sMN 0 "fn", sMN 0 "arg"]
                             (case mapMaybe applyCase xs of
                                [] -> DNothing
                                cases ->
                                    mkBigCase (sMN 0 "APPLY") 256
                                               (DV (Glob (sMN 0 "fn")))
                                              (cases ++
                                    [DDefaultCase DNothing])))
  where
    applyCase (n, t, ApplyCase x) = Just x
    applyCase _ = Nothing


declare :: Int -> [(Name, Int, EvalApply DAlt)] -> [(Name, DDecl)]
declare t xs = dec' t xs [] where
   dec' t [] acc = reverse acc
   dec' t ((n, ar, _) : xs) acc = dec' (t + 1) xs ((n, DConstructor n t ar) : acc)


genArgs i = sMN i "P_c" : genArgs (i + 1)

mkFnCon    n = sMN 0 ("P_" ++ show n)
mkUnderCon n 0       = n
mkUnderCon n missing = sMN missing ("U_" ++ show n)

instance Show DExp where
   show e = show' [] e where
     show' env (DV (Loc i)) = "var " ++ env!!i
     show' env (DV (Glob n)) = "GLOB " ++ show n
     show' env (DApp _ e args) = show e ++ "(" ++
                                   showSep ", " (map (show' env) args) ++")"
     show' env (DLet n v e) = "let " ++ show n ++ " = " ++ show' env v ++ " in " ++
                               show' (env ++ [show n]) e
     show' env (DUpdate n e) = "!update " ++ show n ++ "(" ++ show' env e ++ ")"
     show' env (DC loc i n args) = atloc loc ++ "CON " ++ show n ++ "(" ++ showSep ", " (map (show' env) args) ++ ")"
       where atloc Nothing = ""
             atloc (Just l) = "@" ++ show (LV l) ++ ":"
     show' env (DProj t i) = show t ++ "!" ++ show i
     show' env (DCase up e alts) = "case" ++ update ++ show' env e ++ " of {\n\t" ++
                                    showSep "\n\t| " (map (showAlt env) alts)
         where update = case up of
                           Shared -> " "
                           Updatable -> "! "
     show' env (DChkCase e alts) = "case' " ++ show' env e ++ " of {\n\t" ++
                                    showSep "\n\t| " (map (showAlt env) alts)
     show' env (DConst c) = show c
     show' env (DForeign ty n args)
           = "foreign " ++ show n ++ "(" ++ showSep ", " (map (show' env) (map snd args)) ++ ")"
     show' env (DOp f args) = show f ++ "(" ++ showSep ", " (map (show' env) args) ++ ")"
     show' env (DError str) = "error " ++ show str
     show' env DNothing = "____"

     showAlt env (DConCase _ n args e)
          = show n ++ "(" ++ showSep ", " (map show args) ++ ") => "
             ++ show' env e
     showAlt env (DConstCase c e) = show c ++ " => " ++ show' env e
     showAlt env (DDefaultCase e) = "_ => " ++ show' env e

-- Divide up a large case expression so that each has a maximum of
-- 'max' branches

mkBigCase cn max arg branches
   | length branches <= max = DChkCase arg branches
   | otherwise = -- DChkCase arg branches -- until I think of something...
       -- divide the branches into groups of at most max (by tag),
       -- generate a new case and shrink, recursively
       let bs = sortBy tagOrd branches
           (all, def) = case (last bs) of
                    DDefaultCase t -> (init all, Just (DDefaultCase t))
                    _ -> (all, Nothing)
           bss = groupsOf max all
           cs = map mkCase bss in
           DChkCase arg branches

    where mkCase bs = DChkCase arg bs

          tagOrd (DConCase t _ _ _) (DConCase t' _ _ _) = compare t t'
          tagOrd (DConstCase c _) (DConstCase c' _) = compare c c'
          tagOrd (DDefaultCase _) (DDefaultCase _) = EQ

          tagOrd (DConCase _ _ _ _) (DDefaultCase _) = LT
          tagOrd (DConCase _ _ _ _) (DConstCase _ _) = LT
          tagOrd (DConstCase _ _) (DDefaultCase _) = LT

          tagOrd (DDefaultCase _) (DConCase _ _ _ _) = GT
          tagOrd (DConstCase _ _) (DConCase _ _ _ _) = GT
          tagOrd (DDefaultCase _) (DConstCase _ _) = GT

groupsOf :: Int -> [DAlt] -> [[DAlt]]
groupsOf x [] = []
groupsOf x xs = let (batch, rest) = span (tagLT (x + tagHead xs)) xs in
                    batch : groupsOf x rest
  where tagHead (DConstCase (I i) _ : _) = i
        tagHead (DConCase t _ _ _ : _) = t
        tagHead (DDefaultCase _ : _) = -1 -- must be the end

        tagLT i (DConstCase (I j) _) = i < j
        tagLT i (DConCase j _ _ _) = i < j
        tagLT i (DDefaultCase _) = False

dumpDefuns :: DDefs -> String
dumpDefuns ds = showSep "\n" $ map showDef (toAlist ds)
  where showDef (x, DFun fn args exp)
            = show fn ++ "(" ++ showSep ", " (map show args) ++ ") = \n\t" ++
              show exp ++ "\n"
        showDef (x, DConstructor n t a) = "Constructor " ++ show n ++ " " ++ show t