module Lambdabot.Pointful (pointful) where
import Lambdabot.Parser (withParsed, prettyPrintInLine)
import Control.Monad.Reader
import Control.Monad.State
import Data.Functor.Identity (Identity)
import Data.Generics
import qualified Data.Set as S
import qualified Data.Map as M
import Data.List
import Data.Maybe
import Language.Haskell.Exts.Simple as Hs
stabilize :: Eq a => (a -> a) -> a -> a
stabilize f x = let x' = f x in if x' == x then x else stabilize f x'
varsBoundHere :: Data d => d -> S.Set Name
varsBoundHere (cast -> Just (PVar name)) = S.singleton name
varsBoundHere (cast -> Just (Match name _ _ _)) = S.singleton name
varsBoundHere (cast -> Just (PatBind pat _ _)) = varsBoundHere pat
varsBoundHere (cast -> Just (_ :: Exp)) = S.empty
varsBoundHere d = S.unions (gmapQ varsBoundHere d)
foldFreeVars :: forall a d. Data d => (Name -> S.Set Name -> a) -> ([a] -> a) -> d -> a
foldFreeVars var sum e = runReader (go e) S.empty where
go :: forall d. Data d => d -> Reader (S.Set Name) a
go (cast -> Just (Var (UnQual name))) =
asks (var name)
go (cast -> Just (Lambda ps exp)) =
bind [varsBoundHere ps] $ go exp
go (cast -> Just (Let bs exp)) =
bind [varsBoundHere bs] $ collect [go bs, go exp]
go (cast -> Just (Alt pat exp bs)) =
bind [varsBoundHere pat, varsBoundHere bs] $ collect [go exp, go bs]
go (cast -> Just (PatBind pat exp bs)) =
bind [varsBoundHere pat, varsBoundHere bs] $ collect [go exp, go bs]
go (cast -> Just (Match _ ps exp bs)) =
bind [varsBoundHere ps, varsBoundHere bs] $ collect [go exp, go bs]
go d = collect (gmapQ go d)
collect :: forall m. Monad m => [m a] -> m a
collect ms = sum `liftM` sequence ms
bind :: forall a b. Ord a => [S.Set a] -> Reader (S.Set a) b -> Reader (S.Set a) b
bind ss = local (S.unions ss `S.union`)
freeVars :: Data d => d -> S.Set Name
freeVars = foldFreeVars (\name bv -> S.singleton name `S.difference` bv) S.unions
countOcc :: Data d => Name -> d -> Int
countOcc name = foldFreeVars var sum where
sum = foldl' (+) 0
var name' bv = if name /= name' || name' `S.member` bv then 0 else 1
substAvoiding :: Data d => M.Map Name Exp -> S.Set Name -> d -> d
substAvoiding subst bv = base `extT` exp `extT` alt `extT` decl `extT` match where
base :: Data d => d -> d
base = gmapT (substAvoiding subst bv)
exp e@(Var (UnQual name)) =
fromMaybe e (M.lookup name subst)
exp (Lambda ps exp) =
let (subst', bv', ps') = renameBinds subst bv ps
in Lambda ps' (substAvoiding subst' bv' exp)
exp (Let bs exp) =
let (subst', bv', bs') = renameBinds subst bv bs
in Let (substAvoiding subst' bv' bs') (substAvoiding subst' bv' exp)
exp d = base d
alt (Alt pat exp bs) =
let (subst1, bv1, pat') = renameBinds subst bv pat
(subst', bv', bs') = renameBinds subst1 bv1 bs
in Alt pat' (substAvoiding subst' bv' exp) (substAvoiding subst' bv' bs')
decl (PatBind pat exp bs) =
let (subst', bv', bs') = renameBinds subst bv bs in
PatBind pat (substAvoiding subst' bv' exp) (substAvoiding subst' bv' bs')
decl d = base d
match (Match name ps exp bs) =
let (subst1, bv1, ps') = renameBinds subst bv ps
(subst', bv', bs') = renameBinds subst1 bv1 bs
in Match name ps' (substAvoiding subst' bv' exp) (substAvoiding subst' bv' bs')
renameBinds :: Data d => M.Map Name Exp -> S.Set Name -> d -> (M.Map Name Exp, S.Set Name, d)
renameBinds subst bv d = (subst', bv', d') where
(d', (subst', bv', _)) = runState (go d) (subst, bv, M.empty)
go, base :: Data d => d -> State (M.Map Name Exp, S.Set Name, M.Map Name Name) d
go = base `extM` pat `extM` match `extM` decl `extM` exp
base d = gmapM go d
pat (PVar name) = PVar `fmap` rename name
pat d = base d
match (Match name ps exp bs) = do
name' <- rename name
return $ Match name' ps exp bs
decl (PatBind pat exp bs) = do
pat' <- go pat
return $ PatBind pat' exp bs
decl d = base d
exp (e :: Exp) = return e
rename :: Name -> State (M.Map Name Exp, S.Set Name, M.Map Name Name) Name
rename name = do
(subst, bv, ass) <- get
case (name `M.lookup` ass, name `S.member` bv) of
(Just name', _) -> do
return name'
(_, False) -> do
put (M.delete name subst, S.insert name bv, ass)
return name
_ -> do
let name' = freshNameAvoiding name bv
put (M.insert name (Var (UnQual name')) subst,
S.insert name' bv, M.insert name name' ass)
return name'
freshNameAvoiding :: Name -> S.Set Name -> Name
freshNameAvoiding name forbidden = con (pre ++ suf) where
(con, nm, cs) = case name of
Ident n -> (Ident, n, "0123456789")
Symbol n -> (Symbol, n, "?#")
pre = reverse . dropWhile (`elem` cs) . reverse $ nm
sufs = [1..] >>= flip replicateM cs
suf = head $ dropWhile (\suf -> con (pre ++ suf) `S.member` forbidden) sufs
optimizeD :: Decl -> Decl
optimizeD (PatBind (PVar fname) (UnGuardedRhs (Lambda pats rhs)) Nothing) =
let (subst, bv, pats') = renameBinds M.empty (S.singleton fname) pats
rhs' = substAvoiding subst bv rhs
in FunBind [Match fname pats' (UnGuardedRhs rhs') Nothing]
optimizeD (FunBind [Match fname pats1 (UnGuardedRhs (Lambda pats2 rhs)) Nothing]) =
let (subst, bv, pats2') = renameBinds M.empty (varsBoundHere pats1) pats2
rhs' = substAvoiding subst bv rhs
in FunBind [Match fname (pats1 ++ pats2') (UnGuardedRhs rhs') Nothing]
optimizeD x = x
optimizeRhs :: Rhs -> Rhs
optimizeRhs (UnGuardedRhs (Paren x)) = UnGuardedRhs x
optimizeRhs x = x
optimizeE :: Exp -> Exp
optimizeE (App (Lambda (PVar ident : pats) body) arg) | single || simple arg =
let (subst, bv, pats') = renameBinds (M.singleton ident arg) (freeVars arg) pats
in Paren (Lambda pats' (substAvoiding subst bv body))
where
single = countOcc ident body <= 1
simple e = case e of Var _ -> True; Lit _ -> True; Paren e' -> simple e'; _ -> False
optimizeE (App (Lambda (PWildCard : pats) body) _) =
Paren (Lambda pats body)
optimizeE (Lambda [] b) =
b
optimizeE (Lambda p1 (Lambda p2 body)) =
let (subst, bv, p2') = renameBinds M.empty (varsBoundHere p1) p2
body' = substAvoiding subst bv body
in Lambda (p1 ++ p2') body'
optimizeE (Paren (Paren x)) =
Paren x
optimizeE (App (Paren (x@Lambda{})) y) =
App x y
optimizeE (Lambda p (Paren x)) =
Lambda p x
optimizeE (Paren x@(Var _)) =
x
optimizeE (Paren x@(Lit _)) =
x
optimizeE (InfixApp a o (Paren l@(Lambda _ _))) =
InfixApp a o l
optimizeE (InfixApp (Paren a@App{}) o l) =
InfixApp a o l
optimizeE (InfixApp a o (Paren l@App{})) =
InfixApp a o l
optimizeE (App (Paren (App a b)) c) =
App (App a b) c
optimizeE (App (App (Var name'@(UnQual (Symbol _))) l) r) =
(InfixApp l (QVarOp name') r)
optimizeE (Lambda ps@(_:_) (App e (Var (UnQual v))))
| free && last ps == PVar v = Lambda (init ps) e
where free = countOcc v e == 0
optimizeE x = x
uncomb' :: Exp -> Exp
uncomb' (Paren (Paren e)) = Paren e
uncomb' (RightSection op' arg) =
let a = freshNameAvoiding (Ident "a") (freeVars arg)
in (Paren (Lambda [PVar a] (InfixApp (Var (UnQual a)) op' arg)))
uncomb' (LeftSection arg op') =
let a = freshNameAvoiding (Ident "a") (freeVars arg)
in (Paren (Lambda [PVar a] (InfixApp arg op' (Var (UnQual a)))))
uncomb' (InfixApp lf (QVarOp name') rf) =
(Paren (App (App (Var name') (Paren lf)) (Paren rf)))
uncomb' (App (Var (UnQual (Symbol ">>="))) (Paren lam@Lambda{})) =
let a = freshNameAvoiding (Ident "a") (freeVars lam)
b = freshNameAvoiding (Ident "b") (freeVars lam)
in (Paren (Lambda [PVar a, PVar b]
(App (App (Var (UnQual a)) (Paren (App lam (Var (UnQual b))))) (Var (UnQual b)))))
uncomb' (App (App (Var (UnQual (Symbol ">>="))) e1) (Paren lam@(Lambda (_:_:_) _))) =
let a = freshNameAvoiding (Ident "a") (freeVars [e1,lam])
in (Paren (Lambda [PVar a]
(App (App lam (App e1 (Var (UnQual a)))) (Var (UnQual a)))))
uncomb' expr = expr
combinators :: M.Map Name Exp
combinators = M.fromList $ map declToTuple defs
where defs = case parseModule combinatorModule of
ParseOk (Hs.Module _ _ _ d) -> d
f@(ParseFailed _ _) -> error ("Combinator loading: " ++ show f)
declToTuple (PatBind (PVar fname) (UnGuardedRhs body) Nothing)
= (fname, Paren body)
declToTuple _ = error "Pointful Plugin error: can't convert declaration to tuple"
combinatorModule :: String
combinatorModule = unlines [
"(.) = \\f g x -> f (g x) ",
"($) = \\f x -> f x ",
"flip = \\f x y -> f y x ",
"const = \\x _ -> x ",
"id = \\x -> x ",
"(=<<) = flip (>>=) ",
"liftM2 = \\f m1 m2 -> m1 >>= \\x1 -> m2 >>= \\x2 -> return (f x1 x2) ",
"join = (>>= id) ",
"ap = liftM2 id ",
"(>=>) = flip (<=<) ",
"(<=<) = \\f g x -> f >>= g x ",
" ",
"-- ASSUMED reader monad ",
"-- (>>=) = (\\f k r -> k (f r) r) ",
"-- return = const ",
""]
unfoldCombinators :: (Data a) => a -> a
unfoldCombinators = substAvoiding combinators (freeVars combinators)
uncombOnce :: (Data a) => a -> a
uncombOnce x = everywhere (mkT uncomb') x
uncomb :: (Eq a, Data a) => a -> a
uncomb = stabilize uncombOnce
optimizeOnce :: (Data a) => a -> a
optimizeOnce x = everywhere (mkT optimizeD `extT` optimizeRhs `extT` optimizeE) x
optimize :: (Eq a, Data a) => a -> a
optimize = stabilize optimizeOnce
pointful :: String -> String
pointful = withParsed (stabilize (optimize . uncomb) . stabilize (unfoldCombinators . uncomb))