{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ViewPatterns #-}
-- Undo pointfree transformations. Plugin code derived from Pl.hs.
module Lambdabot.Pointful (pointful) where

import Lambdabot.Parser (withParsed, prettyPrintInLine)

import Control.Monad.Reader
import Control.Monad.State
import Data.Functor.Identity (Identity)
import Data.Generics
import qualified Data.Set as S
import qualified Data.Map as M
import Data.List
import Data.Maybe
import Language.Haskell.Exts.Simple as Hs

---- Utilities ----

stabilize :: Eq a => (a -> a) -> a -> a
stabilize f x = let x' = f x in if x' == x then x else stabilize f x'

-- varsBoundHere returns variables bound by top patterns or binders
varsBoundHere :: Data d => d -> S.Set Name
varsBoundHere (cast -> Just (PVar name)) = S.singleton name
varsBoundHere (cast -> Just (Match name _ _ _)) = S.singleton name
varsBoundHere (cast -> Just (PatBind pat _ _)) = varsBoundHere pat
varsBoundHere (cast -> Just (_ :: Exp)) = S.empty
varsBoundHere d = S.unions (gmapQ varsBoundHere d)

-- note: the tempting idea of using a pattern synonym for the frequent
-- (cast -> Just _) patterns causes compiler crashes with ghc before
-- version 8; cf. https://ghc.haskell.org/trac/ghc/ticket/11336

foldFreeVars :: forall a d. Data d => (Name -> S.Set Name -> a) -> ([a] -> a) -> d -> a
foldFreeVars var sum e = runReader (go e) S.empty where
    go :: forall d. Data d => d -> Reader (S.Set Name) a
    go (cast -> Just (Var (UnQual name))) =
        asks (var name)
    go (cast -> Just (Lambda ps exp)) =
        bind [varsBoundHere ps] $ go exp
    go (cast -> Just (Let bs exp)) =
        bind [varsBoundHere bs] $ collect [go bs, go exp]
    go (cast -> Just (Alt pat exp bs)) =
        bind [varsBoundHere pat, varsBoundHere bs] $ collect [go exp, go bs]
    go (cast -> Just (PatBind pat exp bs)) =
        bind [varsBoundHere pat, varsBoundHere bs] $ collect [go exp, go bs]
    go (cast -> Just (Match _ ps exp bs)) =
        bind [varsBoundHere ps, varsBoundHere bs] $ collect [go exp, go bs]
    go d = collect (gmapQ go d)

    collect :: forall m. Monad m => [m a] -> m a
    collect ms = sum `liftM` sequence ms

    bind :: forall a b. Ord a => [S.Set a] -> Reader (S.Set a) b -> Reader (S.Set a) b
    bind ss = local (S.unions ss `S.union`)

-- return free variables
freeVars :: Data d => d -> S.Set Name
freeVars = foldFreeVars (\name bv -> S.singleton name `S.difference` bv) S.unions

-- return number of free occurrences of a variable
countOcc :: Data d => Name -> d -> Int
countOcc name = foldFreeVars var sum where
    sum = foldl' (+) 0
    var name' bv = if name /= name' || name' `S.member` bv then 0 else 1

-- variable capture avoiding substitution
substAvoiding :: Data d => M.Map Name Exp -> S.Set Name -> d -> d
substAvoiding subst bv = base `extT` exp `extT` alt `extT` decl `extT` match where
    base :: Data d => d -> d
    base = gmapT (substAvoiding subst bv)

    exp e@(Var (UnQual name)) =
        fromMaybe e (M.lookup name subst)
    exp (Lambda ps exp) =
        let (subst', bv', ps') = renameBinds subst bv ps
        in  Lambda ps' (substAvoiding subst' bv' exp)
    exp (Let bs exp) =
        let (subst', bv', bs') = renameBinds subst bv bs
        in  Let (substAvoiding subst' bv' bs') (substAvoiding subst' bv' exp)
    exp d = base d

    alt (Alt pat exp bs) =
        let (subst1, bv1, pat') = renameBinds subst bv pat
            (subst', bv', bs') = renameBinds subst1 bv1 bs
        in  Alt pat' (substAvoiding subst' bv' exp) (substAvoiding subst' bv' bs')

    decl (PatBind pat exp bs) =
        let (subst', bv', bs') = renameBinds subst bv bs in
        PatBind pat (substAvoiding subst' bv' exp) (substAvoiding subst' bv' bs')
    decl d = base d

    match (Match name ps exp bs) =
        let (subst1, bv1, ps') = renameBinds subst bv ps
            (subst', bv', bs') = renameBinds subst1 bv1 bs
        in  Match name ps' (substAvoiding subst' bv' exp) (substAvoiding subst' bv' bs')

-- rename local binders (but not the nested expressions)
renameBinds :: Data d => M.Map Name Exp -> S.Set Name -> d -> (M.Map Name Exp, S.Set Name, d)
renameBinds subst bv d = (subst', bv', d') where
    (d', (subst', bv', _)) = runState (go d) (subst, bv, M.empty)

    go, base :: Data d => d -> State (M.Map Name Exp, S.Set Name, M.Map Name Name) d
    go = base `extM` pat `extM` match `extM` decl `extM` exp
    base d = gmapM go d

    pat (PVar name) = PVar `fmap` rename name
    pat d = base d

    match (Match name ps exp bs) = do
        name' <- rename name
        return $ Match name' ps exp bs

    decl (PatBind pat exp bs) = do
        pat' <- go pat
        return $ PatBind pat' exp bs
    decl d = base d

    exp (e :: Exp) = return e

    rename :: Name -> State (M.Map Name Exp, S.Set Name, M.Map Name Name) Name
    rename name = do
        (subst, bv, ass) <- get
        case (name `M.lookup` ass, name `S.member` bv) of
            (Just name', _) -> do
                 return name'
            (_, False) -> do
                put (M.delete name subst, S.insert name bv, ass)
                return name
            _ -> do
                let name' = freshNameAvoiding name bv
                put (M.insert name (Var (UnQual name')) subst,
                     S.insert name' bv, M.insert name name' ass)
                return name'

-- generate fresh names
freshNameAvoiding :: Name -> S.Set Name -> Name
freshNameAvoiding name forbidden  = con (pre ++ suf) where
    (con, nm, cs) = case name of
         Ident  n -> (Ident,  n, "0123456789")
         Symbol n -> (Symbol, n, "?#")
    pre = reverse . dropWhile (`elem` cs) . reverse $ nm
    sufs = [1..] >>= flip replicateM cs
    suf = head $ dropWhile (\suf -> con (pre ++ suf) `S.member` forbidden) sufs

---- Optimization (removing explicit lambdas) and restoration of infix ops ----

-- move lambda patterns into LHS
optimizeD :: Decl -> Decl
optimizeD (PatBind (PVar fname) (UnGuardedRhs (Lambda pats rhs)) Nothing) =
    let (subst, bv, pats') = renameBinds M.empty (S.singleton fname) pats
        rhs' = substAvoiding subst bv rhs
    in  FunBind [Match fname pats' (UnGuardedRhs rhs') Nothing]
---- combine function binding and lambda
optimizeD (FunBind [Match fname pats1 (UnGuardedRhs (Lambda pats2 rhs)) Nothing]) =
    let (subst, bv, pats2') = renameBinds M.empty (varsBoundHere pats1) pats2
        rhs' = substAvoiding subst bv rhs
    in  FunBind [Match fname (pats1 ++ pats2') (UnGuardedRhs rhs') Nothing]
optimizeD x = x

-- remove parens
optimizeRhs :: Rhs -> Rhs
optimizeRhs (UnGuardedRhs (Paren x)) = UnGuardedRhs x
optimizeRhs x = x

optimizeE :: Exp -> Exp
-- apply ((\x z -> ...x...) y) yielding (\z -> ...y...) if there is only one x or y is simple
optimizeE (App (Lambda (PVar ident : pats) body) arg) | single || simple arg =
     let (subst, bv, pats') = renameBinds (M.singleton ident arg) (freeVars arg) pats
     in  Paren (Lambda pats' (substAvoiding subst bv body))
  where
    single = countOcc ident body <= 1
    simple e = case e of Var _ -> True; Lit _ -> True; Paren e' -> simple e'; _ -> False
-- apply ((\_ z -> ...) y) yielding (\z -> ...)
optimizeE (App (Lambda (PWildCard : pats) body) _) =
    Paren (Lambda pats body)
-- remove 0-arg lambdas resulting from application rules
optimizeE (Lambda [] b) =
    b
-- replace (\x -> \y -> z) with (\x y -> z)
optimizeE (Lambda p1 (Lambda p2 body)) =
    let (subst, bv, p2') = renameBinds M.empty (varsBoundHere p1) p2
        body' = substAvoiding subst bv body
    in  Lambda (p1 ++ p2') body'
-- remove double parens
optimizeE (Paren (Paren x)) =
    Paren x
-- remove parens around applied lambdas (the pretty printer restores them)
optimizeE (App (Paren (x@Lambda{})) y) =
    App x y
-- remove lambda body parens
optimizeE (Lambda p (Paren x)) =
    Lambda p x
-- remove var, lit parens
optimizeE (Paren x@(Var _)) =
    x
optimizeE (Paren x@(Lit _)) =
    x
-- remove infix+lambda parens
optimizeE (InfixApp a o (Paren l@(Lambda _ _))) =
    InfixApp a o l
-- remove infix+app aprens
optimizeE (InfixApp (Paren a@App{}) o l) =
    InfixApp a o l
optimizeE (InfixApp a o (Paren l@App{})) =
    InfixApp a o l
-- remove left-assoc application parens
optimizeE (App (Paren (App a b)) c) =
    App (App a b) c
-- restore infix
optimizeE (App (App (Var name'@(UnQual (Symbol _))) l) r) =
    (InfixApp l (QVarOp name') r)
-- eta reduce
optimizeE (Lambda ps@(_:_) (App e (Var (UnQual v))))
    | free && last ps == PVar v = Lambda (init ps) e
  where free = countOcc v e == 0
-- fail
optimizeE x = x

---- Decombinatorization ----

uncomb' :: Exp -> Exp

uncomb' (Paren (Paren e)) = Paren e

-- eliminate sections
uncomb' (RightSection op' arg) =
    let a = freshNameAvoiding (Ident "a") (freeVars arg)
    in  (Paren (Lambda [PVar a] (InfixApp (Var (UnQual a)) op' arg)))
uncomb' (LeftSection arg op') =
    let a = freshNameAvoiding (Ident "a") (freeVars arg)
    in  (Paren (Lambda [PVar a] (InfixApp arg op' (Var (UnQual a)))))
-- infix to prefix for canonicality
uncomb' (InfixApp lf (QVarOp name') rf) =
    (Paren (App (App (Var name') (Paren lf)) (Paren rf)))

-- Expand (>>=) when it is obviously the reader monad:

-- rewrite: (>>=) (\x -> e)
-- to:      (\ a b -> a ((\ x -> e) b) b)
uncomb' (App (Var (UnQual (Symbol ">>="))) (Paren lam@Lambda{})) =
   let a = freshNameAvoiding (Ident "a") (freeVars lam)
       b = freshNameAvoiding (Ident "b") (freeVars lam)
   in  (Paren (Lambda [PVar a, PVar b]
           (App (App (Var (UnQual a)) (Paren (App lam (Var (UnQual b))))) (Var (UnQual b)))))
-- rewrite: ((>>=) e1) (\x y -> e2)
-- to:      (\a -> (\x y -> e2) (e1 a) a)
uncomb' (App (App (Var (UnQual (Symbol ">>="))) e1) (Paren lam@(Lambda (_:_:_) _))) =
    let a = freshNameAvoiding (Ident "a") (freeVars [e1,lam])
    in  (Paren (Lambda [PVar a]
            (App (App lam (App e1 (Var (UnQual a)))) (Var (UnQual a)))))

-- fail
uncomb' expr = expr

---- Simple combinator definitions ---
combinators :: M.Map Name Exp
combinators = M.fromList $ map declToTuple defs
  where defs = case parseModule combinatorModule of
          ParseOk (Hs.Module _ _ _ d) -> d
          f@(ParseFailed _ _) -> error ("Combinator loading: " ++ show f)
        declToTuple (PatBind (PVar fname) (UnGuardedRhs body) Nothing)
          = (fname, Paren body)
        declToTuple _ = error "Pointful Plugin error: can't convert declaration to tuple"

combinatorModule :: String
combinatorModule = unlines [
  "(.)    = \\f g x -> f (g x)                                          ",
  "($)    = \\f x   -> f x                                              ",
  "flip   = \\f x y -> f y x                                            ",
  "const  = \\x _ -> x                                                  ",
  "id     = \\x -> x                                                    ",
  "(=<<)  = flip (>>=)                                                  ",
  "liftM2 = \\f m1 m2 -> m1 >>= \\x1 -> m2 >>= \\x2 -> return (f x1 x2) ",
  "join   = (>>= id)                                                    ",
  "ap     = liftM2 id                                                   ",
  "(>=>)  = flip (<=<)                                                  ",
  "(<=<)  = \\f g x -> f >>= g x                                        ",
  "                                                                     ",
  "-- ASSUMED reader monad                                              ",
  "-- (>>=)  = (\\f k r -> k (f r) r)                                   ",
  "-- return = const                                                    ",
  ""]

---- Top level ----

unfoldCombinators :: (Data a) => a -> a
unfoldCombinators = substAvoiding combinators (freeVars combinators)

uncombOnce :: (Data a) => a -> a
uncombOnce x = everywhere (mkT uncomb') x
uncomb :: (Eq a, Data a) => a -> a
uncomb = stabilize uncombOnce

optimizeOnce :: (Data a) => a -> a
optimizeOnce x = everywhere (mkT optimizeD `extT` optimizeRhs `extT` optimizeE) x
optimize :: (Eq a, Data a) => a -> a
optimize = stabilize optimizeOnce

pointful :: String -> String
pointful = withParsed (stabilize (optimize . uncomb) . stabilize (unfoldCombinators . uncomb))

-- TODO: merge this into a proper test suite once one exists
-- test s = case parseModule s of
--   f@(ParseFailed _ _) -> fail (show f)
--   ParseOk (Hs.Module _ _ _ _ _ _ defs) ->
--     flip mapM_ defs $ \def -> do
--       putStrLn . prettyPrintInLine  $ def
--       putStrLn . prettyPrintInLine  . uncomb $ def
--       putStrLn . prettyPrintInLine  . optimize . uncomb $ def
--       putStrLn . prettyPrintInLine  . stabilize (optimize . uncomb) $ def
--       putStrLn ""
--
-- main = test "f = tail . head; g = head . tail; h = tail + tail; three = g . h . i; dontSub = (\\x -> x + x) 1; ofHead f = f . head; fm = flip mapM_ xs (\\x -> g x); po = (+1); op = (1+); g = (. f); stabilize = fix (ap . flip (ap . (flip =<< (if' .) . (==))) =<<)"
--