module Idris.Core.CaseTree(CaseDef(..), SC, SC'(..), CaseAlt, CaseAlt'(..), ErasureInfo,
Phase(..), CaseTree, CaseType(..),
simpleCase, small, namesUsed, findCalls, findUsedArgs,
substSC, substAlt, mkForce) where
import Idris.Core.TT
import Control.Applicative hiding (Const)
import Control.Monad.State
import Control.Monad.Reader
import Data.Maybe
import Data.List hiding (partition)
import qualified Data.List(partition)
import Debug.Trace
data CaseDef = CaseDef [Name] !SC [Term]
deriving Show
data SC' t = Case CaseType Name [CaseAlt' t]
| ProjCase t [CaseAlt' t]
| STerm !t
| UnmatchedCase String
| ImpossibleCase
deriving (Eq, Ord, Functor)
data CaseType = Updatable | Shared
deriving (Eq, Ord, Show)
type SC = SC' Term
data CaseAlt' t = ConCase Name Int [Name] !(SC' t)
| FnCase Name [Name] !(SC' t)
| ConstCase Const !(SC' t)
| SucCase Name !(SC' t)
| DefaultCase !(SC' t)
deriving (Show, Eq, Ord, Functor)
type CaseAlt = CaseAlt' Term
instance Show t => Show (SC' t) where
show sc = show' 1 sc
where
show' i (Case up n alts) = "case" ++ u ++ show n ++ " of\n" ++ indent i ++
showSep ("\n" ++ indent i) (map (showA i) alts)
where u = case up of
Updatable -> "! "
Shared -> " "
show' i (ProjCase tm alts) = "case " ++ show tm ++ " of " ++
showSep ("\n" ++ indent i) (map (showA i) alts)
show' i (STerm tm) = show tm
show' i (UnmatchedCase str) = "error " ++ show str
show' i ImpossibleCase = "impossible"
indent i = concat $ take i (repeat " ")
showA i (ConCase n t args sc)
= show n ++ "(" ++ showSep (", ") (map show args) ++ ") => "
++ show' (i+1) sc
showA i (FnCase n args sc)
= "FN " ++ show n ++ "(" ++ showSep (", ") (map show args) ++ ") => "
++ show' (i+1) sc
showA i (ConstCase t sc)
= show t ++ " => " ++ show' (i+1) sc
showA i (SucCase n sc)
= show n ++ "+1 => " ++ show' (i+1) sc
showA i (DefaultCase sc)
= "_ => " ++ show' (i+1) sc
type CaseTree = SC
type Clause = ([Pat], (Term, Term))
type CS = ([Term], Int, [(Name, Type)])
instance TermSize SC where
termsize n (Case _ n' as) = termsize n as
termsize n (ProjCase n' as) = termsize n as
termsize n (STerm t) = termsize n t
termsize n _ = 1
instance TermSize CaseAlt where
termsize n (ConCase _ _ _ s) = termsize n s
termsize n (FnCase _ _ s) = termsize n s
termsize n (ConstCase _ s) = termsize n s
termsize n (SucCase _ s) = termsize n s
termsize n (DefaultCase s) = termsize n s
small :: Name -> [Name] -> SC -> Bool
small n args t = let as = findAllUsedArgs t args in
length as == length (nub as) &&
termsize n t < 10
namesUsed :: SC -> [Name]
namesUsed sc = nub $ nu' [] sc where
nu' ps (Case _ n alts) = nub (concatMap (nua ps) alts) \\ [n]
nu' ps (ProjCase t alts) = nub $ nut ps t ++ concatMap (nua ps) alts
nu' ps (STerm t) = nub $ nut ps t
nu' ps _ = []
nua ps (ConCase n i args sc) = nub (nu' (ps ++ args) sc) \\ args
nua ps (FnCase n args sc) = nub (nu' (ps ++ args) sc) \\ args
nua ps (ConstCase _ sc) = nu' ps sc
nua ps (SucCase _ sc) = nu' ps sc
nua ps (DefaultCase sc) = nu' ps sc
nut ps (P _ n _) | n `elem` ps = []
| otherwise = [n]
nut ps (App _ f a) = nut ps f ++ nut ps a
nut ps (Proj t _) = nut ps t
nut ps (Bind n (Let t v) sc) = nut ps v ++ nut (n:ps) sc
nut ps (Bind n b sc) = nut (n:ps) sc
nut ps _ = []
findCalls :: SC -> [Name] -> [(Name, [[Name]])]
findCalls sc topargs = nub $ nu' topargs sc where
nu' ps (Case _ n alts) = nub (concatMap (nua (n : ps)) alts)
nu' ps (ProjCase t alts) = nub $ nut ps t ++ concatMap (nua ps) alts
nu' ps (STerm t) = nub $ nut ps t
nu' ps _ = []
nua ps (ConCase n i args sc) = nub (nu' (ps ++ args) sc)
nua ps (FnCase n args sc) = nub (nu' (ps ++ args) sc)
nua ps (ConstCase _ sc) = nu' ps sc
nua ps (SucCase _ sc) = nu' ps sc
nua ps (DefaultCase sc) = nu' ps sc
nut ps (P _ n _) | n `elem` ps = []
| otherwise = [(n, [])]
nut ps fn@(App _ f a)
| (P _ n _, args) <- unApply fn
= if n `elem` ps then nut ps f ++ nut ps a
else [(n, map argNames args)] ++ concatMap (nut ps) args
| (P (TCon _ _) n _, _) <- unApply fn = []
| otherwise = nut ps f ++ nut ps a
nut ps (Bind n (Let t v) sc) = nut ps v ++ nut (n:ps) sc
nut ps (Proj t _) = nut ps t
nut ps (Bind n b sc) = nut (n:ps) sc
nut ps _ = []
argNames tm = let ns = directUse tm in
filter (\x -> x `elem` ns) topargs
directUse :: TT Name -> [Name]
directUse (P _ n _) = [n]
directUse (Bind n (Let t v) sc) = nub $ directUse v ++ (directUse sc \\ [n])
++ directUse t
directUse (Bind n b sc) = nub $ directUse (binderTy b) ++ (directUse sc \\ [n])
directUse fn@(App _ f a)
| (P Ref (UN pfk) _, [App _ e w]) <- unApply fn,
pfk == txt "prim_fork"
= directUse e ++ directUse w
| (P Ref (UN fce) _, [_, _, a]) <- unApply fn,
fce == txt "Force"
= directUse a
| (P Ref n _, args) <- unApply fn = []
| (P (TCon _ _) n _, args) <- unApply fn = []
| otherwise = nub $ directUse f ++ directUse a
directUse (Proj x i) = nub $ directUse x
directUse _ = []
findUsedArgs :: SC -> [Name] -> [Name]
findUsedArgs sc topargs = nub (findAllUsedArgs sc topargs)
findAllUsedArgs sc topargs = filter (\x -> x `elem` topargs) (nu' sc) where
nu' (Case _ n alts) = n : concatMap nua alts
nu' (ProjCase t alts) = directUse t ++ concatMap nua alts
nu' (STerm t) = directUse t
nu' _ = []
nua (ConCase n i args sc) = nu' sc
nua (FnCase n args sc) = nu' sc
nua (ConstCase _ sc) = nu' sc
nua (SucCase _ sc) = nu' sc
nua (DefaultCase sc) = nu' sc
isUsed :: SC -> Name -> Bool
isUsed sc n = used sc where
used (Case _ n' alts) = n == n' || any usedA alts
used (ProjCase t alts) = n `elem` freeNames t || any usedA alts
used (STerm t) = n `elem` freeNames t
used _ = False
usedA (ConCase _ _ args sc) = used sc
usedA (FnCase _ args sc) = used sc
usedA (ConstCase _ sc) = used sc
usedA (SucCase _ sc) = used sc
usedA (DefaultCase sc) = used sc
type ErasureInfo = Name -> [Int]
type CaseBuilder a = ReaderT ErasureInfo (State CS) a
runCaseBuilder :: ErasureInfo -> CaseBuilder a -> (CS -> (a, CS))
runCaseBuilder ei bld = runState $ runReaderT bld ei
data Phase = CompileTime | RunTime
deriving (Show, Eq)
simpleCase :: Bool -> SC -> Bool ->
Phase -> FC -> [Int] -> [Type] ->
[([Name], Term, Term)] ->
ErasureInfo ->
TC CaseDef
simpleCase tc defcase reflect phase fc inacc argtys cs erInfo
= sc' tc defcase phase fc (filter (\(_, _, r) ->
case r of
Impossible -> False
_ -> True) cs)
where
sc' tc defcase phase fc []
= return $ CaseDef [] (UnmatchedCase "No pattern clauses") []
sc' tc defcase phase fc cs
= let proj = phase == RunTime
vnames = fstT (head cs)
pats = map (\ (avs, l, r) ->
(avs, toPats reflect tc l, (l, r))) cs
chkPats = mapM chkAccessible pats in
case chkPats of
OK pats ->
let numargs = length (fst (head pats))
ns = take numargs args
(ns', ps') = order [(n, i `elem` inacc) | (i,n) <- zip [0..] ns] pats
(tree, st) = runCaseBuilder erInfo
(match ns' ps' defcase)
([], numargs, [])
t = CaseDef ns (prune proj (depatt ns' tree)) (fstT st) in
if proj then return (stripLambdas t)
else return t
Error err -> Error (At fc err)
where args = map (\i -> sMN i "e") [0..]
defaultCase True = STerm Erased
defaultCase False = UnmatchedCase "Error"
fstT (x, _, _) = x
lstT (_, _, x) = x
chkAccessible (avs, l, c)
| phase == RunTime || reflect = return (l, c)
| otherwise = do mapM_ (acc l) avs
return (l, c)
acc [] n = Error (Inaccessible n)
acc (PV x t : xs) n | x == n = OK ()
acc (PCon _ _ _ ps : xs) n = acc (ps ++ xs) n
acc (PSuc p : xs) n = acc (p : xs) n
acc (_ : xs) n = acc xs n
checkSameTypes :: [(Name, Type)] -> SC -> Bool
checkSameTypes tys (Case _ n alts)
= case lookup n tys of
Just t -> and (map (checkAlts t) alts)
_ -> and (map ((checkSameTypes tys).getSC) alts)
where
checkAlts t (ConCase n _ _ sc) = isType n t && checkSameTypes tys sc
checkAlts (Constant t) (ConstCase c sc) = isConstType c t && checkSameTypes tys sc
checkAlts _ (ConstCase c sc) = False
checkAlts _ _ = True
getSC (ConCase _ _ _ sc) = sc
getSC (FnCase _ _ sc) = sc
getSC (ConstCase _ sc) = sc
getSC (SucCase _ sc) = sc
getSC (DefaultCase sc) = sc
checkSameTypes _ _ = True
isType n t | (P (TCon _ _) _ _, _) <- unApply t = True
isType n t | (P Ref _ _, _) <- unApply t = True
isType n t = False
isConstType (I _) (AType (ATInt ITNative)) = True
isConstType (BI _) (AType (ATInt ITBig)) = True
isConstType (Fl _) (AType ATFloat) = True
isConstType (Ch _) (AType (ATInt ITChar)) = True
isConstType (Str _) StrType = True
isConstType (B8 _) (AType (ATInt _)) = True
isConstType (B16 _) (AType (ATInt _)) = True
isConstType (B32 _) (AType (ATInt _)) = True
isConstType (B64 _) (AType (ATInt _)) = True
isConstType _ _ = False
data Pat = PCon Bool Name Int [Pat]
| PConst Const
| PV Name Type
| PSuc Pat
| PReflected Name [Pat]
| PAny
| PTyPat
deriving Show
toPats :: Bool -> Bool -> Term -> [Pat]
toPats reflect tc f = reverse (toPat reflect tc (getArgs f)) where
getArgs (App _ f a) = a : getArgs f
getArgs _ = []
toPat :: Bool -> Bool -> [Term] -> [Pat]
toPat reflect tc = map $ toPat' []
where
toPat' [_,_,arg] (P (DCon t a uniq) nm@(UN n) _)
| n == txt "Delay"
= PCon uniq nm t [PAny, PAny, toPat' [] arg]
toPat' args (P (DCon t a uniq) nm@(NS (UN n) [own]) _)
| n == txt "Read" && own == txt "Ownership"
= PCon False nm t (map shareCons (map (toPat' []) args))
where shareCons (PCon _ n i ps) = PCon False n i (map shareCons ps)
shareCons p = p
toPat' args (P (DCon t a uniq) n _)
= PCon uniq n t $ map (toPat' []) args
toPat' [p, Constant (BI 1)] (P _ (UN pabi) _)
| pabi == txt "prim__addBigInt"
= PSuc $ toPat' [] p
toPat' [] (P Bound n ty) = PV n ty
toPat' args (App _ f a) = toPat' (a : args) f
toPat' [] (Constant x) | isTypeConst x = PTyPat
| otherwise = PConst x
toPat' [] (Bind n (Pi _ t _) sc)
| reflect && noOccurrence n sc
= PReflected (sUN "->") [toPat' [] t, toPat' [] sc]
toPat' args (P _ n _)
| reflect
= PReflected n $ map (toPat' []) args
toPat' _ t = PAny
fixedN IT8 = "Bits8"
fixedN IT16 = "Bits16"
fixedN IT32 = "Bits32"
fixedN IT64 = "Bits64"
data Partition = Cons [Clause]
| Vars [Clause]
deriving Show
isVarPat (PV _ _ : ps , _) = True
isVarPat (PAny : ps , _) = True
isVarPat (PTyPat : ps , _) = True
isVarPat _ = False
isConPat (PCon _ _ _ _ : ps, _) = True
isConPat (PReflected _ _ : ps, _) = True
isConPat (PSuc _ : ps, _) = True
isConPat (PConst _ : ps, _) = True
isConPat _ = False
partition :: [Clause] -> [Partition]
partition [] = []
partition ms@(m : _)
| isVarPat m = let (vars, rest) = span isVarPat ms in
Vars vars : partition rest
| isConPat m = let (cons, rest) = span isConPat ms in
Cons cons : partition rest
partition xs = error $ "Partition " ++ show xs
order :: [(Name, Bool)] -> [Clause] -> ([Name], [Clause])
order [] cs = ([], cs)
order ns' [] = (map fst ns', [])
order ns' cs = let patnames = transpose (map (zip ns') (map fst cs))
(patnames_ord, patnames_rest)
= Data.List.partition (noClash . map snd) patnames
pats' = transpose (sortBy moreDistinct (reverse patnames_ord)
++ patnames_rest) in
(getNOrder pats', zipWith rebuild pats' cs)
where
getNOrder [] = error $ "Failed order on " ++ show (map fst ns', cs)
getNOrder (c : _) = map (fst . fst) c
rebuild patnames clause = (map snd patnames, snd clause)
noClash [] = True
noClash (p : ps) = not (any (clashPat p) ps) && noClash ps
clashPat (PCon _ _ _ _) (PConst _) = True
clashPat (PConst _) (PCon _ _ _ _) = True
clashPat (PCon _ _ _ _) (PSuc _) = True
clashPat (PSuc _) (PCon _ _ _ _) = True
clashPat (PCon _ n i _) (PCon _ n' i' _) | i == i' = n /= n'
clashPat _ _ = False
moreDistinct xs ys = compare (snd . fst . head $ xs, numNames [] (map snd ys))
(snd . fst . head $ ys, numNames [] (map snd xs))
numNames xs (PCon _ n _ _ : ps)
| not (Left n `elem` xs) = numNames (Left n : xs) ps
numNames xs (PConst c : ps)
| not (Right c `elem` xs) = numNames (Right c : xs) ps
numNames xs (_ : ps) = numNames xs ps
numNames xs [] = length xs
match :: [Name] -> [Clause] -> SC
-> CaseBuilder SC
match [] (([], ret) : xs) err
= do (ts, v, ntys) <- get
put (ts ++ (map (fst.snd) xs), v, ntys)
case snd ret of
Impossible -> return ImpossibleCase
tm -> return $ STerm tm
match vs cs err = do let ps = partition cs
mixture vs ps err
mixture :: [Name] -> [Partition] -> SC -> CaseBuilder SC
mixture vs [] err = return err
mixture vs (Cons ms : ps) err = do fallthrough <- mixture vs ps err
conRule vs ms fallthrough
mixture vs (Vars ms : ps) err = do fallthrough <- mixture vs ps err
varRule vs ms fallthrough
inaccessibleArgs :: Name -> CaseBuilder [Int]
inaccessibleArgs n = do
getInaccessiblePositions <- ask
return $ getInaccessiblePositions n
data ConType = CName Name Int
| CFn Name
| CSuc
| CConst Const
deriving (Show, Eq)
data Group = ConGroup Bool
ConType
[([Pat], Clause)]
deriving Show
conRule :: [Name] -> [Clause] -> SC -> CaseBuilder SC
conRule (v:vs) cs err = do groups <- groupCons cs
caseGroups (v:vs) groups err
caseGroups :: [Name] -> [Group] -> SC -> CaseBuilder SC
caseGroups (v:vs) gs err = do g <- altGroups gs
return $ Case (getShared gs) v (sort g)
where
getShared (ConGroup True _ _ : _) = Updatable
getShared _ = Shared
altGroups [] = return [DefaultCase err]
altGroups (ConGroup _ (CName n i) args : cs)
= (:) <$> altGroup n i args <*> altGroups cs
altGroups (ConGroup _ (CFn n) args : cs)
= (:) <$> altFnGroup n args <*> altGroups cs
altGroups (ConGroup _ CSuc args : cs)
= (:) <$> altSucGroup args <*> altGroups cs
altGroups (ConGroup _ (CConst c) args : cs)
= (:) <$> altConstGroup c args <*> altGroups cs
altGroup n i args
= do inacc <- inaccessibleArgs n
(newVars, accVars, inaccVars, nextCs) <- argsToAlt inacc args
matchCs <- match (accVars ++ vs ++ inaccVars) nextCs err
return $ ConCase n i newVars matchCs
altFnGroup n args = do (newVars, _, [], nextCs) <- argsToAlt [] args
matchCs <- match (newVars ++ vs) nextCs err
return $ FnCase n newVars matchCs
altSucGroup args = do ([newVar], _, [], nextCs) <- argsToAlt [] args
matchCs <- match (newVar:vs) nextCs err
return $ SucCase newVar matchCs
altConstGroup n args = do (_, _, [], nextCs) <- argsToAlt [] args
matchCs <- match vs nextCs err
return $ ConstCase n matchCs
argsToAlt :: [Int] -> [([Pat], Clause)] -> CaseBuilder ([Name], [Name], [Name], [Clause])
argsToAlt _ [] = return ([], [], [], [])
argsToAlt inacc rs@((r, m) : rest) = do
newVars <- getNewVars r
let (accVars, inaccVars) = partitionAcc newVars
return (newVars, accVars, inaccVars, addRs rs)
where
getNewVars :: [Pat] -> CaseBuilder [Name]
getNewVars [] = return []
getNewVars ((PV n t) : ns) = do v <- getVar "e"
nsv <- getNewVars ns
(cs, i, ntys) <- get
put (cs, i, (v, t) : ntys)
return (v : nsv)
getNewVars (PAny : ns) = (:) <$> getVar "i" <*> getNewVars ns
getNewVars (PTyPat : ns) = (:) <$> getVar "t" <*> getNewVars ns
getNewVars (_ : ns) = (:) <$> getVar "e" <*> getNewVars ns
partitionAcc xs =
( [x | (i,x) <- zip [0..] xs, i `notElem` inacc]
, [x | (i,x) <- zip [0..] xs, i `elem` inacc]
)
addRs [] = []
addRs ((r, (ps, res)) : rs) = ((acc++ps++inacc, res) : addRs rs)
where
(acc, inacc) = partitionAcc r
uniq i (UN n) = MN i n
uniq i n = n
getVar :: String -> CaseBuilder Name
getVar b = do (t, v, ntys) <- get; put (t, v+1, ntys); return (sMN v b)
groupCons :: [Clause] -> CaseBuilder [Group]
groupCons cs = gc [] cs
where
gc acc [] = return acc
gc acc ((p : ps, res) : cs) =
do acc' <- addGroup p ps res acc
gc acc' cs
addGroup p ps res acc = case p of
PCon uniq con i args -> return $ addg uniq (CName con i) args (ps, res) acc
PConst cval -> return $ addConG cval (ps, res) acc
PSuc n -> return $ addg False CSuc [n] (ps, res) acc
PReflected fn args -> return $ addg False (CFn fn) args (ps, res) acc
pat -> fail $ show pat ++ " is not a constructor or constant (can't happen)"
addg uniq c conargs res []
= [ConGroup uniq c [(conargs, res)]]
addg uniq c conargs res (g@(ConGroup _ c' cs):gs)
| c == c' = ConGroup uniq c (cs ++ [(conargs, res)]) : gs
| otherwise = g : addg uniq c conargs res gs
addConG con res [] = [ConGroup False (CConst con) [([], res)]]
addConG con res (g@(ConGroup False (CConst n) cs) : gs)
| con == n = ConGroup False (CConst n) (cs ++ [([], res)]) : gs
addConG con res (g : gs) = g : addConG con res gs
varRule :: [Name] -> [Clause] -> SC -> CaseBuilder SC
varRule (v : vs) alts err =
do alts' <- mapM (repVar v) alts
match vs alts' err
where
repVar v (PV p ty : ps , (lhs, res))
= do (cs, i, ntys) <- get
put (cs, i, (v, ty) : ntys)
return (ps, (lhs, subst p (P Bound v ty) res))
repVar v (PAny : ps , res) = return (ps, res)
repVar v (PTyPat : ps , res) = return (ps, res)
depatt :: [Name] -> SC -> SC
depatt ns tm = dp [] tm
where
dp ms (STerm tm) = STerm (applyMaps ms tm)
dp ms (Case up x alts) = Case up x (map (dpa ms x) alts)
dp ms sc = sc
dpa ms x (ConCase n i args sc)
= ConCase n i args (dp ((x, (n, args)) : ms) sc)
dpa ms x (FnCase n args sc)
= FnCase n args (dp ((x, (n, args)) : ms) sc)
dpa ms x (ConstCase c sc) = ConstCase c (dp ms sc)
dpa ms x (SucCase n sc) = SucCase n (dp ms sc)
dpa ms x (DefaultCase sc) = DefaultCase (dp ms sc)
applyMaps ms f@(App _ _ _)
| (P nt cn pty, args) <- unApply f
= let args' = map (applyMaps ms) args in
applyMap ms nt cn pty args'
where
applyMap [] nt cn pty args' = mkApp (P nt cn pty) args'
applyMap ((x, (n, args)) : ms) nt cn pty args'
| and ((length args == length args') :
(n == cn) : zipWith same args args') = P Ref x Erased
| otherwise = applyMap ms nt cn pty args'
same n (P _ n' _) = n == n'
same _ _ = False
applyMaps ms (App s f a) = App s (applyMaps ms f) (applyMaps ms a)
applyMaps ms t = t
prune :: Bool
-> SC -> SC
prune proj (Case up n alts) = case alts' of
[] -> ImpossibleCase
as@[ConCase cn i args sc]
| proj -> let sc' = prune proj sc in
if any (isUsed sc') args
then Case up n [ConCase cn i args sc']
else sc'
[SucCase cn sc]
| proj
-> projRep cn n (1) $ prune proj sc
[ConstCase _ sc]
-> prune proj sc
[s@(SucCase _ _), DefaultCase dc]
-> Case up n [ConstCase (BI 0) dc, s]
as -> Case up n as
where
alts' = filter (not . erased) $ map pruneAlt alts
pruneAlt (ConCase cn i ns sc) = ConCase cn i ns (prune proj sc)
pruneAlt (FnCase cn ns sc) = FnCase cn ns (prune proj sc)
pruneAlt (ConstCase c sc) = ConstCase c (prune proj sc)
pruneAlt (SucCase n sc) = SucCase n (prune proj sc)
pruneAlt (DefaultCase sc) = DefaultCase (prune proj sc)
erased (DefaultCase (STerm Erased)) = True
erased (DefaultCase ImpossibleCase) = True
erased _ = False
projRep :: Name -> Name -> Int -> SC -> SC
projRep arg n i (Case up x alts) | x == arg
= ProjCase (Proj (P Bound n Erased) i) $ map (projRepAlt arg n i) alts
projRep arg n i (Case up x alts)
= Case up x (map (projRepAlt arg n i) alts)
projRep arg n i (ProjCase t alts)
= ProjCase (projRepTm arg n i t) $ map (projRepAlt arg n i) alts
projRep arg n i (STerm t) = STerm (projRepTm arg n i t)
projRep arg n i c = c
projRepAlt arg n i (ConCase cn t args rhs)
= ConCase cn t args (projRep arg n i rhs)
projRepAlt arg n i (FnCase cn args rhs)
= FnCase cn args (projRep arg n i rhs)
projRepAlt arg n i (ConstCase t rhs)
= ConstCase t (projRep arg n i rhs)
projRepAlt arg n i (SucCase sn rhs)
= SucCase sn (projRep arg n i rhs)
projRepAlt arg n i (DefaultCase rhs)
= DefaultCase (projRep arg n i rhs)
projRepTm arg n i t = subst arg (Proj (P Bound n Erased) i) t
prune _ t = t
stripLambdas :: CaseDef -> CaseDef
stripLambdas (CaseDef ns (STerm (Bind x (Lam _) sc)) tm)
= stripLambdas (CaseDef (ns ++ [x]) (STerm (instantiate (P Bound x Erased) sc)) tm)
stripLambdas x = x
substSC :: Name -> Name -> SC -> SC
substSC n repl (Case up n' alts)
| n == n' = Case up repl (map (substAlt n repl) alts)
| otherwise = Case up n' (map (substAlt n repl) alts)
substSC n repl (STerm t) = STerm $ subst n (P Bound repl Erased) t
substSC n repl (UnmatchedCase errmsg) = UnmatchedCase errmsg
substSC n repl ImpossibleCase = ImpossibleCase
substSC n repl sc = error $ "unsupported in substSC: " ++ show sc
substAlt :: Name -> Name -> CaseAlt -> CaseAlt
substAlt n repl (ConCase cn a ns sc) = ConCase cn a ns (substSC n repl sc)
substAlt n repl (FnCase fn ns sc) = FnCase fn ns (substSC n repl sc)
substAlt n repl (ConstCase c sc) = ConstCase c (substSC n repl sc)
substAlt n repl (SucCase n' sc)
| n == n' = SucCase n (substSC n repl sc)
| otherwise = SucCase n' (substSC n repl sc)
substAlt n repl (DefaultCase sc) = DefaultCase (substSC n repl sc)
mkForce :: Name -> Name -> SC -> SC
mkForce = mkForceSC
where
mkForceSC n arg (Case up x alts) | x == arg
= Case up n $ map (mkForceAlt n arg) alts
mkForceSC n arg (Case up x alts)
= Case up x (map (mkForceAlt n arg) alts)
mkForceSC n arg (ProjCase t alts)
= ProjCase t $ map (mkForceAlt n arg) alts
mkForceSC n arg c = c
mkForceAlt n arg (ConCase cn t args rhs)
= ConCase cn t args (mkForceSC n arg rhs)
mkForceAlt n arg (FnCase cn args rhs)
= FnCase cn args (mkForceSC n arg rhs)
mkForceAlt n arg (ConstCase t rhs)
= ConstCase t (mkForceSC n arg rhs)
mkForceAlt n arg (SucCase sn rhs)
= SucCase sn (mkForceSC n arg rhs)
mkForceAlt n arg (DefaultCase rhs)
= DefaultCase (mkForceSC n arg rhs)
forceTm n arg t = subst n arg t