{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE RankNTypes #-}

-- | This module describes stripped-down Template Haskell abstract
-- syntax trees (ASTs) for a subset of Haskell.

module Quipper.Utils.Template.Lifting where

import Control.Monad.State

import qualified Data.Map as Map
import Data.Map (Map)

import qualified Data.List as List

import Data.Maybe (catMaybes)

import qualified Data.Set as Set
import Data.Set (Set)

import qualified Language.Haskell.TH as TH
import Language.Haskell.TH (Name)

-- Get the monad to build the lifting.
import Quipper.Utils.Template.LiftQ


-- * Abstract syntax trees of a simplified language

-- | There are no \"guarded bodies\". One net effect is to make the
-- \"where\" construct equivalent to a simple \"let\".
type Body = Exp

-- | Literals.
data Lit =
   CharL Char          -- ^ Characters.
 | IntegerL Integer    -- ^ Integers.
 | RationalL Rational  -- ^ Reals.
   deriving (Show)


-- | Patterns.
data Pat =
    LitP Lit          -- ^ Literal.
  | VarP Name         -- ^ Variable name.
  | TupP [Pat]        -- ^ Tuple.
  | WildP             -- ^ Wildchar.
  | ListP [Pat]       -- ^ List as @[...]@.
  | ConP Name [Pat]   -- ^ Cons: @h:t@.
    deriving (Show)

-- | Match term construct.
data Match =
  Match Pat Body
  deriving (Show)

-- | First-level declaration.
data Dec =
  ValD Name Body
  deriving (Show)

-- | Expression
data Exp =
    VarE Name         -- ^ Variable name.
  | ConE Name         -- ^ Constant name.
  | LitE Lit          -- ^ Literal.
  | AppE Exp Exp      -- ^ Application.
  | LamE Name Exp     -- ^ Lambda abstraction.
  | TupE [Exp]        -- ^ Tuple.
  | CondE Exp Exp Exp -- ^ If-then-else.
  | LetE [Dec] Exp    -- ^ Let-construct.
  | CaseE Exp [Match] -- ^ Case distinction
  | ListE [Exp]       -- ^ List: @[...]@.
  | ReturnE           -- ^ hardcoded constant for @'return'@.
  | MAppE             -- ^ hardcoded constant for @'>>='@.
  deriving (Show)


-- $ Syntactic sugar to recover do-notation.

-- | Datatype to encode the notation @x <- expr@.
data BindS = BindS Name Exp

-- | A simple @do@: list of monadic @let@ followed by a computation.
doE :: [BindS] -> Exp -> Exp
doE binds exp = foldr doOne exp binds
  where
    doOne :: BindS -> Exp -> Exp
    doOne (BindS n value) computation = AppE (AppE MAppE value) (LamE n computation)


-- * Variable substitution


-- | Get the set of variable names in a pattern.
getVarNames :: Pat -> Set Name
getVarNames (VarP n) = Set.singleton n
getVarNames (TupP pats) = Set.unions $ map getVarNames pats
getVarNames (ListP pats) = Set.unions $ map getVarNames pats
getVarNames _ = Set.empty

-- | Substitution in a @'Match'@.
substMatch :: Name -> Exp -> Match -> Match
substMatch n s (Match p e) | Set.member n (getVarNames p) = Match p e
                           | True                         = Match p (substExp n s e)


-- | Substitution in a @'Dec'@.
substDec :: Name -> Exp -> Dec -> Dec
substDec n s (ValD m e) | n == m = ValD m e
                        | True   = ValD m (substExp n s e)

-- | Substitution in an @'Exp'@.
substExp :: Name -> Exp -> Exp -> Exp
substExp n s (VarE m) | n == m = s
                      | True   = (VarE m)
substExp n s (ConE m) = ConE m
substExp n s (LitE l) = LitE l
substExp n s (AppE e1 e2) = AppE (substExp n s e1) (substExp n s e2)
substExp n s (LamE m exp) | n == m = LamE m exp
                          | True   = LamE m $ substExp n s exp
substExp n s (TupE exps) = TupE $ map (substExp n s) exps
substExp n s (CondE e1 e2 e3) = CondE (substExp n s e1) (substExp n s e2) (substExp n s e3)
substExp n s (LetE decs exp) = LetE (map (substDec n s) decs) (substExp n s exp)
substExp n s (CaseE exp matches) = CaseE (substExp n s exp) $ map (substMatch n s) matches
substExp n s (ListE exps) = ListE $ map (substExp n s) exps
substExp n s ReturnE = ReturnE
substExp n s MAppE   = MAppE


-- | Substitution of several variables in one go.
mapSubstExp :: (Map Name Exp) -> Exp -> Exp
mapSubstExp map exp = List.foldl (\exp (x,y) -> substExp x y exp) exp $ Map.toList map


-- * Downgrading Template Haskell to AST

-- | Downgrading TH literals to @'Exp'@.
litTHtoExpAST :: TH.Lit -> LiftQ Exp
litTHtoExpAST (TH.CharL c) = return $ LitE $ CharL c
litTHtoExpAST (TH.StringL s) = return $ ListE $ map (LitE . CharL) s
litTHtoExpAST (TH.IntegerL i) = return $ LitE $ IntegerL i
litTHtoExpAST (TH.RationalL r) = return $ LitE $ RationalL r
litTHtoExpAST x = errorMsg ("lifting not handled for " ++ (show x))

-- | Downgrading TH literals to @'Pat'@.
litTHtoPatAST :: TH.Lit -> LiftQ Pat
litTHtoPatAST (TH.CharL c) = return $ LitP $ CharL c
litTHtoPatAST (TH.StringL s) = return $ ListP $ map (LitP . CharL) s
litTHtoPatAST (TH.IntegerL i) = return $ LitP $ IntegerL i
litTHtoPatAST (TH.RationalL r) = return $ LitP $ RationalL r
litTHtoPatAST x = errorMsg ("lifting not handled for " ++ (show x))


-- | Take a list of patterns coming from a @where@ section and output
-- a list of fresh names for normalized @let@s. Also gives a mapping
-- for substituting inside the expressions. Assume all names in the
-- list of patterns are distinct.
normalizePatInExp :: [Pat] -> LiftQ ([Name], Map Name Exp)
normalizePatInExp pats = do
  fresh_names <- mapM newName $ replicate (length pats) "normalizePat"
  let sets_of_old_names = List.map getVarNames pats
  let old_to_fresh old_name =
        List.lookup True $ zip (List.map (Set.member old_name) sets_of_old_names) fresh_names
  let old_to_pat old_name =
        List.lookup True $ zip (List.map (Set.member old_name) sets_of_old_names) pats
  let list_of_old_names = List.concat $ List.map Set.toList sets_of_old_names
  let maybe_list_map = mapM
            (\x -> do
                fresh <- old_to_fresh x
                pat <-   old_to_pat x
                return (x, CaseE (VarE fresh) [Match pat (VarE x)]))
            list_of_old_names
  case maybe_list_map of
    Nothing -> errorMsg "error in patterns..."
    Just l -> return $ (fresh_names, Map.fromList l)


-- | Build a @let@-expression out of pieces.
whereToLet :: Exp -> [(Pat,Exp)] -> LiftQ Exp
whereToLet exp [] = return exp
whereToLet exp list = do
  (fresh_names, pmap) <- normalizePatInExp $ map fst list
  let decs'' = map (uncurry ValD) $ zip fresh_names $ map snd list
  let decs' = map (\(ValD n e) -> ValD n $ mapSubstExp pmap e) decs''
  return $
    LetE decs' $
         CaseE (TupE $ map VarE fresh_names) [Match (TupP $ map fst list) exp]

-- | Build a @'Match'@ out of a TH clause
clauseToMatch :: TH.Clause -> LiftQ Match
clauseToMatch (TH.Clause pats body decs) = do
  pats' <- mapM patTHtoAST pats
  body' <- bodyTHtoAST body
  decs' <- mapM decTHtoAST decs
  exp <- whereToLet body' decs'
  return $ Match (TupP pats') exp

-- | From a list of TH clauses, make a case-distinction wrapped in a
-- lambda abstraction.
clausesToLambda :: [TH.Clause] -> LiftQ Exp
clausesToLambda clauses = do
  -- get length of patterns
  pats_length <- clausesLengthPats clauses
  -- make a list of new names from the function name
  fresh_names <- mapM newName $ replicate pats_length "x"
  -- make matches out of the clauses.
  matches <- mapM clauseToMatch clauses
  -- return a simple function with a case-distinction
  return $ foldr LamE
                 (CaseE (TupE $ map VarE fresh_names) matches)
                 fresh_names


-- | Downgrade expressions.
expTHtoAST :: TH.Exp -> LiftQ Exp

expTHtoAST (TH.VarE v) = return $ VarE v
expTHtoAST (TH.ConE n) = return $ ConE n
expTHtoAST (TH.LitE l) = litTHtoExpAST l

expTHtoAST (TH.AppE e1 e2) = do
  e1' <- expTHtoAST e1
  e2' <- expTHtoAST e2
  return $ AppE e1' e2'

expTHtoAST (TH.InfixE (Just e1) e2 (Just e3)) = do
  e1' <- expTHtoAST e1
  e2' <- expTHtoAST e2
  e3' <- expTHtoAST e3
  return $ AppE (AppE e2' e1') e3'

expTHtoAST (TH.InfixE Nothing e2 (Just e3)) = do
  e2' <- expTHtoAST e2
  e3' <- expTHtoAST e3
  n <- newName "x"
  return $ LamE n $ AppE (AppE e2' (VarE n)) e3'

expTHtoAST (TH.InfixE (Just e1) e2 Nothing) = do
  e1' <- expTHtoAST e1
  e2' <- expTHtoAST e2
  return $ AppE e2' e1'

expTHtoAST (TH.InfixE Nothing e2 Nothing) = do
  e2' <- expTHtoAST e2
  return e2'

expTHtoAST (TH.LamE pats exp) =
  clausesToLambda [TH.Clause pats (TH.NormalB exp) []]

expTHtoAST (TH.TupE exps) = do
  exps' <- mapM expTHtoAST exps
  return (TupE exps')

expTHtoAST (TH.CondE e1 e2 e3) = do
  e1' <- expTHtoAST e1
  e2' <- expTHtoAST e2
  e3' <- expTHtoAST e3
  return $ CondE e1' e2' e3'

expTHtoAST (TH.LetE decs exp) = do
  decs' <- mapM decTHtoAST decs
  exp' <- expTHtoAST exp
  whereToLet exp' decs'

expTHtoAST (TH.CaseE exp matches) = do
  exp' <- expTHtoAST exp
  matches' <- mapM matchTHtoAST matches
  return $ CaseE exp' matches'

expTHtoAST (TH.ListE exps) = do
  exps' <- mapM expTHtoAST exps
  return $ ListE exps'


expTHtoAST (TH.SigE e _) = expTHtoAST e

expTHtoAST x = errorMsg ("lifting not handled for " ++ (show x))


-- | Downgrade match-constructs.
matchTHtoAST :: TH.Match -> LiftQ Match
matchTHtoAST (TH.Match pat body decs) = do
   pat' <- patTHtoAST pat
   body' <- bodyTHtoAST body
   decs' <- mapM decTHtoAST decs
   exp <- whereToLet body' decs'
   return $ Match pat' exp

-- | Downgrade bodies into expressions.
bodyTHtoAST :: TH.Body -> LiftQ Exp
bodyTHtoAST (TH.NormalB exp) = expTHtoAST exp
bodyTHtoAST (TH.GuardedB x) = errorMsg ("guarded body not allowed in lifting: " ++ (show x))

-- | Downgrade patterns.
patTHtoAST :: TH.Pat -> LiftQ Pat
patTHtoAST (TH.LitP l) = litTHtoPatAST l
patTHtoAST (TH.VarP n) = return $ VarP n
patTHtoAST (TH.TupP pats) = do pats' <- mapM patTHtoAST pats; return $ TupP pats'
patTHtoAST (TH.WildP) = return WildP
patTHtoAST (TH.ListP pats) = do pats' <- mapM patTHtoAST pats; return $ ListP pats'
patTHtoAST (TH.ConP n pats) = do pats' <- mapM patTHtoAST pats; return $ ConP n pats'
patTHtoAST (TH.InfixP p1 n p2) = do
  p1' <- patTHtoAST p1
  p2' <- patTHtoAST p2
  return $ ConP n [p1',p2']
patTHtoAST x = errorMsg ("non-implemented lifting: " ++ (show x))




-- | Downgrade first-level declarations.
firstLevelDecTHtoAST :: TH.Dec -> Maybe (LiftQ Dec)
firstLevelDecTHtoAST (TH.FunD name clauses) = Just $ do
  exp <- clausesToLambda clauses
  name' <- makeTemplateName name
  return $ ValD name' $ substExp name (VarE name') exp

firstLevelDecTHtoAST (TH.ValD (TH.VarP name) body decs) = Just $ do
  body' <- bodyTHtoAST body
  decs' <- mapM decTHtoAST decs
  exp <- whereToLet body' decs'
  name' <- makeTemplateName name
  return $ ValD name' $ substExp name (VarE name') exp

firstLevelDecTHtoAST (TH.ValD _ _ _) = Just $
  errorMsg ("only variables and functions can be lifted as first-level declarations")

firstLevelDecTHtoAST (TH.SigD _ _) = Nothing

firstLevelDecTHtoAST x = Just $ errorMsg ("non-implemented lifting: " ++ (show x))


-- | Downgrade any declarations (including the ones in @where@-constructs).
decTHtoAST :: TH.Dec -> LiftQ (Pat,Exp)

decTHtoAST (TH.FunD name clauses) = do
  exp <- clausesToLambda clauses
  return $ (VarP name, exp)

decTHtoAST (TH.ValD pat body decs) = do
  pat' <- patTHtoAST pat
  body' <- bodyTHtoAST body
  decs' <- mapM decTHtoAST decs
  exp <- whereToLet body' decs'
  return $ (pat', exp)

decTHtoAST x = errorMsg ("non-implemented lifting: " ++ (show x))




-- * Upgrade AST to Template Haskell

-- | Abstract syntax tree of the type of the function 'return'.
typReturnE :: LiftQ TH.Type
typReturnE = do
  m_string <- getMonadName
  let m = TH.conT (mkName m_string)
  embedQ [t| forall x. x -> $(m) x |]

-- | Abstract syntax tree of the type of the function '>>='.
typMAppE :: LiftQ TH.Type
typMAppE = do
  m_string <- getMonadName
  let m = TH.conT (mkName m_string)
  embedQ [t| forall x y. $(m) x -> (x -> $(m) y) -> $(m) y |]


-- | Upgrade literals
litASTtoTH :: Lit -> TH.Lit
litASTtoTH (CharL c) = TH.CharL c
litASTtoTH (IntegerL i) = TH.IntegerL i
litASTtoTH (RationalL r) = TH.RationalL r

-- | Upgrade patterns.
patASTtoTH :: Pat -> TH.Pat
patASTtoTH (LitP l)      = TH.LitP $ litASTtoTH l
patASTtoTH (VarP n)      = TH.VarP n
patASTtoTH (TupP pats)   = TH.TupP $ map patASTtoTH pats
patASTtoTH WildP         = TH.WildP
patASTtoTH (ListP pats)  = TH.ListP $ map patASTtoTH pats
patASTtoTH (ConP n pats) = TH.ConP n $ map patASTtoTH pats

-- | Upgrade match-constructs.
matchASTtoTH :: Match -> LiftQ TH.Match
matchASTtoTH (Match p b) = do
  exp <- expASTtoTH b
  return $ TH.Match (patASTtoTH p) (TH.NormalB exp) []

-- | Upgrade declarations.
decASTtoTH :: Dec -> LiftQ TH.Dec

decASTtoTH (ValD n b) = do
  exp <- expASTtoTH b
  return $ TH.ValD (TH.VarP n) (TH.NormalB exp) []


-- | Upgrade expressions.
expASTtoTH :: Exp -> LiftQ TH.Exp

expASTtoTH (VarE n) = return $ TH.VarE n
expASTtoTH (ConE n) = return $ TH.ConE n
expASTtoTH (LitE l) = return $ TH.LitE $ litASTtoTH l

expASTtoTH (AppE e1 e2) = do
  e1' <- expASTtoTH e1
  e2' <- expASTtoTH e2
  return $ TH.AppE e1' e2'

expASTtoTH (LamE n e) = do
  e' <- expASTtoTH e
  return $ TH.LamE [TH.VarP n] e'

expASTtoTH (TupE exps) = do
  exps' <- mapM expASTtoTH exps
  return $ TH.TupE exps'

expASTtoTH (CondE e1 e2 e3) = do
  e1' <- expASTtoTH e1
  e2' <- expASTtoTH e2
  e3' <- expASTtoTH e3
  return $ TH.CondE e1' e2' e3'

expASTtoTH (LetE decs e) = do
  decs' <- mapM decASTtoTH decs
  e' <- expASTtoTH e
  return $ TH.LetE decs' e'

expASTtoTH (CaseE e matches) = do
  e' <- expASTtoTH e
  m' <- mapM matchASTtoTH matches
  return $ TH.CaseE e' m'

expASTtoTH (ListE exps) = do
  exps' <- mapM expASTtoTH exps
  return $ TH.ListE exps'

expASTtoTH ReturnE = do
  t <- typReturnE
  maybe_r <- embedQ $ TH.lookupValueName "return"
  case maybe_r of
    Just r -> return $ TH.SigE (TH.VarE r) t
    Nothing -> errorMsg "\'return\' undefined"

expASTtoTH MAppE = do
  t <- typMAppE
  maybe_a <- embedQ $ TH.lookupValueName ">>="
  case maybe_a of
    Just a -> return $ TH.SigE (TH.VarE a) t
    Nothing -> errorMsg "\'>>=\' undefined"




-- * Lifting AST terms (into AST terms)

-- | Variable referring to the lifting function for integers.
liftIntegerL :: Exp
liftIntegerL = VarE $ mkName "template_integer"

-- | Variable referring to the lifting function for reals.
liftRationalL :: Exp
liftRationalL = VarE $ mkName "template_rational"

-- | Lifting literals.
liftLitAST :: Lit -> LiftQ Exp
liftLitAST (CharL c) = return (AppE ReturnE (LitE $ CharL c))
liftLitAST (IntegerL i) = return $ AppE liftIntegerL (LitE $ IntegerL i)
liftLitAST (RationalL r) =  return $ AppE liftRationalL (LitE $ RationalL r)

-- | Lifting patterns.
liftPatAST :: Pat -> LiftQ Pat
liftPatAST pat = return pat

-- | Lifting match-constructs.
liftMatchAST :: Match -> LiftQ Match
liftMatchAST (Match pat exp) = do
  exp' <- liftExpAST exp
  return $ Match pat exp'

-- | Lifting declarations.
liftDecAST :: Dec -> LiftQ Dec
liftDecAST (ValD name exp) = do
  exp' <- liftExpAST exp
  return $ ValD name exp'

-- | Lifting first-level declarations.
liftFirstLevelDecAST :: Dec -> LiftQ Dec
liftFirstLevelDecAST (ValD name exp) = withBoundVar name $ do
  exp' <- liftExpAST exp
  return $ ValD name exp'

-- | Lifting expressions.
liftExpAST :: Exp -> LiftQ Exp

liftExpAST (VarE x) = do
  template_name <- lookForTemplate x
  case template_name of
    Nothing -> do
      b <- isBoundVar x
      if b
        then return $ VarE x
        else return $ AppE ReturnE $ VarE x
    Just t  -> return $ VarE t

liftExpAST (ConE n) = do
  template_name <- lookForTemplate n
  case template_name of
    Nothing -> do
      t <- templateString $ TH.nameBase n
      errorMsg ("variable " ++ t ++ " undefined")
    Just t  -> return $ VarE t

liftExpAST (LitE l) = liftLitAST l

liftExpAST (AppE e1 e2) = do
  e1' <- liftExpAST e1
  e2' <- liftExpAST e2
  n1 <- newName "app1"
  n2 <- newName "app2"
  return $ doE [BindS n1 e1', BindS n2 e2'] $ AppE (VarE n1) (VarE n2)

liftExpAST (LamE n exp) = do
  exp' <- liftExpAST exp
  return $ AppE ReturnE $ LamE n exp'

liftExpAST (TupE exps) = do
  exps' <- mapM liftExpAST exps
  fresh_names <- mapM newName $ replicate (length exps) "tupe"
  return $
    doE (map (uncurry BindS) $ zip fresh_names exps')
        (AppE ReturnE $ TupE $ map VarE fresh_names)

liftExpAST (CondE e1 e2 e3) = do
  e1' <- liftExpAST e1
  e2' <- liftExpAST e2
  e3' <- liftExpAST e3
  return $ AppE (AppE (AppE (VarE (mkName "template_if")) (e1')) (e2')) (e3')


liftExpAST (LetE decs exp) =
  let existing_names = map (\(ValD n _) -> n) decs
  in
   withBoundVars existing_names $ do
     decs' <- mapM liftDecAST decs
     exp' <- liftExpAST exp
     return $
       LetE decs' exp'


liftExpAST (CaseE exp matches) = do
  exp' <- liftExpAST exp
  matches' <- mapM liftMatchAST matches
  fresh_name <- newName "varfromcase"
  return $ doE [BindS fresh_name exp']
               $ CaseE (VarE fresh_name) matches'

liftExpAST (ListE exps) = do
  exps' <- mapM liftExpAST exps
  fresh_names <- mapM newName $ replicate (length exps) "varfromlist"
  return $
    doE (map (uncurry BindS) $ zip fresh_names exps')
       $ AppE ReturnE $ ListE $ map VarE fresh_names

-- These two are not supposed to be there!
liftExpAST ReturnE = undefined
liftExpAST MAppE   = undefined


-- | make a declaration into a template-declaration (by renaming with
-- the template-prefix).
makeDecTemplate :: Dec -> LiftQ Dec
makeDecTemplate (ValD name exp) = do
  name' <- makeTemplateName name
  return $ ValD name' $ substExp name (VarE name') exp


-- * Various pretty printing functions


-- | pretty-printing Template Haskell declarations.
prettyPrintAST :: TH.Q [TH.Dec] -> IO ()
prettyPrintAST x = prettyPrint $ do
  x' <- embedQ x
  y <- sequence $ catMaybes $ map firstLevelDecTHtoAST x'
  mapM decASTtoTH y

-- | Pretty-printing Template Haskell expressions.
prettyPrintLiftExpTH :: TH.Q (TH.Exp) -> IO ()
prettyPrintLiftExpTH x = prettyPrint $ do
  x' <- embedQ x
  y <- expTHtoAST x'
  z <- liftExpAST y
  expASTtoTH z

-- | Pretty-printing expressions.
prettyPrintLiftExpAST :: LiftQ (Exp) -> IO ()
prettyPrintLiftExpAST x = prettyPrint $ do
  z <- x
  z' <- liftExpAST z
  expASTtoTH z'


-- * The main lifting functions.


-- | Lift a list of declarations. The first argument is the name of
-- the monad to lift into.
decToMonad :: String -> TH.Q [TH.Dec] -> TH.Q [TH.Dec]
decToMonad s x = extractQ "decToMonad: " $ do
  setMonadName s
  setPrefix "template_"
  dec <- embedQ x
  decAST <- sequence $ catMaybes $ map firstLevelDecTHtoAST dec
  liftedAST <- mapM liftFirstLevelDecAST decAST
  mapM decASTtoTH liftedAST

-- | Lift an expression. The first argument is the name of the monad
-- to lift into.
expToMonad :: String -> TH.Q TH.Exp -> TH.Q TH.Exp
expToMonad s x = extractQ "expToMonad: " $ do
  setMonadName s
  setPrefix "template_"
  dec <- embedQ x
  decAST <- expTHtoAST dec
  liftedAST <- liftExpAST decAST
  expASTtoTH liftedAST