{-# LANGUAGE FlexibleContexts  #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternGuards     #-}
{-# LANGUAGE TupleSections     #-}

-- | Optimizing the outputted JavaScript(-ish) AST.

module Fay.Compiler.Optimizer where

import           Fay.Compiler.Prelude

import           Fay.Compiler.Misc
import           Fay.Types

import           Control.Monad.State             (State, modify, runState)
import           Control.Monad.Writer            (runWriter, tell)
import qualified Fay.Exts.NoAnnotation           as N
import           Language.Haskell.Exts hiding (app, name, op)

-- | The arity of a function. Arity here is defined to be the number
-- of arguments that can be directly uncurried from a curried lambda
-- abstraction. So \x y z -> if x then (\a -> a) else (\a -> a) has an
-- arity of 3, not 4.
type FuncArity = (N.QName,Int)

-- | Optimize monad.
type Optimize = State OptState

-- | State.
data OptState = OptState
  { optStmts   :: [JsStmt]
  , optUncurry :: [N.QName]
  }

-- | Run an optimizer, which may output additional statements.
runOptimizer :: ([JsStmt] -> Optimize [JsStmt]) -> [JsStmt] -> [JsStmt]
runOptimizer optimizer stmts =
  let (newstmts,OptState _ uncurried) = flip runState st $ optimizer stmts
  in newstmts ++ (tco . mapMaybe (uncurryBinding newstmts) $ nub uncurried)
  where st = OptState stmts []

-- | Inline x >> y to x;y in the JS output.
inlineMonad :: [JsStmt] -> [JsStmt]
inlineMonad = map go where
  go stmt = case stmt of
    JsVar name exp          -> JsVar name (inline exp)
    JsIf exp stmts stmts'   -> JsIf (inline exp) (map go stmts) (map go stmts')
    JsEarlyReturn exp       -> JsEarlyReturn (inline exp)
    JsThrow exp             -> JsThrow (inline exp)
    JsWhile exp stmts       -> JsWhile (inline exp) (map go stmts)
    JsUpdate name exp       -> JsUpdate name (inline exp)
    JsSetProp a b exp       -> JsSetProp a b (inline exp)
    JsSetQName s a exp      -> JsSetQName s a (inline exp)
    JsSetModule a exp       -> JsSetModule a (inline exp)
    JsSetConstructor a exp  -> JsSetConstructor a (inline exp)
    JsSetPropExtern a b exp -> JsSetPropExtern a b (inline exp)
    JsContinue              -> JsContinue
    JsBlock stmts           -> JsBlock (map go stmts)
    JsExpStmt exp           -> JsExpStmt (inline exp)

  inline expr = case expr of
    -- Optimizations
    JsApp op args -> fromMaybe (JsApp (inline op) $ map inline args) (flatten expr)

    -- Plumbing
    JsFun nm names stmts mexp        -> JsFun nm names (map go stmts) (fmap inline mexp)

    JsNegApp exp                     -> JsNegApp (inline exp)
    JsTernaryIf exp1 exp2 exp3       -> JsTernaryIf (inline exp1) (inline exp2) (inline exp3)
    JsParen exp                      -> JsParen (inline exp)
    JsGetProp exp name               -> JsGetProp (inline exp) name
    JsLookup exp exp2                -> JsLookup (inline exp) (inline exp2)
    JsUpdateProp exp name exp2       -> JsUpdateProp (inline exp) name (inline exp2)
    JsGetPropExtern exp string       -> JsGetPropExtern (inline exp) string
    JsUpdatePropExtern exp name exp2 -> JsUpdatePropExtern (inline exp) name (inline exp2)
    JsList exps                      -> JsList (map inline exps)
    JsNew name exps                  -> JsNew name (map inline exps)
    JsThrowExp exp                   -> JsThrowExp (inline exp)
    JsInstanceOf exp name            -> JsInstanceOf (inline exp) name
    JsIndex i exp                    -> JsIndex i (inline exp)
    JsEq exp exp2                    -> JsEq (inline exp) (inline exp2)
    JsNeq exp exp2                   -> JsNeq (inline exp) (inline exp2)
    JsInfix string exp exp2          -> JsInfix string (inline exp) (inline exp2)
    JsObj keyvals                    -> JsObj keyvals
    rest                             -> rest

-- | Flatten a a>>(b>>c) to [a,b,c].
flatten :: JsExp -> Maybe JsExp
flatten exp = case collect exp of
  Just (stmts@(_:_:_)) -> let s = reverse stmts
                          in Just $ thunk (JsSeq (map force (init s) ++ [last s]))
  _ -> Nothing

-- | Try to collect nested a>>(b>>c).
collect :: JsExp -> Maybe [JsExp]
collect exp = case exp of
  JsApp op args | isThen op ->
    case args of
      [rest,x] -> (x :) <$> collect rest
      [x]  -> return [x]
      _ -> Nothing
  _ -> return [exp]

  where
    isThen (JsName (JsNameVar (Qual _ (ModuleName _ m) (Ident _ n)))) = m == "Fay$" && n == "then$uncurried"
    isThen _ = False


-- | Perform any top-level cross-module optimizations and GO DEEP to
-- optimize further.
optimizeToplevel :: [JsStmt] -> Optimize [JsStmt]
optimizeToplevel = stripAndUncurry

-- | Perform tail-call optimization.
tco :: [JsStmt] -> [JsStmt]
tco = map inStmt where
  inStmt stmt = case stmt of
    JsVar name exp -> JsVar name (inject name exp)
    JsSetQName l name exp -> JsSetQName l name (inject (JsNameVar name) exp)
    e -> e
  inject name exp = case exp of
    JsFun nm params [] (Just (JsNew JsThunk [JsFun _ [] stmts ret])) ->
      JsFun nm params
            []
            (Just
              (JsNew JsThunk
                     [JsFun Nothing []
                            (optimize params name (stmts ++ [ JsEarlyReturn e | Just e <- [ret] ]))
                            Nothing]))
    _ -> exp
  optimize params name stmts = result where
    result = let (newstmts,w) = runWriter makeWhile
             in if null w
                   then stmts
                   else newstmts
    makeWhile = do
      newstmts <- fmap concat (mapM swap stmts)
      return [JsWhile (JsLit (JsBool True)) newstmts]
    swap stmt = case stmt of
      JsEarlyReturn e
        | tailCall e -> do tell [()]
                           return (rebind e ++ [JsContinue])
        | otherwise  -> return [stmt]
      JsIf p ithen ielse -> do
        newithen <- fmap concat (mapM swap ithen)
        newielse <- fmap concat (mapM swap ielse)
        return [JsIf p newithen newielse]
      e -> return [e]
    tailCall (JsApp (JsName cname) _) = cname == name
    tailCall _ = False
    rebind (JsApp _ args) = zipWith go args params where
      go arg param = JsUpdate param arg
    rebind e = error . show $ e

-- | Strip redundant forcing from the whole generated code.
stripAndUncurry :: [JsStmt] -> Optimize [JsStmt]
stripAndUncurry = applyToExpsInStmts stripFuncForces where
  stripFuncForces arities exp = case exp of
    JsApp (JsName JsForce) [JsName (JsNameVar f)]
      | Just _ <- lookup f arities -> return (JsName (JsNameVar f))
    JsFun nm ps stmts body         -> do substmts <- mapM stripInStmt stmts
                                         sbody <- maybe (return Nothing) (fmap Just . go) body
                                         return (JsFun nm ps substmts sbody)
    JsApp a b                      -> do
      result <- walkAndStripForces arities exp
      case result of
        Just strippedExp           -> go strippedExp
        Nothing                    -> JsApp <$> go a <*> mapM go b
    JsNegApp e                     -> JsNegApp <$> go e
    JsTernaryIf a b c              -> JsTernaryIf <$> go a <*> go b <*> go c
    JsParen e                      -> JsParen <$> go e
    JsUpdateProp e n a             -> JsUpdateProp <$> go e <*> pure n <*> go a
    JsList xs                      -> JsList <$> mapM go xs
    JsEq a b                       -> JsEq <$> go a <*> go b
    JsInfix op a b                 -> JsInfix op <$> go a <*> go b
    JsObj xs                       -> JsObj <$> mapM (\(x,y) -> (x,) <$> go y) xs
    JsNew name xs                  -> JsNew name <$> mapM go xs
    e                              -> return e

    where
      go = stripFuncForces arities
      stripInStmt = applyToExpsInStmt arities stripFuncForces

-- | Strip redundant forcing from an application if possible.
walkAndStripForces :: [FuncArity] -> JsExp -> Optimize (Maybe JsExp)
walkAndStripForces arities = go True [] where
  go frst args app = case app of
    JsApp (JsName JsForce) [e] ->
      if frst
        then do
          result <- go False args e
          case result of
            Nothing -> return Nothing
            Just ex -> return (Just (JsApp (JsName JsForce) [ex]))
        else go False args e
    JsApp op [arg] -> go False (arg:args) op
    JsName (JsNameVar f)
      | Just arity <- lookup f arities, length args == arity -> do
        modify $ \s -> s { optUncurry = f : optUncurry s }
        return (Just (JsApp (JsName (JsNameVar (renameUncurried f))) args))
    _ -> return Nothing

-- | Apply the given function to the top-level expressions in the
-- given statements.
applyToExpsInStmts :: ([FuncArity] -> JsExp -> Optimize JsExp) -> [JsStmt] -> Optimize [JsStmt]
applyToExpsInStmts f stmts = mapM (applyToExpsInStmt (collectFuncs stmts) f) stmts

-- | Apply the given function to the top-level expressions in the
-- given statement.
applyToExpsInStmt :: [FuncArity] -> ([FuncArity] -> JsExp -> Optimize JsExp) -> JsStmt -> Optimize JsStmt
applyToExpsInStmt funcs f = uncurryInStmt where
  transform = f funcs
  uncurryInStmt stmt = case stmt of
    JsVar name exp              -> JsVar name <$> transform exp
    JsSetQName l name exp       -> JsSetQName l name <$> transform exp
    JsEarlyReturn exp           -> JsEarlyReturn <$> transform exp
    JsIf op ithen ielse         -> JsIf <$> transform op
                                        <*> mapM uncurryInStmt ithen
                                        <*> mapM uncurryInStmt ielse
    s -> pure s

-- | Collect functions and their arity from the whole codeset.
collectFuncs :: [JsStmt] -> [FuncArity]
collectFuncs = (++ prim) . concatMap collectFunc where
  collectFunc (JsSetQName _ name exp) | arity > 0 = [(name,arity)]
    where arity = expArity exp
  collectFunc _ = []
  prim = map (first (Qual () (ModuleName () "Fay$"))) (unary ++ binary)
  unary = map (,1) [Ident () "return"]
  binary = map ((,2) . Ident ())
               ["then","bind","mult","mult","add","sub","div"
               ,"eq","neq","gt","lt","gte","lte","and","or"]

-- | Get the arity of an expression.
expArity :: JsExp -> Int
expArity (JsFun _ _ _ mexp) = 1 + maybe 0 expArity mexp
expArity _ = 0

-- | Change foo(x)(y) to foo$uncurried(x,y).
uncurryBinding :: [JsStmt] -> N.QName -> Maybe JsStmt
uncurryBinding stmts qname = listToMaybe (mapMaybe funBinding stmts)
  where
    funBinding stmt = case stmt of
      JsVar (JsNameVar name) body
        | name == qname -> JsVar (JsNameVar (renameUncurried name)) <$> uncurryIt body
      JsSetQName l name body
        | name == qname -> JsSetQName l (renameUncurried name) <$> uncurryIt body
      _ -> Nothing

    uncurryIt = Just . go [] where
      go args exp = case exp of
        JsFun _ [arg] [] (Just body) -> go (arg : args) body
        inner -> JsFun Nothing (reverse args) [] (Just inner)

-- | Rename an uncurried copy of a curried function.
renameUncurried :: N.QName -> N.QName
renameUncurried q = case q of
  Qual _ m n -> Qual () m (renameUnQual n)
  UnQual _ n -> UnQual () (renameUnQual n)
  s -> s
  where
    renameUnQual n = case n of
      Ident _ nom -> Ident () (nom ++ postfix)
      Symbol _ nom -> Symbol () (nom ++ postfix)
    postfix = "$uncurried"