{-# LANGUAGE PatternGuards, DeriveFunctor, TypeSynonymInstances #-}

module Idris.Core.CaseTree(CaseDef(..), SC, SC'(..), CaseAlt, CaseAlt'(..), ErasureInfo,
                     Phase(..), CaseTree, CaseType(..),
                     simpleCase, small, namesUsed, findCalls, findUsedArgs,
                     substSC, substAlt, mkForce) where

import Idris.Core.TT

import Control.Applicative hiding (Const)
import Control.Monad.State
import Control.Monad.Reader
import Data.Maybe
import Data.List hiding (partition)
import qualified Data.List(partition)
import Debug.Trace

data CaseDef = CaseDef [Name] !SC [Term]
    deriving Show

-- Note: The case-tree elaborator only produces (Case n alts)-cases;
-- in other words, it never inspects anything else than variables.
--
-- ProjCase is a special powerful case construct that allows inspection
-- of compound terms. Occurrences of ProjCase arise no earlier than
-- in the function `prune` as a means of optimisation
-- of already built case trees.
--
-- While the intermediate representation (follows in the pipeline, named LExp)
-- allows casing on arbitrary terms, here we choose to maintain the distinction
-- in order to allow for better optimisation opportunities.
--
data SC' t = Case CaseType Name [CaseAlt' t]  -- ^ invariant: lowest tags first
           | ProjCase t [CaseAlt' t] -- ^ special case for projections/thunk-forcing before inspection
           | STerm !t
           | UnmatchedCase String -- ^ error message
           | ImpossibleCase -- ^ already checked to be impossible
    deriving (Eq, Ord, Functor)
{-!
deriving instance Binary SC'
deriving instance NFData SC'
!-}

data CaseType = Updatable | Shared
   deriving (Eq, Ord, Show)

type SC = SC' Term

data CaseAlt' t = ConCase Name Int [Name] !(SC' t)
                | FnCase Name [Name]      !(SC' t) -- ^ reflection function
                | ConstCase Const         !(SC' t)
                | SucCase Name            !(SC' t)
                | DefaultCase             !(SC' t)
    deriving (Show, Eq, Ord, Functor)
{-!
deriving instance Binary CaseAlt'
deriving instance NFData CaseAlt'
!-}

type CaseAlt = CaseAlt' Term

instance Show t => Show (SC' t) where
    show sc = show' 1 sc
      where
        show' i (Case up n alts) = "case" ++ u ++ show n ++ " of\n" ++ indent i ++
                                    showSep ("\n" ++ indent i) (map (showA i) alts)
            where u = case up of
                           Updatable -> "! "
                           Shared -> " "
        show' i (ProjCase tm alts) = "case " ++ show tm ++ " of " ++
                                      showSep ("\n" ++ indent i) (map (showA i) alts)
        show' i (STerm tm) = show tm
        show' i (UnmatchedCase str) = "error " ++ show str
        show' i ImpossibleCase = "impossible"

        indent i = concat $ take i (repeat "    ")

        showA i (ConCase n t args sc)
           = show n ++ "(" ++ showSep (", ") (map show args) ++ ") => "
                ++ show' (i+1) sc
        showA i (FnCase n args sc)
           = "FN " ++ show n ++ "(" ++ showSep (", ") (map show args) ++ ") => "
                ++ show' (i+1) sc
        showA i (ConstCase t sc)
           = show t ++ " => " ++ show' (i+1) sc
        showA i (SucCase n sc)
           = show n ++ "+1 => " ++ show' (i+1) sc
        showA i (DefaultCase sc)
           = "_ => " ++ show' (i+1) sc


type CaseTree = SC
type Clause   = ([Pat], (Term, Term))
type CS = ([Term], Int, [(Name, Type)])

instance TermSize SC where
    termsize n (Case _ n' as) = termsize n as
    termsize n (ProjCase n' as) = termsize n as
    termsize n (STerm t) = termsize n t
    termsize n _ = 1

instance TermSize CaseAlt where
    termsize n (ConCase _ _ _ s) = termsize n s
    termsize n (FnCase _ _ s) = termsize n s
    termsize n (ConstCase _ s) = termsize n s
    termsize n (SucCase _ s) = termsize n s
    termsize n (DefaultCase s) = termsize n s

-- simple terms can be inlined trivially - good for primitives in particular
-- To avoid duplicating work, don't inline something which uses one
-- of its arguments in more than one place

small :: Name -> [Name] -> SC -> Bool
small n args t = let as = findAllUsedArgs t args in
                     length as == length (nub as) &&
                     termsize n t < 10

namesUsed :: SC -> [Name]
namesUsed sc = nub $ nu' [] sc where
    nu' ps (Case _ n alts) = nub (concatMap (nua ps) alts) \\ [n]
    nu' ps (ProjCase t alts) = nub $ nut ps t ++ concatMap (nua ps) alts
    nu' ps (STerm t)     = nub $ nut ps t
    nu' ps _ = []

    nua ps (ConCase n i args sc) = nub (nu' (ps ++ args) sc) \\ args
    nua ps (FnCase n args sc) = nub (nu' (ps ++ args) sc) \\ args
    nua ps (ConstCase _ sc) = nu' ps sc
    nua ps (SucCase _ sc) = nu' ps sc
    nua ps (DefaultCase sc) = nu' ps sc

    nut ps (P _ n _) | n `elem` ps = []
                     | otherwise = [n]
    nut ps (App _ f a) = nut ps f ++ nut ps a
    nut ps (Proj t _) = nut ps t
    nut ps (Bind n (Let t v) sc) = nut ps v ++ nut (n:ps) sc
    nut ps (Bind n b sc) = nut (n:ps) sc
    nut ps _ = []

-- | Return all called functions, and which arguments are used
-- in each argument position for the call, in order to help reduce
-- compilation time, and trace all unused arguments
findCalls :: SC -> [Name] -> [(Name, [[Name]])]
findCalls sc topargs = nub $ nu' topargs sc where
    nu' ps (Case _ n alts) = nub (concatMap (nua (n : ps)) alts)
    nu' ps (ProjCase t alts) = nub $ nut ps t ++ concatMap (nua ps) alts
    nu' ps (STerm t)     = nub $ nut ps t
    nu' ps _ = []

    nua ps (ConCase n i args sc) = nub (nu' (ps ++ args) sc)
    nua ps (FnCase n args sc) = nub (nu' (ps ++ args) sc)
    nua ps (ConstCase _ sc) = nu' ps sc
    nua ps (SucCase _ sc) = nu' ps sc
    nua ps (DefaultCase sc) = nu' ps sc

    nut ps (P _ n _) | n `elem` ps = []
                     | otherwise = [(n, [])] -- tmp
    nut ps fn@(App _ f a)
        | (P _ n _, args) <- unApply fn
             = if n `elem` ps then nut ps f ++ nut ps a
                  else [(n, map argNames args)] ++ concatMap (nut ps) args
        | (P (TCon _ _) n _, _) <- unApply fn = []
        | otherwise = nut ps f ++ nut ps a
    nut ps (Bind n (Let t v) sc) = nut ps v ++ nut (n:ps) sc
    nut ps (Proj t _) = nut ps t
    nut ps (Bind n b sc) = nut (n:ps) sc
    nut ps _ = []

    argNames tm = let ns = directUse tm in
                      filter (\x -> x `elem` ns) topargs

-- Find names which are used directly (i.e. not in a function call) in a term

directUse :: TT Name -> [Name]
directUse (P _ n _) = [n]
directUse (Bind n (Let t v) sc) = nub $ directUse v ++ (directUse sc \\ [n])
                                        ++ directUse t
directUse (Bind n b sc) = nub $ directUse (binderTy b) ++ (directUse sc \\ [n])
directUse fn@(App _ f a)
    | (P Ref (UN pfk) _, [App _ e w]) <- unApply fn,
         pfk == txt "prim_fork"
             = directUse e ++ directUse w -- HACK so that fork works
    | (P Ref (UN fce) _, [_, _, a]) <- unApply fn,
         fce == txt "Force"
             = directUse a -- forcing a value counts as a use
    | (P Ref n _, args) <- unApply fn = [] -- need to know what n does with them
    | (P (TCon _ _) n _, args) <- unApply fn = [] -- type constructors not used at runtime
    | otherwise = nub $ directUse f ++ directUse a
directUse (Proj x i) = nub $ directUse x
directUse _ = []

-- Find all directly used arguments (i.e. used but not in function calls)

findUsedArgs :: SC -> [Name] -> [Name]
findUsedArgs sc topargs = nub (findAllUsedArgs sc topargs)

findAllUsedArgs sc topargs = filter (\x -> x `elem` topargs) (nu' sc) where
    nu' (Case _ n alts) = n : concatMap nua alts
    nu' (ProjCase t alts) = directUse t ++ concatMap nua alts
    nu' (STerm t)     = directUse t
    nu' _             = []

    nua (ConCase n i args sc) = nu' sc
    nua (FnCase n  args sc)   = nu' sc
    nua (ConstCase _ sc)      = nu' sc
    nua (SucCase _ sc)        = nu' sc
    nua (DefaultCase sc)      = nu' sc

-- Return whether name is used anywhere in a case tree
isUsed :: SC -> Name -> Bool
isUsed sc n = used sc where

  used (Case _ n' alts) = n == n' || any usedA alts
  used (ProjCase t alts) = n `elem` freeNames t || any usedA alts
  used (STerm t) = n `elem` freeNames t
  used _ = False

  usedA (ConCase _ _ args sc) = used sc
  usedA (FnCase _ args sc) = used sc
  usedA (ConstCase _ sc) = used sc
  usedA (SucCase _ sc) = used sc
  usedA (DefaultCase sc) = used sc

type ErasureInfo = Name -> [Int]  -- name to list of inaccessible arguments; empty list if name not found
type CaseBuilder a = ReaderT ErasureInfo (State CS) a

runCaseBuilder :: ErasureInfo -> CaseBuilder a -> (CS -> (a, CS))
runCaseBuilder ei bld = runState $ runReaderT bld ei

data Phase = CompileTime | RunTime
    deriving (Show, Eq)

-- Generate a simple case tree
-- Work Right to Left

simpleCase :: Bool -> SC -> Bool ->
              Phase -> FC -> [Int] -> [Type] ->
              [([Name], Term, Term)] ->
              ErasureInfo ->
              TC CaseDef
simpleCase tc defcase reflect phase fc inacc argtys cs erInfo
      = sc' tc defcase phase fc (filter (\(_, _, r) ->
                                          case r of
                                            Impossible -> False
                                            _ -> True) cs)
          where
 sc' tc defcase phase fc []
                 = return $ CaseDef [] (UnmatchedCase "No pattern clauses") []
 sc' tc defcase phase fc cs
      = let proj       = phase == RunTime
            vnames     = fstT (head cs)
            pats       = map (\ (avs, l, r) ->
                                   (avs, toPats reflect tc l, (l, r))) cs
            chkPats    = mapM chkAccessible pats in
            case chkPats of
                OK pats ->
                    let numargs    = length (fst (head pats))
                        ns         = take numargs args
                        (ns', ps') = order [(n, i `elem` inacc) | (i,n) <- zip [0..] ns] pats
                        (tree, st) = runCaseBuilder erInfo
                                         (match ns' ps' defcase)
                                         ([], numargs, [])
                        t          = CaseDef ns (prune proj (depatt ns' tree)) (fstT st) in
                        if proj then return (stripLambdas t)
                                else return t
                Error err -> Error (At fc err)
    where args = map (\i -> sMN i "e") [0..]
          defaultCase True = STerm Erased
          defaultCase False = UnmatchedCase "Error"
          fstT (x, _, _) = x
          lstT (_, _, x) = x

          -- Check that all pattern variables are reachable by a case split
          -- Otherwise, they won't make sense on the RHS.
          chkAccessible (avs, l, c)
               | phase == RunTime || reflect = return (l, c)
               | otherwise = do mapM_ (acc l) avs
                                return (l, c)

          acc [] n = Error (Inaccessible n)
          acc (PV x t : xs) n | x == n = OK ()
          acc (PCon _ _ _ ps : xs) n = acc (ps ++ xs) n
          acc (PSuc p : xs) n = acc (p : xs) n
          acc (_ : xs) n = acc xs n

-- For each 'Case', make sure every choice is in the same type family,
-- as directed by the variable type (i.e. there is no implicit type casing
-- going on).

checkSameTypes :: [(Name, Type)] -> SC -> Bool
checkSameTypes tys (Case _ n alts)
        = case lookup n tys of
               Just t -> and (map (checkAlts t) alts)
               _ -> and (map ((checkSameTypes tys).getSC) alts)
  where
    checkAlts t (ConCase n _ _ sc) = isType n t && checkSameTypes tys sc
    checkAlts (Constant t) (ConstCase c sc) = isConstType c t && checkSameTypes tys sc
    checkAlts _ (ConstCase c sc) = False
    checkAlts _ _ = True

    getSC (ConCase _ _ _ sc) = sc
    getSC (FnCase _ _ sc) = sc
    getSC (ConstCase _ sc) = sc
    getSC (SucCase _ sc) = sc
    getSC (DefaultCase sc) = sc
checkSameTypes _ _ = True

-- FIXME: All we're actually doing here is checking that we haven't arrived
-- at a specific constructor for a polymorphic argument. I *think* this
-- is sufficient, but if it turns out not to be, fix it!
--
-- Issue #1718 on the issue tracker: https://github.com/idris-lang/Idris-dev/issues/1718
isType n t | (P (TCon _ _) _ _, _) <- unApply t = True
isType n t | (P Ref _ _, _) <- unApply t = True
isType n t = False

isConstType (I _) (AType (ATInt ITNative)) = True
isConstType (BI _) (AType (ATInt ITBig)) = True
isConstType (Fl _) (AType ATFloat) = True
isConstType (Ch _) (AType (ATInt ITChar)) = True
isConstType (Str _) StrType = True
isConstType (B8 _) (AType (ATInt _)) = True
isConstType (B16 _) (AType (ATInt _)) = True
isConstType (B32 _) (AType (ATInt _)) = True
isConstType (B64 _) (AType (ATInt _)) = True
isConstType _ _ = False

data Pat = PCon Bool Name Int [Pat]
         | PConst Const
         | PV Name Type
         | PSuc Pat -- special case for n+1 on Integer
         | PReflected Name [Pat]
         | PAny
         | PTyPat -- typecase, not allowed, inspect last
    deriving Show

-- If there are repeated variables, take the *last* one (could be name shadowing
-- in a where clause, so take the most recent).

toPats :: Bool -> Bool -> Term -> [Pat]
toPats reflect tc f = reverse (toPat reflect tc (getArgs f)) where
   getArgs (App _ f a) = a : getArgs f
   getArgs _ = []

toPat :: Bool -> Bool -> [Term] -> [Pat]
toPat reflect tc = map $ toPat' []
  where
    toPat' [_,_,arg] (P (DCon t a uniq) nm@(UN n) _)
        | n == txt "Delay"
        = PCon uniq nm t [PAny, PAny, toPat' [] arg]

    toPat' args (P (DCon t a uniq) nm@(NS (UN n) [own]) _)
        | n == txt "Read" && own == txt "Ownership"
        = PCon False nm t (map shareCons (map (toPat' []) args))
      where shareCons (PCon _ n i ps) = PCon False n i (map shareCons ps)
            shareCons p = p

    toPat' args (P (DCon t a uniq) n _)
        = PCon uniq n t $ map (toPat' []) args

    -- n + 1
    toPat' [p, Constant (BI 1)] (P _ (UN pabi) _)
        | pabi == txt "prim__addBigInt"
        = PSuc $ toPat' [] p

    toPat' []   (P Bound n ty) = PV n ty
    toPat' args (App _ f a)    = toPat' (a : args) f
    toPat' [] (Constant x) | isTypeConst x = PTyPat
                           | otherwise     = PConst x

    toPat' [] (Bind n (Pi _ t _) sc)
        | reflect && noOccurrence n sc
        = PReflected (sUN "->") [toPat' [] t, toPat' [] sc]

    toPat' args (P _ n _)
        | reflect
        = PReflected n $ map (toPat' []) args

    toPat' _ t = PAny

    fixedN IT8 = "Bits8"
    fixedN IT16 = "Bits16"
    fixedN IT32 = "Bits32"
    fixedN IT64 = "Bits64"


data Partition = Cons [Clause]
               | Vars [Clause]
    deriving Show

isVarPat (PV _ _ : ps , _) = True
isVarPat (PAny   : ps , _) = True
isVarPat (PTyPat : ps , _) = True
isVarPat _                 = False

isConPat (PCon _ _ _ _ : ps, _) = True
isConPat (PReflected _ _ : ps, _) = True
isConPat (PSuc _   : ps, _) = True
isConPat (PConst _   : ps, _) = True
isConPat _                    = False

partition :: [Clause] -> [Partition]
partition [] = []
partition ms@(m : _)
    | isVarPat m = let (vars, rest) = span isVarPat ms in
                       Vars vars : partition rest
    | isConPat m = let (cons, rest) = span isConPat ms in
                       Cons cons : partition rest
partition xs = error $ "Partition " ++ show xs

-- reorder the patterns so that the one with most distinct names
-- comes next. Take rightmost first, otherwise (i.e. pick value rather
-- than dependency)
--
-- The first argument means [(Name, IsInaccessible)].

order :: [(Name, Bool)] -> [Clause] -> ([Name], [Clause])
order []  cs = ([], cs)
order ns' [] = (map fst ns', [])
order ns' cs = let patnames = transpose (map (zip ns') (map fst cs))
                   -- only sort the arguments where there is no clash in
                   -- constructor tags between families, and no constructor/constant
                   -- clash, because otherwise we can't reliable make the
                   -- case distinction on evaluation
                   (patnames_ord, patnames_rest)
                        = Data.List.partition (noClash . map snd) patnames
                   -- note: sortBy . reverse is not nonsense because sortBy is stable
                   pats' = transpose (sortBy moreDistinct (reverse patnames_ord)
                                         ++ patnames_rest) in
                   (getNOrder pats', zipWith rebuild pats' cs)
  where
    getNOrder [] = error $ "Failed order on " ++ show (map fst ns', cs)
    getNOrder (c : _) = map (fst . fst) c

    rebuild patnames clause = (map snd patnames, snd clause)

    noClash [] = True
    noClash (p : ps) = not (any (clashPat p) ps) && noClash ps

    clashPat (PCon _ _ _ _) (PConst _) = True
    clashPat (PConst _) (PCon _ _ _ _) = True
    clashPat (PCon _ _ _ _) (PSuc _) = True
    clashPat (PSuc _) (PCon _ _ _ _) = True
    clashPat (PCon _ n i _) (PCon _ n' i' _) | i == i' = n /= n'
    clashPat _ _ = False

    -- this compares (+isInaccessible, -numberOfCases)
    moreDistinct xs ys = compare (snd . fst . head $ xs, numNames [] (map snd ys))
                                 (snd . fst . head $ ys, numNames [] (map snd xs))

    numNames xs (PCon _ n _ _ : ps)
        | not (Left n `elem` xs) = numNames (Left n : xs) ps
    numNames xs (PConst c : ps)
        | not (Right c `elem` xs) = numNames (Right c : xs) ps
    numNames xs (_ : ps) = numNames xs ps
    numNames xs [] = length xs

match :: [Name] -> [Clause] -> SC -- error case
                            -> CaseBuilder SC
match [] (([], ret) : xs) err
    = do (ts, v, ntys) <- get
         put (ts ++ (map (fst.snd) xs), v, ntys)
         case snd ret of
            Impossible -> return ImpossibleCase
            tm -> return $ STerm tm -- run out of arguments
match vs cs err = do let ps = partition cs
                     mixture vs ps err

mixture :: [Name] -> [Partition] -> SC -> CaseBuilder SC
mixture vs [] err = return err
mixture vs (Cons ms : ps) err = do fallthrough <- mixture vs ps err
                                   conRule vs ms fallthrough
mixture vs (Vars ms : ps) err = do fallthrough <- mixture vs ps err
                                   varRule vs ms fallthrough

-- Return the list of inaccessible arguments of a data constructor.
inaccessibleArgs :: Name -> CaseBuilder [Int]
inaccessibleArgs n = do
    getInaccessiblePositions <- ask  -- this function is the only thing in the environment
    return $ getInaccessiblePositions n

data ConType = CName Name Int -- named constructor
             | CFn Name -- reflected function name
             | CSuc -- n+1
             | CConst Const -- constant, not implemented yet
   deriving (Show, Eq)

data Group = ConGroup Bool -- Uniqueness flag
                      ConType -- Constructor
                      [([Pat], Clause)] -- arguments and rest of alternative
   deriving Show

conRule :: [Name] -> [Clause] -> SC -> CaseBuilder SC
conRule (v:vs) cs err = do groups <- groupCons cs
                           caseGroups (v:vs) groups err

caseGroups :: [Name] -> [Group] -> SC -> CaseBuilder SC
caseGroups (v:vs) gs err = do g <- altGroups gs
                              return $ Case (getShared gs) v (sort g)
  where
    getShared (ConGroup True _ _ : _) = Updatable
    getShared _ = Shared

    altGroups [] = return [DefaultCase err]

    altGroups (ConGroup _ (CName n i) args : cs)
        = (:) <$> altGroup n i args <*> altGroups cs

    altGroups (ConGroup _ (CFn n) args : cs)
        = (:) <$> altFnGroup n args <*> altGroups cs

    altGroups (ConGroup _ CSuc args : cs)
        = (:) <$> altSucGroup args <*> altGroups cs

    altGroups (ConGroup _ (CConst c) args : cs)
        = (:) <$> altConstGroup c args <*> altGroups cs

    altGroup n i args
         = do inacc <- inaccessibleArgs n
              (newVars, accVars, inaccVars, nextCs) <- argsToAlt inacc args
              matchCs <- match (accVars ++ vs ++ inaccVars) nextCs err
              return $ ConCase n i newVars matchCs

    altFnGroup n args = do (newVars, _, [], nextCs) <- argsToAlt [] args
                           matchCs <- match (newVars ++ vs) nextCs err
                           return $ FnCase n newVars matchCs

    altSucGroup args = do ([newVar], _, [], nextCs) <- argsToAlt [] args
                          matchCs <- match (newVar:vs) nextCs err
                          return $ SucCase newVar matchCs

    altConstGroup n args = do (_, _, [], nextCs) <- argsToAlt [] args
                              matchCs <- match vs nextCs err
                              return $ ConstCase n matchCs

-- Returns:
--   * names of all variables arising from match
--   * names of accessible variables (subset of all variables)
--   * names of inaccessible variables (subset of all variables)
--   * clauses corresponding to (accVars ++ origVars ++ inaccVars)
argsToAlt :: [Int] -> [([Pat], Clause)] -> CaseBuilder ([Name], [Name], [Name], [Clause])
argsToAlt _ [] = return ([], [], [], [])
argsToAlt inacc rs@((r, m) : rest) = do
    newVars <- getNewVars r
    let (accVars, inaccVars) = partitionAcc newVars
    return (newVars, accVars, inaccVars, addRs rs)
  where
    -- Create names for new variables arising from the given patterns.
    getNewVars :: [Pat] -> CaseBuilder [Name]
    getNewVars [] = return []
    getNewVars ((PV n t) : ns) = do v <- getVar "e"
                                    nsv <- getNewVars ns

                                    -- Record the type of the variable.
                                    --
                                    -- It seems that the ordering is not important
                                    -- and we can put (v,t) always in front of "ntys"
                                    -- (the varName-type pairs seem to represent a mapping).
                                    --
                                    -- The code that reads this is currently
                                    -- commented out, anyway.
                                    (cs, i, ntys) <- get
                                    put (cs, i, (v, t) : ntys)

                                    return (v : nsv)

    getNewVars (PAny   : ns) = (:) <$> getVar "i" <*> getNewVars ns
    getNewVars (PTyPat : ns) = (:) <$> getVar "t" <*> getNewVars ns
    getNewVars (_      : ns) = (:) <$> getVar "e" <*> getNewVars ns

    -- Partition a list of things into (accessible, inaccessible) things,
    -- according to the list of inaccessible indices.
    partitionAcc xs =
        ( [x | (i,x) <- zip [0..] xs, i `notElem` inacc]
        , [x | (i,x) <- zip [0..] xs, i    `elem` inacc]
        )

    addRs [] = []
    addRs ((r, (ps, res)) : rs) = ((acc++ps++inacc, res) : addRs rs)
      where
        (acc, inacc) = partitionAcc r

    uniq i (UN n) = MN i n
    uniq i n = n

getVar :: String -> CaseBuilder Name
getVar b = do (t, v, ntys) <- get; put (t, v+1, ntys); return (sMN v b)

groupCons :: [Clause] -> CaseBuilder [Group]
groupCons cs = gc [] cs
  where
    gc acc [] = return acc
    gc acc ((p : ps, res) : cs) =
        do acc' <- addGroup p ps res acc
           gc acc' cs
    addGroup p ps res acc = case p of
        PCon uniq con i args -> return $ addg uniq (CName con i) args (ps, res) acc
        PConst cval -> return $ addConG cval (ps, res) acc
        PSuc n -> return $ addg False CSuc [n] (ps, res) acc
        PReflected fn args -> return $ addg False (CFn fn) args (ps, res) acc
        pat -> fail $ show pat ++ " is not a constructor or constant (can't happen)"

    addg uniq c conargs res []
           = [ConGroup uniq c [(conargs, res)]]
    addg uniq c conargs res (g@(ConGroup _ c' cs):gs)
        | c == c' = ConGroup uniq c (cs ++ [(conargs, res)]) : gs
        | otherwise = g : addg uniq c conargs res gs

    addConG con res [] = [ConGroup False (CConst con) [([], res)]]
    addConG con res (g@(ConGroup False (CConst n) cs) : gs)
        | con == n = ConGroup False (CConst n) (cs ++ [([], res)]) : gs
--         | otherwise = g : addConG con res gs
    addConG con res (g : gs) = g : addConG con res gs

varRule :: [Name] -> [Clause] -> SC -> CaseBuilder SC
varRule (v : vs) alts err =
    do alts' <- mapM (repVar v) alts
       match vs alts' err
  where
    repVar v (PV p ty : ps , (lhs, res))
           = do (cs, i, ntys) <- get
                put (cs, i, (v, ty) : ntys)
                return (ps, (lhs, subst p (P Bound v ty) res))
    repVar v (PAny : ps , res) = return (ps, res)
    repVar v (PTyPat : ps , res) = return (ps, res)

-- fix: case e of S k -> f (S k)  ==> case e of S k -> f e

depatt :: [Name] -> SC -> SC
depatt ns tm = dp [] tm
  where
    dp ms (STerm tm) = STerm (applyMaps ms tm)
    dp ms (Case up x alts) = Case up x (map (dpa ms x) alts)
    dp ms sc = sc

    dpa ms x (ConCase n i args sc)
        = ConCase n i args (dp ((x, (n, args)) : ms) sc)
    dpa ms x (FnCase n args sc)
        = FnCase n args (dp ((x, (n, args)) : ms) sc)
    dpa ms x (ConstCase c sc) = ConstCase c (dp ms sc)
    dpa ms x (SucCase n sc) = SucCase n (dp ms sc)
    dpa ms x (DefaultCase sc) = DefaultCase (dp ms sc)

    applyMaps ms f@(App _ _ _)
       | (P nt cn pty, args) <- unApply f
            = let args' = map (applyMaps ms) args in
                  applyMap ms nt cn pty args'
        where
          applyMap [] nt cn pty args' = mkApp (P nt cn pty) args'
          applyMap ((x, (n, args)) : ms) nt cn pty args'
            | and ((length args == length args') :
                     (n == cn) : zipWith same args args') = P Ref x Erased
            | otherwise = applyMap ms nt cn pty args'
          same n (P _ n' _) = n == n'
          same _ _ = False

    applyMaps ms (App s f a) = App s (applyMaps ms f) (applyMaps ms a)
    applyMaps ms t = t

-- FIXME: Do this for SucCase too
-- Issue #1719 on the issue tracker:  https://github.com/idris-lang/Idris-dev/issues/1719
prune :: Bool -- ^ Convert single branches to projections (only useful at runtime)
      -> SC -> SC
prune proj (Case up n alts) = case alts' of
    [] -> ImpossibleCase

    -- Projection transformations prevent us from seeing some uses of ctor fields
    -- because they delete information about which ctor is being used.
    -- Consider:
    --   f (X x) = ...  x  ...
    -- vs.
    --   f  x    = ... x!0 ...
    --
    -- Hence, we disable this step.
    -- TODO: re-enable this in toIR
    --
    -- as@[ConCase cn i args sc]
    --     | proj -> mkProj n 0 args (prune proj sc)
    -- mkProj n i xs sc = foldr (\x -> projRep x n i) sc xs

    -- If none of the args are used in the sc, however, we can just replace it
    -- with sc
    as@[ConCase cn i args sc]
        | proj -> let sc' = prune proj sc in
                      if any (isUsed sc') args
                         then Case up n [ConCase cn i args sc']
                         else sc'

    [SucCase cn sc]
        | proj
        -> projRep cn n (-1) $ prune proj sc

    [ConstCase _ sc]
        -> prune proj sc

    -- Bit of a hack here! The default case will always be 0, make sure
    -- it gets caught first.
    [s@(SucCase _ _), DefaultCase dc]
        -> Case up n [ConstCase (BI 0) dc, s]

    as  -> Case up n as
  where
    alts' = filter (not . erased) $ map pruneAlt alts

    pruneAlt (ConCase cn i ns sc) = ConCase cn i ns (prune proj sc)
    pruneAlt (FnCase cn ns sc) = FnCase cn ns (prune proj sc)
    pruneAlt (ConstCase c sc) = ConstCase c (prune proj sc)
    pruneAlt (SucCase n sc) = SucCase n (prune proj sc)
    pruneAlt (DefaultCase sc) = DefaultCase (prune proj sc)

    erased (DefaultCase (STerm Erased)) = True
    erased (DefaultCase ImpossibleCase) = True
    erased _ = False

    projRep :: Name -> Name -> Int -> SC -> SC
    projRep arg n i (Case up x alts) | x == arg
        = ProjCase (Proj (P Bound n Erased) i) $ map (projRepAlt arg n i) alts
    projRep arg n i (Case up x alts)
        = Case up x (map (projRepAlt arg n i) alts)
    projRep arg n i (ProjCase t alts)
        = ProjCase (projRepTm arg n i t) $ map (projRepAlt arg n i) alts
    projRep arg n i (STerm t) = STerm (projRepTm arg n i t)
    projRep arg n i c = c

    projRepAlt arg n i (ConCase cn t args rhs)
        = ConCase cn t args (projRep arg n i rhs)
    projRepAlt arg n i (FnCase cn args rhs)
        = FnCase cn args (projRep arg n i rhs)
    projRepAlt arg n i (ConstCase t rhs)
        = ConstCase t (projRep arg n i rhs)
    projRepAlt arg n i (SucCase sn rhs)
        = SucCase sn (projRep arg n i rhs)
    projRepAlt arg n i (DefaultCase rhs)
        = DefaultCase (projRep arg n i rhs)

    projRepTm arg n i t = subst arg (Proj (P Bound n Erased) i) t

prune _ t = t

stripLambdas :: CaseDef -> CaseDef
stripLambdas (CaseDef ns (STerm (Bind x (Lam _) sc)) tm)
    = stripLambdas (CaseDef (ns ++ [x]) (STerm (instantiate (P Bound x Erased) sc)) tm)
stripLambdas x = x

substSC :: Name -> Name -> SC -> SC
substSC n repl (Case up n' alts)
    | n == n'   = Case up repl (map (substAlt n repl) alts)
    | otherwise = Case up n'   (map (substAlt n repl) alts)
substSC n repl (STerm t) = STerm $ subst n (P Bound repl Erased) t
substSC n repl (UnmatchedCase errmsg) = UnmatchedCase errmsg
substSC n repl  ImpossibleCase = ImpossibleCase
substSC n repl sc = error $ "unsupported in substSC: " ++ show sc

substAlt :: Name -> Name -> CaseAlt -> CaseAlt
substAlt n repl (ConCase cn a ns sc) = ConCase cn a ns (substSC n repl sc)
substAlt n repl (FnCase fn ns sc)    = FnCase fn ns (substSC n repl sc)
substAlt n repl (ConstCase c sc)     = ConstCase c (substSC n repl sc)
substAlt n repl (SucCase n' sc)
    | n == n'   = SucCase n  (substSC n repl sc)
    | otherwise = SucCase n' (substSC n repl sc)
substAlt n repl (DefaultCase sc)     = DefaultCase (substSC n repl sc)

-- mkForce n' n t updates the tree t under the assumption that
-- n' = force n (so basically updating n to n')
mkForce :: Name -> Name -> SC -> SC
mkForce = mkForceSC
  where
    mkForceSC n arg (Case up x alts) | x == arg
        = Case up n $ map (mkForceAlt n arg) alts

    mkForceSC n arg (Case up x alts)
        = Case up x (map (mkForceAlt n arg) alts)

    mkForceSC n arg (ProjCase t alts)
        = ProjCase t $ map (mkForceAlt n arg) alts

    mkForceSC n arg c = c

    mkForceAlt n arg (ConCase cn t args rhs)
        = ConCase cn t args (mkForceSC n arg rhs)
    mkForceAlt n arg (FnCase cn args rhs)
        = FnCase cn args (mkForceSC n arg rhs)
    mkForceAlt n arg (ConstCase t rhs)
        = ConstCase t (mkForceSC n arg rhs)
    mkForceAlt n arg (SucCase sn rhs)
        = SucCase sn (mkForceSC n arg rhs)
    mkForceAlt n arg (DefaultCase rhs)
        = DefaultCase (mkForceSC n arg rhs)

    forceTm n arg t = subst n arg t