{-|
Module      : Idris.Elab.Rewrite
Description : Code to elaborate rewrite rules.

License     : BSD3
Maintainer  : The Idris Community.
-}
{-# LANGUAGE PatternGuards #-}
{-# OPTIONS_GHC -fwarn-incomplete-patterns #-}
module Idris.Elab.Rewrite(elabRewrite, elabRewriteLemma) where

import Idris.AbsSyntax
import Idris.Core.Elaborate
import Idris.Core.Evaluate
import Idris.Core.TT
import Idris.Delaborate
import Idris.Docstrings
import Idris.Error

import Control.Monad
import Control.Monad.State.Strict

elabRewrite :: (PTerm -> ElabD ()) -> IState ->
               FC -> Maybe Name -> PTerm -> PTerm -> Maybe PTerm -> ElabD ()
-- Old version, no rewrite rule given
{-
elabRewrite elab ist fc Nothing rule sc newg
        = do attack
             tyn <- getNameFrom (sMN 0 "rty")
             claim tyn RType
             valn <- getNameFrom (sMN 0 "rval")
             claim valn (Var tyn)
             letn <- getNameFrom (sMN 0 "_rewrite_rule")
             letbind letn (Var tyn) (Var valn)
             focus valn
             elab rule
             compute
             g <- goal
             rewrite (Var letn)
             g' <- goal
             when (g == g') $ lift $ tfail (NoRewriting g)
             case newg of
                 Nothing -> elab sc
                 Just t -> doEquiv elab fc t sc
             solve
             -}
elabRewrite elab ist fc substfn_in rule sc_in newg
        = do attack
             sc <- case newg of
                      Nothing -> return sc_in
                      Just t -> do
                         letn <- getNameFrom (sMN 0 "rewrite_result")
                         return $ PLet fc RigW letn fc t sc_in
                                       (PRef fc [] letn)
             tyn <- getNameFrom (sMN 0 "rty")
             claim tyn RType
             valn <- getNameFrom (sMN 0 "rval")
             claim valn (Var tyn)
             letn <- getNameFrom (sMN 0 "_rewrite_rule")
             letbind letn RigW (Var tyn) (Var valn)
             focus valn
             elab rule
             compute
             g <- goal
             (tmv, rule_in) <- get_type_val (Var letn)
             env <- get_env
             let ttrule = normalise (tt_ctxt ist) env rule_in
             rname <- unique_hole (sMN 0 "replaced")
             case unApply ttrule of
                  (P _ (UN q) _, [lt, rt, l, r]) | q == txt "=" ->
                     do substfn <- findSubstFn substfn_in ist lt rt
                        let pred_tt = mkP (P Bound rname rt) l r g
                        when (g == pred_tt) $ lift $ tfail (NoRewriting l r g)
                        let pred = PLam fc rname fc Placeholder
                                        (delab ist pred_tt)
                        let rewrite = addImplBound ist (map fstEnv env) (PApp fc (PRef fc [] substfn)
                                           [pexp (stripImpls pred),
                                            pexp (stripImpls rule), pexp sc])
--                         trace ("LHS: " ++ show l ++ "\n" ++
--                                "RHS: " ++ show r ++ "\n" ++
--                                "REWRITE: " ++ show rewrite ++ "\n" ++
--                                "GOAL: " ++ show (delab ist g)) $
                        elab rewrite
                        solve
                  _ -> lift $ tfail (NotEquality tmv ttrule)
      where
        mkP :: TT Name -> -- Left term, top level
               TT Name -> TT Name -> TT Name -> TT Name
        mkP lt l r ty | l == ty = lt
        mkP lt l r (App s f a)
            = let f' = if (r /= f) then mkP lt l r f else f
                  a' = if (r /= a) then mkP lt l r a else a in
                  App s f' a'
        mkP lt l r (Bind n b sc)
             = let b' = mkPB b
                   sc' = if (r /= sc) then mkP lt l r sc else sc in
                   Bind n b' sc'
            where mkPB (Let rig t v)
                       = let t' = if (r /= t) then mkP lt l r t else t
                             v' = if (r /= v) then mkP lt l r v else v in
                             Let rig t' v'
                  mkPB b = let ty = binderTy b
                               ty' = if (r /= ty) then mkP lt l r ty else ty in
                                     b { binderTy = ty' }
        mkP lt l r x = x

        -- If we're rewriting a dependent type, the rewrite type the rewrite
        -- may break the indices.
        -- So, for any function argument which can be determined by looking
        -- at the indices (i.e. any constructor guarded elsewhere in the type)
        -- replace it with a _. e.g. (++) a n m xs ys becomes (++) _ _ _ xs ys
        -- because a,n, and m are constructor guarded later in the type of ++

        -- FIXME: Currently this is an approximation which just strips
        -- implicits. This is perfectly fine for most cases, but not as
        -- general as it should be.
        stripImpls tm = mapPT phApps tm

        phApps (PApp fc f args) = PApp fc f (map removeImp args)
          where removeImp tm@(PImp{}) = tm { getTm = Placeholder }
                removeImp t = t
        phApps tm = tm

findSubstFn :: Maybe Name -> IState -> Type -> Type -> ElabD Name
findSubstFn Nothing ist lt rt
     | lt == rt = return (sUN "rewrite__impl")
     -- TODO: Find correct rewriting lemma, if it exists, and tidy up this
     -- error
     | (P _ lcon _, _) <- unApply lt,
       (P _ rcon _, _) <- unApply rt,
       lcon == rcon
         = case lookupTyExact (rewrite_name lcon) (tt_ctxt ist) of
                Just _ -> return (rewrite_name lcon)
                Nothing -> rewriteFail lt rt
     | otherwise = rewriteFail lt rt
  where rewriteFail lt rt = lift . tfail .
                     Msg $ "Can't rewrite heterogeneous equality on types " ++
                           show (delab ist lt) ++ " and " ++ show (delab ist rt)

findSubstFn (Just substfn_in) ist lt rt
    = case lookupTyName substfn_in (tt_ctxt ist) of
           [(n, t)] -> return n
           [] -> lift . tfail . NoSuchVariable $ substfn_in
           more -> lift . tfail . CantResolveAlts $ map fst more

rewrite_name :: Name -> Name
rewrite_name n = sMN 0 (show n ++ "_rewrite_lemma")

data ParamInfo = Index
               | Param
               | ImplicitIndex
               | ImplicitParam
  deriving Show

getParamInfo :: Type -> [PArg] -> Int -> [Int] -> [ParamInfo]
-- Skip the top level implicits, we just need the explicit ones
getParamInfo (Bind n (Pi _ _ _ _) sc) (PExp{} : is) i ps
    | i `elem` ps = Param : getParamInfo sc is (i + 1) ps
    | otherwise = Index : getParamInfo sc is (i + 1) ps
getParamInfo (Bind n (Pi _ _ _ _) sc) (_ : is) i ps
    | i `elem` ps = ImplicitParam : getParamInfo sc is (i + 1) ps
    | otherwise = ImplicitIndex : getParamInfo sc is (i + 1) ps
getParamInfo _ _ _ _ = []

isParam Index = False
isParam Param = True
isParam ImplicitIndex = False
isParam ImplicitParam = True

-- | Make a rewriting lemma for the given type constructor.
--
-- If it fails, fail silently (it may well fail for some very complex
-- types.  We can fix this later, for now this gives us a lot more
-- working rewrites...)
elabRewriteLemma :: ElabInfo -> Name -> Type -> Idris ()
elabRewriteLemma info n cty_in =
    do ist <- getIState
       let cty = normalise (tt_ctxt ist) [] cty_in
       let rewrite_lem = rewrite_name n
       case lookupCtxtExact n (idris_datatypes ist) of
            Nothing -> fail $ "Can't happen, elabRewriteLemma for " ++ show n
            Just ti -> do
               let imps = case lookupCtxtExact n (idris_implicits ist) of
                               Nothing -> repeat (pexp Placeholder)
                               Just is -> is
               let pinfo = getParamInfo cty imps 0 (param_pos ti)
               if all isParam pinfo
                  then return () -- no need for a lemma, = will be homogeneous
                  else idrisCatch (mkLemma info rewrite_lem n pinfo cty)
                                  (\_ -> return ())

mkLemma :: ElabInfo -> Name -> Name -> [ParamInfo] -> Type -> Idris ()
mkLemma info lemma tcon ps ty =
    do ist <- getIState
       let leftty = mkTy tcon ps (namesFrom "p" 0) (namesFrom "left" 0)
           rightty = mkTy tcon ps (namesFrom "p" 0) (namesFrom "right" 0)
           predty = bindIdxs ist ps ty (namesFrom "pred" 0) $
                        PPi expl (sMN 0 "rep") fc
                          (mkTy tcon ps (namesFrom "p" 0) (namesFrom "pred" 0))
                          (PType fc)
       let xarg = sMN 0 "x"
       let yarg = sMN 0 "y"
       let parg = sMN 0 "P"
       let eq = sMN 0 "eq"
       let prop = sMN 0 "prop"
       let lemTy = PPi impl xarg fc leftty $
                   PPi impl yarg fc rightty $
                   PPi expl parg fc predty $
                   PPi expl eq fc (PApp fc (PRef fc [] (sUN "="))
                                       [pexp (PRef fc [] xarg),
                                        pexp (PRef fc [] yarg)]) $
                   PPi expl prop fc (PApp fc (PRef fc [] parg)
                                       [pexp (PRef fc [] yarg)]) $
                       PApp fc (PRef fc [] parg) [pexp (PRef fc [] xarg)]

       let lemLHS = PApp fc (PRef fc [] lemma)
                            [pexp (PRef fc [] parg),
                             pexp (PRef fc [] (sUN "Refl")),
                             pexp (PRef fc [] prop)]

       let lemRHS = PRef fc [] prop

       -- Elaborate the type of the lemma
       rec_elabDecl info EAll info
            (PTy emptyDocstring [] defaultSyntax fc [] lemma fc lemTy)
       -- Elaborate the definition
       rec_elabDecl info EAll info
            (PClauses fc [] lemma [PClause fc lemma lemLHS [] lemRHS []])

  where
    fc = emptyFC

    namesFrom x i = sMN i (x ++ show i) : namesFrom x (i + 1)

    mkTy fn pinfo ps is
         = PApp fc (PRef fc [] fn) (mkArgs pinfo ps is)

    mkArgs [] ps is = []
    mkArgs (Param : pinfo) (p : ps) is
         = pexp (PRef fc [] p) : mkArgs pinfo ps is
    mkArgs (Index : pinfo) ps (i : is)
         = pexp (PRef fc [] i) : mkArgs pinfo ps is
    mkArgs (ImplicitParam : pinfo) (p : ps) is
         = mkArgs pinfo ps is
    mkArgs (ImplicitIndex : pinfo) ps (i : is)
         = mkArgs pinfo ps is
    mkArgs _ _ _ = []

    bindIdxs ist [] ty is tm = tm
    bindIdxs ist (Param : pinfo) (Bind n (Pi _ _ ty _) sc) is tm
         = bindIdxs ist pinfo (instantiate (P Bound n ty) sc) is tm
    bindIdxs ist (Index : pinfo) (Bind n (Pi _ _ ty _) sc) (i : is) tm
        = PPi forall_imp i fc (delab ist ty)
               (bindIdxs ist pinfo (instantiate (P Bound n ty) sc) is tm)
    bindIdxs ist (ImplicitParam : pinfo) (Bind n (Pi _ _ ty _) sc) is tm
        = bindIdxs ist pinfo (instantiate (P Bound n ty) sc) is tm
    bindIdxs ist (ImplicitIndex : pinfo) (Bind n (Pi _ _ ty _) sc) (i : is) tm
        = bindIdxs ist pinfo (instantiate (P Bound n ty) sc) is tm
    bindIdxs _ _ _ _ tm = tm