module Quipper.Utils.Template.LiftQ where
import qualified Language.Haskell.TH as TH
import qualified Data.Map as Map
import qualified Data.Set as Set
import qualified Data.List as List
import Language.Haskell.TH (Name)
import Control.Monad.State
import Data.Map (Map)
import Data.Set (Set)
import Control.Applicative (Applicative(..))
import Control.Monad (liftM, ap)
import qualified Quipper.Utils.Template.ErrorMsgQ as Err
import Quipper.Utils.Template.ErrorMsgQ (ErrMsgQ)
data LiftState = LiftState {
boundVar :: Map Name Int,
prefix :: Maybe String,
monadName :: Maybe String
}
emptyLiftState :: LiftState
emptyLiftState = LiftState {
boundVar = Map.empty,
prefix = Nothing,
monadName = Nothing
}
type LiftQState = StateT LiftState ErrMsgQ
data LiftQ a = LiftQ (LiftQState a)
instance Monad LiftQ where
return x = LiftQ $ return x
(>>=) (LiftQ x) f = LiftQ $ do
x' <- x
let (LiftQ y) = f x'
y
instance Applicative LiftQ where
pure = return
(<*>) = ap
instance Functor LiftQ where
fmap = liftM
getState :: LiftQ LiftState
getState = LiftQ $ mapStateT (\x -> do ((),s) <- x; return (s,s))
(return ())
setState :: LiftState -> LiftQ ()
setState s = LiftQ $ mapStateT (\_ -> return ((),s))
((return ()) :: LiftQState ())
embedErrMsgQ :: ErrMsgQ a -> LiftQ a
embedErrMsgQ q = LiftQ $ mapStateT (\x -> do ((),s) <- x; y <- q; return (y,s))
(return ())
embedQ :: TH.Q a -> LiftQ a
embedQ q = LiftQ $ mapStateT (\x -> do ((),s) <- x; y <- Err.embedQ q; return (y,s))
(return ())
extractQ :: String -> LiftQ a -> TH.Q a
extractQ s (LiftQ x) = Err.extractQ s $ evalStateT x emptyLiftState
errorMsg :: String -> LiftQ a
errorMsg s = embedErrMsgQ $ Err.errorMsg s
addToBoundVar :: Name -> LiftQ ()
addToBoundVar n = do
s <- getState
let new_value =
if (Map.member n $ boundVar s)
then 1 + ((boundVar s) Map.! n)
else 0
setState $ s { boundVar = Map.insert n new_value $ boundVar s }
removeFromBoundVar :: Name -> LiftQ ()
removeFromBoundVar n = do
s <- getState
if (not $ Map.member n $ boundVar s)
then errorMsg ((show n) ++ " is not a bound value")
else let old_value = (boundVar s) Map.! n in
if old_value == 0
then setState $ s { boundVar = Map.delete n $ boundVar s }
else setState $ s { boundVar = Map.insert n (old_value - 1) $ boundVar s }
withBoundVar :: Name -> LiftQ a -> LiftQ a
withBoundVar n comp = do
addToBoundVar n
a <- comp
removeFromBoundVar n
return a
withBoundVars :: [Name] -> LiftQ a -> LiftQ a
withBoundVars names comp = foldl (flip withBoundVar) comp names
isBoundVar :: Name -> LiftQ Bool
isBoundVar n = do
s <- getState
return $ Map.member n $ boundVar s
setPrefix :: String -> LiftQ ()
setPrefix p = do
s <- getState
case (prefix s) of
Just p' -> errorMsg ("cannot set the prefix to " ++
(show p) ++
": prefix already defined as " ++
p')
Nothing -> setState $ s { prefix = Just p }
getPrefix :: LiftQ String
getPrefix = do
s <- getState
case (prefix s) of
Nothing -> errorMsg "undefined prefix"
Just p -> return p
setMonadName :: String -> LiftQ ()
setMonadName m = do
s <- getState
case (monadName s) of
Just m' -> errorMsg ("cannot set the monad to " ++
(show m) ++
": monad already defined as " ++
m')
Nothing -> setState $ s { monadName = Just m }
getMonadName :: LiftQ String
getMonadName = do
s <- getState
case (monadName s) of
Nothing -> errorMsg "undefined monad"
Just m -> return m
mkName :: String -> Name
mkName s = TH.mkName s
newName :: String -> LiftQ Name
newName st = embedQ $ TH.newName st
sanitizeString :: String -> String
sanitizeString name =
List.concat (List.map
(\c ->
Map.findWithDefault c c
(Map.map (\s -> "symb_" ++ s ++ "_")
unicodeNames))
(List.map (\x -> [x]) name))
where
unicodeNames :: Map.Map String String
unicodeNames = Map.fromList
[("!","exclamation"),
("\"","doublequote"),
("#","sharp"),
("$","dollar"),
("%","percent"),
("&","ampersand"),
("'","quote"),
("(","oparent"),
(")","cparent"),
("*","star"),
("+","plus"),
(",","comma"),
("-","minus"),
("/","slash"),
(":","colon"),
(";","semicolon"),
("<","oangle"),
("=","equal"),
(">","cangle"),
("?","question"),
("@","at"),
("[","obracket"),
("\\","backslash"),
("]","cbracket"),
("^","caret"),
("`","graveaccent"),
("{","obrace"),
("|","vbar"),
("}","cbrace"),
("~","tilde")]
templateString :: String -> LiftQ String
templateString s = do
p <- getPrefix
return (p ++ (sanitizeString s))
lookForTemplate :: Name -> LiftQ (Maybe Name)
lookForTemplate n = do
t_string <- templateString $ TH.nameBase n
embedQ $ TH.lookupValueName t_string
makeTemplateName :: Name -> LiftQ Name
makeTemplateName n = do
t_string <- templateString $ TH.nameBase n
return $ TH.mkName t_string
prettyPrint :: TH.Ppr a => LiftQ a -> IO ()
prettyPrint x = (TH.runQ $ extractQ "prettyPrint: " x) >>= (putStrLn . TH.pprint)
clauseGetPats :: TH.Clause -> [TH.Pat]
clauseGetPats (TH.Clause pats _ _) = pats
equalNEListElts :: Eq a => [a] -> Bool
equalNEListElts [] = True
equalNEListElts (h:list) = foldl (&&) True $ map (== h) list
clausesLengthPats :: [TH.Clause] -> LiftQ Int
clausesLengthPats [] = errorMsg "empty clause"
clausesLengthPats clauses
| (equalNEListElts $ map length $ map clauseGetPats clauses) =
return $ length $ clauseGetPats $ head clauses
clausesLengthPats _ = errorMsg "patterns in clause are not of equal size"