{-# LANGUAGE PatternSynonyms #-}
module Lambdabot.Plugin.Haskell.Pretty (prettyPlugin) where
import Lambdabot.Plugin
import Data.List
import qualified Language.Haskell.Exts.Simple as Hs
import Language.Haskell.Exts.Simple hiding (Module, Pretty)
type Pretty = ModuleT () LB
prettyPlugin :: Module ()
prettyPlugin = newModule
{ moduleCmds = return
[ (command "pretty")
{ help = say "pretty <expr>. Display haskell code in a pretty-printed manner"
, process = prettyCmd
}
]
}
prettyCmd :: String -> Cmd Pretty ()
prettyCmd rest =
let code = dropWhile (`elem` " \t>") rest
modPrefix1 = "module Main where "
modPrefix2 = "module Main where __expr__ = "
prefLen1 = length modPrefix1
result = case (parseModule (modPrefix1 ++ code ++ "\n"), parseModule (modPrefix2 ++ code ++ "\n")) of
(ParseOk a, _) -> doPretty a
(_, ParseOk a) -> doPretty a
(ParseFailed locat msg,_) -> let (SrcLoc _ _ col) = locat in
(show msg ++ " at column " ++ show (col - prefLen1)) : []
in mapM_ say result
doPretty :: Hs.Module -> [String]
doPretty (Hs.Module _ _ _ decls) =
let defaultLen = 4
declLen (FunBind mtches) = maximum $ map matchLen mtches
declLen (PatBind pat _ _) = patLen pat
declLen _ = defaultLen
patLen (PVar nm) = nameLen nm
patLen _ = defaultLen
nameLen (Ident s) = length s + 1
nameLen _ = defaultLen
matchLen (Match nm pats _ _) =
let l = (nameLen nm + sum (map patLen pats) + 1)
in if l > 16 then defaultLen else l
makeMode decl = defaultMode {
doIndent = 3,
caseIndent = 4,
onsideIndent = declLen decl
}
makeModeExp _ = defaultMode {
doIndent = 3,
caseIndent = 4,
onsideIndent = 0
}
prettyDecl (PatBind (PVar (Ident "__expr__")) (UnGuardedRhs e) Nothing)
= prettyPrintWithMode (makeModeExp e) e
prettyDecl d = prettyPrintWithMode (makeMode d) d
in map (" "++) . lines . concat . intersperse "\n"
. map prettyDecl $ decls