{-# LANGUAGE CPP, TemplateHaskell #-}
#if __GLASGOW_HASKELL__ <= 708
{-# LANGUAGE StandaloneDeriving #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
#endif
module Language.Haskell.TH.Desugar.Match (scExp, scLetDec) where
import Prelude hiding ( fail, exp )
#if __GLASGOW_HASKELL__ < 709
import Control.Applicative
#endif
import Control.Monad hiding ( fail )
import qualified Data.Set as S
import qualified Data.Map as Map
import Language.Haskell.TH.Instances ()
import Language.Haskell.TH.Syntax
import Language.Haskell.TH.Desugar.Core
import Language.Haskell.TH.Desugar.Util
import Language.Haskell.TH.Desugar.Reify
scExp :: DsMonad q => DExp -> q DExp
scExp (DAppE e1 e2) = DAppE <$> scExp e1 <*> scExp e2
scExp (DLamE names exp) = DLamE names <$> scExp exp
scExp (DCaseE scrut matches)
| DVarE name <- scrut
= simplCaseExp [name] clauses
| otherwise
= do scrut_name <- newUniqueName "scrut"
case_exp <- simplCaseExp [scrut_name] clauses
return $ DLetE [DValD (DVarPa scrut_name) scrut] case_exp
where
clauses = map match_to_clause matches
match_to_clause (DMatch pat exp) = DClause [pat] exp
scExp (DLetE decs body) = DLetE <$> mapM scLetDec decs <*> scExp body
scExp (DSigE exp ty) = DSigE <$> scExp exp <*> pure ty
scExp e = return e
scLetDec :: DsMonad q => DLetDec -> q DLetDec
scLetDec (DFunD name clauses@(DClause pats1 _ : _)) = do
arg_names <- mapM (const (newUniqueName "_arg")) pats1
clauses' <- mapM sc_clause_rhs clauses
case_exp <- simplCaseExp arg_names clauses'
return $ DFunD name [DClause (map DVarPa arg_names) case_exp]
where
sc_clause_rhs (DClause pats exp) = DClause pats <$> scExp exp
scLetDec (DValD pat exp) = DValD pat <$> scExp exp
scLetDec dec = return dec
type MatchResult = DExp -> DExp
matchResultToDExp :: MatchResult -> DExp
matchResultToDExp mr = mr failed_pattern_match
where
failed_pattern_match = DAppE (DVarE 'error)
(DLitE $ StringL "Pattern-match failure")
simplCaseExp :: DsMonad q
=> [Name]
-> [DClause]
-> q DExp
simplCaseExp vars clauses =
do let eis = [ EquationInfo pats (\_ -> rhs) |
DClause pats rhs <- clauses ]
matchResultToDExp `liftM` simplCase vars eis
data EquationInfo = EquationInfo [DPat] MatchResult -- like DClause, but with a hole
-- analogous to GHC's match (in deSugar/Match.lhs)
simplCase :: DsMonad q
=> [Name] -- the names of the scrutinees
-> [EquationInfo] -- the matches (where the # of pats == length (1st arg))
-> q MatchResult
simplCase [] clauses = return (foldr1 (.) match_results)
where
match_results = [ mr | EquationInfo _ mr <- clauses ]
simplCase vars@(v:_) clauses = do
(aux_binds, tidy_clauses) <- mapAndUnzipM (tidyClause v) clauses
let grouped = groupClauses tidy_clauses
match_results <- match_groups grouped
return (adjustMatchResult (foldr (.) id aux_binds) $
foldr1 (.) match_results)
where
match_groups :: DsMonad q => [[(PatGroup, EquationInfo)]] -> q [MatchResult]
match_groups [] = matchEmpty v
match_groups gs = mapM match_group gs
match_group :: DsMonad q => [(PatGroup, EquationInfo)] -> q MatchResult
match_group [] = error "Internal error in th-desugar (match_group)"
match_group eqns@((group,_) : _) =
case group of
PgCon _ -> matchConFamily vars (subGroup [(c,e) | (PgCon c, e) <- eqns])
PgLit _ -> matchLiterals vars (subGroup [(l,e) | (PgLit l, e) <- eqns])
PgBang -> matchBangs vars (drop_group eqns)
PgAny -> matchVariables vars (drop_group eqns)
drop_group = map snd
-- analogous to GHC's tidyEqnInfo
tidyClause :: DsMonad q => Name -> EquationInfo -> q (DExp -> DExp, EquationInfo)
tidyClause _ (EquationInfo [] _) =
error "Internal error in th-desugar: no patterns in tidyClause."
tidyClause v (EquationInfo (pat : pats) body) = do
(wrap, pat') <- tidy1 v pat
return (wrap, EquationInfo (pat' : pats) body)
tidy1 :: DsMonad q
=> Name -- the name of the variable that ...
-> DPat -- ... this pattern is matching against
-> q (DExp -> DExp, DPat) -- a wrapper and tidied pattern
tidy1 _ p@(DLitPa {}) = return (id, p)
tidy1 v (DVarPa var) = return (wrapBind var v, DWildPa)
tidy1 _ p@(DConPa {}) = return (id, p)
tidy1 v (DTildePa pat) = do
sel_decs <- mkSelectorDecs pat v
return (maybeDLetE sel_decs, DWildPa)
tidy1 v (DBangPa pat) =
case pat of
DLitPa _ -> tidy1 v pat -- already strict
DVarPa _ -> return (id, DBangPa pat) -- no change
DConPa _ _ -> tidy1 v pat -- already strict
DTildePa p -> tidy1 v (DBangPa p) -- discard ~ under !
DBangPa p -> tidy1 v (DBangPa p) -- discard ! under !
DWildPa -> return (id, DBangPa pat) -- no change
tidy1 _ DWildPa = return (id, DWildPa)
wrapBind :: Name -> Name -> DExp -> DExp
wrapBind new old
| new == old = id
| otherwise = DLetE [DValD (DVarPa new) (DVarE old)]
-- like GHC's mkSelectorBinds
mkSelectorDecs :: DsMonad q
=> DPat -- pattern to deconstruct
-> Name -- variable being matched against
-> q [DLetDec]
mkSelectorDecs (DVarPa v) name = return [DValD (DVarPa v) (DVarE name)]
mkSelectorDecs pat name
| S.null binders
= return []
| S.size binders == 1
= do val_var <- newUniqueName "var"
err_var <- newUniqueName "err"
bind <- mk_bind val_var err_var (head $ S.elems binders)
return [DValD (DVarPa val_var) (DVarE name),
DValD (DVarPa err_var) (DVarE 'error `DAppE`
(DLitE $ StringL "Irrefutable match failed")),
bind]
| otherwise
= do tuple_expr <- simplCaseExp [name] [DClause [pat] local_tuple]
tuple_var <- newUniqueName "tuple"
projections <- mapM (mk_projection tuple_var) [0 .. tuple_size-1]
return (DValD (DVarPa tuple_var) tuple_expr :
zipWith DValD (map DVarPa binders_list) projections)
where
binders = extractBoundNamesDPat pat
binders_list = S.toAscList binders
tuple_size = length binders_list
local_tuple = mkTupleDExp (map DVarE binders_list)
mk_projection :: DsMonad q
=> Name -- of the tuple
-> Int -- which element to get (0-indexed)
-> q DExp
mk_projection tup_name i = do
var_name <- newUniqueName "proj"
return $ DCaseE (DVarE tup_name) [DMatch (DConPa (tupleDataName tuple_size) (mk_tuple_pats var_name i))
(DVarE var_name)]
mk_tuple_pats :: Name -- of the projected element
-> Int -- which element to get (0-indexed)
-> [DPat]
mk_tuple_pats elt_name i = replicate i DWildPa ++ DVarPa elt_name : replicate (tuple_size - i - 1) DWildPa
mk_bind scrut_var err_var bndr_var = do
rhs_mr <- simplCase [scrut_var] [EquationInfo [pat] (\_ -> DVarE bndr_var)]
return (DValD (DVarPa bndr_var) (rhs_mr (DVarE err_var)))
extractBoundNamesDPat :: DPat -> S.Set Name
extractBoundNamesDPat (DLitPa _) = S.empty
extractBoundNamesDPat (DVarPa n) = S.singleton n
extractBoundNamesDPat (DConPa _ pats) = S.unions (map extractBoundNamesDPat pats)
extractBoundNamesDPat (DTildePa p) = extractBoundNamesDPat p
extractBoundNamesDPat (DBangPa p) = extractBoundNamesDPat p
extractBoundNamesDPat DWildPa = S.empty
data PatGroup
= PgAny -- immediate match (wilds, vars, lazies)
| PgCon Name
| PgLit Lit
| PgBang
-- like GHC's groupEquations
groupClauses :: [EquationInfo] -> [[(PatGroup, EquationInfo)]]
groupClauses clauses
= runs same_gp [(patGroup (firstPat clause), clause) | clause <- clauses]
where
same_gp :: (PatGroup, EquationInfo) -> (PatGroup, EquationInfo) -> Bool
(pg1,_) `same_gp` (pg2,_) = pg1 `sameGroup` pg2
patGroup :: DPat -> PatGroup
patGroup (DLitPa l) = PgLit l
patGroup (DVarPa {}) = error "Internal error in th-desugar (patGroup DVarP)"
patGroup (DConPa con _) = PgCon con
patGroup (DTildePa {}) = error "Internal error in th-desugar (patGroup DTildeP)"
patGroup (DBangPa {}) = PgBang
patGroup DWildPa = PgAny
sameGroup :: PatGroup -> PatGroup -> Bool
sameGroup PgAny PgAny = True
sameGroup PgBang PgBang = True
sameGroup (PgCon _) (PgCon _) = True
sameGroup (PgLit _) (PgLit _) = True
sameGroup _ _ = False
subGroup :: Ord a => [(a, EquationInfo)] -> [[EquationInfo]]
subGroup group
= map reverse $ Map.elems $ foldl accumulate Map.empty group
where
accumulate pg_map (pg, eqn)
= case Map.lookup pg pg_map of
Just eqns -> Map.insert pg (eqn:eqns) pg_map
Nothing -> Map.insert pg [eqn] pg_map
firstPat :: EquationInfo -> DPat
firstPat (EquationInfo (pat : _) _) = pat
firstPat _ = error "Clause encountered with no patterns -- should never happen"
data CaseAlt = CaseAlt { alt_con :: Name -- con name
, _alt_args :: [Name] -- bound var names
, _alt_rhs :: MatchResult -- RHS
}
-- from GHC's MatchCon.lhs
matchConFamily :: DsMonad q => [Name] -> [[EquationInfo]] -> q MatchResult
matchConFamily (var:vars) groups
= do alts <- mapM (matchOneCon vars) groups
mkDataConCase var alts
matchConFamily [] _ = error "Internal error in th-desugar (matchConFamily)"
-- like matchOneConLike from MatchCon
matchOneCon :: DsMonad q => [Name] -> [EquationInfo] -> q CaseAlt
matchOneCon vars eqns@(eqn1 : _)
= do arg_vars <- selectMatchVars (pat_args pat1)
match_result <- match_group arg_vars
return $ CaseAlt (pat_con pat1) arg_vars match_result
where
pat1 = firstPat eqn1
pat_args (DConPa _ pats) = pats
pat_args _ = error "Internal error in th-desugar (pat_args)"
pat_con (DConPa con _) = con
pat_con _ = error "Internal error in th-desugar (pat_con)"
match_group :: DsMonad q => [Name] -> q MatchResult
match_group arg_vars
= simplCase (arg_vars ++ vars) (map shift eqns)
shift (EquationInfo (DConPa _ args : pats) exp) = EquationInfo (args ++ pats) exp
shift _ = error "Internal error in th-desugar (shift)"
matchOneCon _ _ = error "Internal error in th-desugar (matchOneCon)"
mkDataConCase :: DsMonad q => Name -> [CaseAlt] -> q MatchResult
mkDataConCase var case_alts = do
all_ctors <- get_all_ctors (alt_con $ head case_alts)
return $ \fail ->
let matches = map (mk_alt fail) case_alts in
DCaseE (DVarE var) (matches ++ mk_default all_ctors fail)
where
mk_alt fail (CaseAlt con args body_fn)
= let body = body_fn fail in
DMatch (DConPa con (map DVarPa args)) body
mk_default all_ctors fail | exhaustive_case all_ctors = []
| otherwise = [DMatch DWildPa fail]
mentioned_ctors = S.fromList $ map alt_con case_alts
exhaustive_case all_ctors = all_ctors `S.isSubsetOf` mentioned_ctors
get_all_ctors :: DsMonad q => Name -> q (S.Set Name)
get_all_ctors con_name = do
ty_name <- dataConNameToDataName con_name
Just (DTyConI tycon_dec _) <- dsReify ty_name
return $ S.fromList $ map get_con_name $ get_cons tycon_dec
get_cons (DDataD _ _ _ _ cons _) = cons
get_cons (DDataInstD _ _ _ _ cons _) = cons
get_cons _ = []
get_con_name (DCon _ _ n _ _) = n
matchEmpty :: DsMonad q => Name -> q [MatchResult]
matchEmpty var = return [mk_seq]
where
mk_seq fail = DCaseE (DVarE var) [DMatch DWildPa fail]
matchLiterals :: DsMonad q => [Name] -> [[EquationInfo]] -> q MatchResult
matchLiterals (var:vars) sub_groups
= do alts <- mapM match_group sub_groups
return (mkCoPrimCaseMatchResult var alts)
where
match_group :: DsMonad q => [EquationInfo] -> q (Lit, MatchResult)
match_group eqns
= do let DLitPa lit = firstPat (head eqns)
match_result <- simplCase vars (shiftEqns eqns)
return (lit, match_result)
matchLiterals [] _ = error "Internal error in th-desugar (matchLiterals)"
mkCoPrimCaseMatchResult :: Name -- Scrutinee
-> [(Lit, MatchResult)]
-> MatchResult
mkCoPrimCaseMatchResult var match_alts = mk_case
where
mk_case fail = let alts = map (mk_alt fail) match_alts in
DCaseE (DVarE var) (alts ++ [DMatch DWildPa fail])
mk_alt fail (lit, body_fn)
= DMatch (DLitPa lit) (body_fn fail)
matchBangs :: DsMonad q => [Name] -> [EquationInfo] -> q MatchResult
matchBangs (var:vars) eqns
= do match_result <- simplCase (var:vars) $
map (decomposeFirstPat getBangPat) eqns
return (mkEvalMatchResult var match_result)
matchBangs [] _ = error "Internal error in th-desugar (matchBangs)"
decomposeFirstPat :: (DPat -> DPat) -> EquationInfo -> EquationInfo
decomposeFirstPat extractpat (EquationInfo (pat:pats) body)
= EquationInfo (extractpat pat : pats) body
decomposeFirstPat _ _ = error "Internal error in th-desugar (decomposeFirstPat)"
getBangPat :: DPat -> DPat
getBangPat (DBangPa p) = p
getBangPat _ = error "Internal error in th-desugar (getBangPat)"
mkEvalMatchResult :: Name -> MatchResult -> MatchResult
mkEvalMatchResult var body_fn fail
= foldl DAppE (DVarE 'seq) [DVarE var, body_fn fail]
matchVariables :: DsMonad q => [Name] -> [EquationInfo] -> q MatchResult
matchVariables (_:vars) eqns = simplCase vars (shiftEqns eqns)
matchVariables _ _ = error "Internal error in th-desugar (matchVariables)"
shiftEqns :: [EquationInfo] -> [EquationInfo]
shiftEqns = map shift
where
shift (EquationInfo pats rhs) = EquationInfo (tail pats) rhs
adjustMatchResult :: (DExp -> DExp) -> MatchResult -> MatchResult
adjustMatchResult wrap mr fail = wrap $ mr fail
-- from DsUtils
selectMatchVars :: DsMonad q => [DPat] -> q [Name]
selectMatchVars = mapM selectMatchVar
-- from DsUtils
selectMatchVar :: DsMonad q => DPat -> q Name
selectMatchVar (DBangPa pat) = selectMatchVar pat
selectMatchVar (DTildePa pat) = selectMatchVar pat
selectMatchVar (DVarPa var) = newUniqueName ('_' : nameBase var)
selectMatchVar _ = newUniqueName "_pat"
-- like GHC's runs
runs :: (a -> a -> Bool) -> [a] -> [[a]]
runs _ [] = []
runs p (x:xs) = case span (p x) xs of
(first, rest) -> (x:first) : (runs p rest)