module Idris.CaseSplit(
splitOnLine, replaceSplits
, getClause, getProofClause
, mkWith
, nameMissing
, getUniq, nameRoot
) where
import Idris.AbsSyntax
import Idris.AbsSyntaxTree (Idris, IState, PTerm)
import Idris.ElabDecls
import Idris.Delaborate
import Idris.Parser
import Idris.Parser.Helpers
import Idris.Error
import Idris.Output
import Idris.Elab.Value
import Idris.Elab.Term
import Idris.Core.TT
import Idris.Core.Typecheck
import Idris.Core.Evaluate
import Data.Maybe
import Data.Char
import Data.List (isPrefixOf, isSuffixOf)
import Control.Monad
import Control.Monad.State.Strict
import Text.Parser.Combinators
import Text.Parser.Char(anyChar)
import Text.Trifecta(Result(..), parseString)
import Text.Trifecta.Delta
import Debug.Trace
split :: Name -> PTerm -> Idris (Bool, [[(Name, PTerm)]])
split n t'
= do ist <- getIState
mapM_ (\n -> setAccessibility n Public) (allNamesIn t')
(tm, ty, pats) <- elabValBind (recinfo (fileFC "casesplit")) ETyDecl True (addImplPat ist t')
logElab 4 ("Elaborated:\n" ++ show tm ++ " : " ++ show ty ++ "\n" ++ show pats)
let t = mergeUserImpl (addImplPat ist t') (delabDirect ist tm)
let ctxt = tt_ctxt ist
case lookup n pats of
Nothing -> ifail $ show n ++ " is not a pattern variable"
Just ty ->
do let splits = findPats ist ty
logElab 1 ("New patterns " ++ showSep ", "
(map showTmImpls splits))
let newPats_in = zipWith (replaceVar ctxt n) splits (repeat t)
logElab 4 ("Working from " ++ showTmImpls t)
logElab 4 ("Trying " ++ showSep "\n"
(map (showTmImpls) newPats_in))
newPats_in <- mapM elabNewPat newPats_in
case anyValid [] [] newPats_in of
Left fails -> do
let fails' = mergeAllPats ist n t fails
return (False, (map snd fails'))
Right newPats -> do
logElab 3 ("Original:\n" ++ show t)
logElab 3 ("Split:\n" ++
(showSep "\n" (map show newPats)))
logElab 3 "----"
let newPats' = mergeAllPats ist n t newPats
logElab 1 ("Name updates " ++ showSep "\n"
(map (\ (p, u) -> show u ++ " " ++ show p) newPats'))
return (True, (map snd newPats'))
where
anyValid ok bad [] = if null ok then Left (reverse bad)
else Right (reverse ok)
anyValid ok bad ((tc, p) : ps)
| tc = anyValid (p : ok) bad ps
| otherwise = anyValid ok (p : bad) ps
data MergeState = MS { namemap :: [(Name, Name)],
invented :: [(Name, Name)],
explicit :: [Name],
updates :: [(Name, PTerm)] }
addUpdate :: Name -> Idris.AbsSyntaxTree.PTerm -> State MergeState ()
addUpdate n tm = do ms <- get
put (ms { updates = ((n, stripNS tm) : updates ms) } )
inventName :: Idris.AbsSyntaxTree.IState -> Maybe Name -> Name -> State MergeState Name
inventName ist ty n =
do ms <- get
let supp = case ty of
Nothing -> []
Just t -> getNameHints ist t
let nsupp = case n of
MN i n | not (tnull n) && thead n == '_'
-> mkSupply (supp ++ varlist)
MN i n -> mkSupply (UN n : supp ++ varlist)
UN n | thead n == '_'
-> mkSupply (supp ++ varlist)
x -> mkSupply (x : supp)
let badnames = map snd (namemap ms) ++ map snd (invented ms) ++
explicit ms
case lookup n (invented ms) of
Just n' -> return n'
Nothing ->
do let n' = uniqueNameFrom nsupp badnames
put (ms { invented = (n, n') : invented ms })
return n'
mkSupply :: [Name] -> [Name]
mkSupply ns = mkSupply' ns (map nextName ns)
where mkSupply' xs ns' = xs ++ mkSupply ns'
varlist :: [Name]
varlist = map (sUN . (:[])) "xyzwstuv"
stripNS :: Idris.AbsSyntaxTree.PTerm -> Idris.AbsSyntaxTree.PTerm
stripNS tm = mapPT dens tm where
dens (PRef fc hls n) = PRef fc hls (nsroot n)
dens t = t
mergeAllPats :: IState -> Name -> PTerm -> [PTerm] -> [(PTerm, [(Name, PTerm)])]
mergeAllPats ist cv t [] = []
mergeAllPats ist cv t (p : ps)
= let (p', MS _ _ _ u) = runState (mergePat ist t p Nothing)
(MS [] [] (filter (/=cv) (patvars t)) [])
ps' = mergeAllPats ist cv t ps in
((p', u) : ps')
where patvars (PRef _ _ n) = [n]
patvars (PApp _ _ as) = concatMap (patvars . getTm) as
patvars (PPatvar _ n) = [n]
patvars _ = []
mergePat :: IState -> PTerm -> PTerm -> Maybe Name -> State MergeState PTerm
mergePat ist orig new t =
do
case matchClause ist orig new of
Left _ -> return ()
Right ns -> mapM_ addNameMap ns
mergePat' ist orig new t
where
addNameMap (n, PRef fc _ n') = do ms <- get
put (ms { namemap = ((n', n) : namemap ms) })
addNameMap _ = return ()
mergePat' ist (PPatvar fc n) new t
= mergePat' ist (PRef fc [] n) new t
mergePat' ist old (PPatvar fc n) t
= mergePat' ist old (PRef fc [] n) t
mergePat' ist orig@(PRef fc _ n) new@(PRef _ _ n') t
| isDConName n' (tt_ctxt ist) = do addUpdate n new
return new
| otherwise
= do ms <- get
case lookup n' (namemap ms) of
Just x -> do addUpdate n (PRef fc [] x)
return (PRef fc [] x)
Nothing -> do put (ms { namemap = ((n', n) : namemap ms) })
return (PRef fc [] n)
mergePat' ist (PApp _ _ args) (PApp fc f args') t
= do newArgs <- zipWithM mergeArg args (zip args' (argTys ist f))
return (PApp fc f newArgs)
where mergeArg x (y, t)
= do tm' <- mergePat' ist (getTm x) (getTm y) t
case x of
(PImp _ _ _ _ _) ->
return (y { machine_inf = machine_inf x,
getTm = tm' })
_ -> return (y { getTm = tm' })
mergePat' ist (PRef fc _ n) tm ty = do tm <- tidy ist tm ty
addUpdate n tm
return tm
mergePat' ist x y t = return y
mergeUserImpl :: PTerm -> PTerm -> PTerm
mergeUserImpl x y = x
argTys :: IState -> PTerm -> [Maybe Name]
argTys ist (PRef fc hls n)
= case lookupTy n (tt_ctxt ist) of
[ty] -> let ty' = normalise (tt_ctxt ist) [] ty in
map (tyName . snd) (getArgTys ty') ++ repeat Nothing
_ -> repeat Nothing
where tyName (Bind _ (Pi _ _ _) _) = Just (sUN "->")
tyName t | (P _ d _, [_, ty]) <- unApply t,
d == sUN "Delayed" = tyName ty
| (P _ n _, _) <- unApply t = Just n
| otherwise = Nothing
argTys _ _ = repeat Nothing
tidy :: IState -> PTerm -> Maybe Name -> State MergeState PTerm
tidy ist orig@(PRef fc hls n) ty
= do ms <- get
case lookup n (namemap ms) of
Just x -> return (PRef fc [] x)
Nothing -> case n of
(UN _) -> return orig
_ -> do n' <- inventName ist ty n
return (PRef fc [] n')
tidy ist (PApp fc f args) ty
= do args' <- zipWithM tidyArg args (argTys ist f)
return (PApp fc f args')
where tidyArg x ty' = do tm' <- tidy ist (getTm x) ty'
return (x { getTm = tm' })
tidy ist tm ty = return tm
elabNewPat :: PTerm -> Idris (Bool, PTerm)
elabNewPat t = idrisCatch (do (tm, ty) <- elabVal (recinfo (fileFC "casesplit")) ELHS t
i <- getIState
return (True, delabDirect i tm))
(\e -> do i <- getIState
logElab 5 $ "Not a valid split:\n" ++ showTmImpls t ++ "\n"
++ pshow i e
return (False, t))
findPats :: IState -> Type -> [PTerm]
findPats ist t | (P _ n _, _) <- unApply t
= case lookupCtxt n (idris_datatypes ist) of
[ti] -> map genPat (con_names ti)
_ -> [Placeholder]
where genPat n = case lookupCtxt n (idris_implicits ist) of
[args] -> PApp emptyFC (PRef emptyFC [] n)
(map toPlaceholder args)
_ -> error $ "Can't happen (genPat) " ++ show n
toPlaceholder tm = tm { getTm = Placeholder }
findPats ist t = [Placeholder]
replaceVar :: Context -> Name -> PTerm -> PTerm -> PTerm
replaceVar ctxt n t (PApp fc f pats) = PApp fc f (map substArg pats)
where subst :: PTerm -> PTerm
subst orig@(PPatvar _ v) | v == n = t
| otherwise = Placeholder
subst orig@(PRef _ _ v) | v == n = t
| isDConName v ctxt = orig
subst (PRef _ _ _) = Placeholder
subst (PApp fc (PRef _ _ t) pats)
| isTConName t ctxt = Placeholder
subst (PApp fc f pats) = PApp fc f (map substArg pats)
subst x = x
substArg arg = arg { getTm = subst (getTm arg) }
replaceVar ctxt n t pat = pat
splitOnLine :: Int
-> Name
-> FilePath
-> Idris (Bool, [[(Name, PTerm)]])
splitOnLine l n fn = do
cl <- getInternalApp fn l
logElab 3 ("Working with " ++ showTmImpls cl)
tms <- split n cl
return tms
replaceSplits :: String -> [[(Name, PTerm)]] -> Bool -> Idris [String]
replaceSplits l ups impossible
= do ist <- getIState
updateRHSs 1 (map (rep ist (expandBraces l)) ups)
where
rep ist str [] = str ++ "\n"
rep ist str ((n, tm) : ups)
= rep ist (updatePat False (show n) (nshow (resugar ist tm)) str) ups
updateRHSs i [] = return []
updateRHSs i (x : xs)
| impossible = do xs' <- updateRHSs i xs
return (setImpossible False x : xs')
| otherwise = do (x', i') <- updateRHS (null xs) i x
xs' <- updateRHSs i' xs
return (x' : xs')
updateRHS last i ('?':'=':xs) = do (xs', i') <- updateRHS last i xs
return ("?=" ++ xs', i')
updateRHS last i ('?':xs)
= do let (nm, rest_in) = span (not . (\x -> isSpace x || x == ')'
|| x == '(')) xs
let rest = if last then rest_in else
case span (not . (=='\n')) rest_in of
(_, restnl) -> restnl
(nm', i') <- getUniq nm i
return ('?':nm' ++ rest, i')
updateRHS last i (x : xs) = do (xs', i') <- updateRHS last i xs
return (x : xs', i')
updateRHS last i [] = return ("", i)
setImpossible brace ('}':xs) = '}' : setImpossible False xs
setImpossible brace ('{':xs) = '{' : setImpossible True xs
setImpossible False ('=':xs) = "impossible\n"
setImpossible brace (x : xs) = x : setImpossible brace xs
setImpossible brace [] = ""
nshow (PRef _ _ (UN z)) | z == txt "Z" = "Z"
nshow (PApp _ (PRef _ _ (UN s)) [x]) | s == txt "S" =
"(S " ++ addBrackets (nshow (getTm x)) ++ ")"
nshow t = show t
expandBraces ('{' : '-' : xs) = '{' : '-' : xs
expandBraces ('{' : xs)
= let (brace, (_:rest)) = span (/= '}') xs in
if (not ('=' `elem` brace))
then ('{' : brace ++ " = " ++ brace ++ "}") ++
expandBraces rest
else ('{' : brace ++ "}") ++ expandBraces rest
expandBraces (x : xs) = x : expandBraces xs
expandBraces [] = []
updatePat start n tm [] = []
updatePat start n tm ('{':rest) =
let (space, rest') = span isSpace rest in
'{' : space ++ updatePat False n tm rest'
updatePat start n tm done@('?':rest) = done
updatePat True n tm xs@(c:rest) | length xs > length n
= let (before, after@(next:_)) = splitAt (length n) xs in
if (before == n && not (isAlphaNum next))
then addBrackets tm ++ updatePat False n tm after
else c : updatePat (not (isAlphaNum c)) n tm rest
updatePat start n tm (c:rest) = c : updatePat (not ((isAlphaNum c) || c == '_')) n tm rest
addBrackets tm | ' ' `elem` tm
, not (isPrefixOf "(" tm && isSuffixOf ")" tm)
= "(" ++ tm ++ ")"
| otherwise = tm
getUniq :: (Show t, Num t) => [Char] -> t -> Idris ([Char], t)
getUniq nm i
= do ist <- getIState
let n = nameRoot [] nm ++ "_" ++ show i
case lookupTy (sUN n) (tt_ctxt ist) of
[] -> return (n, i+1)
_ -> getUniq nm (i+1)
nameRoot acc nm | all isDigit nm = showSep "_" acc
nameRoot acc nm =
case span (/='_') nm of
(before, ('_' : after)) -> nameRoot (acc ++ [before]) after
_ -> showSep "_" (acc ++ [nm])
showLHSName :: Name -> String
showLHSName n = let fn = show n in
if any (flip elem opChars) fn
then "(" ++ fn ++ ")"
else fn
showRHSName :: Name -> String
showRHSName n = let fn = show n in
if any (flip elem opChars) fn
then "op"
else fn
getClause :: Int
-> Name
-> Name
-> FilePath
-> Idris String
getClause l fn un fp
= do i <- getIState
case lookupCtxt un (idris_classes i) of
[c] -> return (mkClassBodies i (class_methods c))
_ -> do ty_in <- getInternalApp fp l
let ty = case ty_in of
PTyped n t -> t
x -> x
ist <- get
let ap = mkApp ist ty []
return (showLHSName un ++ " " ++ ap ++ "= ?"
++ showRHSName un ++ "_rhs")
where mkApp :: IState -> PTerm -> [Name] -> String
mkApp i (PPi (Exp _ _ False) (MN _ _) _ ty sc) used
= let n = getNameFrom i used ty in
show n ++ " " ++ mkApp i sc (n : used)
mkApp i (PPi (Exp _ _ False) (UN n) _ ty sc) used
| thead n == '_'
= let n = getNameFrom i used ty in
show n ++ " " ++ mkApp i sc (n : used)
mkApp i (PPi (Exp _ _ False) n _ _ sc) used
= show n ++ " " ++ mkApp i sc (n : used)
mkApp i (PPi _ _ _ _ sc) used = mkApp i sc used
mkApp i _ _ = ""
getNameFrom i used (PPi _ _ _ _ _)
= uniqueNameFrom (mkSupply [sUN "f", sUN "g"]) used
getNameFrom i used (PApp fc f as) = getNameFrom i used f
getNameFrom i used (PRef fc _ f)
= case getNameHints i f of
[] -> uniqueNameFrom (mkSupply [sUN "x", sUN "y",
sUN "z"]) used
ns -> uniqueNameFrom (mkSupply ns) used
getNameFrom i used _ = uniqueNameFrom (mkSupply [sUN "x", sUN "y",
sUN "z"]) used
mkClassBodies :: IState -> [(Name, (Bool, FnOpts, PTerm))] -> String
mkClassBodies i ns
= showSep "\n"
(zipWith (\(n, (_, _, ty)) m -> " " ++
def (show (nsroot n)) ++ " "
++ mkApp i ty []
++ "= ?"
++ showRHSName un ++ "_rhs_" ++ show m) ns [1..])
def n@(x:xs) | not (isAlphaNum x) = "(" ++ n ++ ")"
def n = n
getProofClause :: Int
-> Name
-> FilePath
-> Idris String
getProofClause l fn fp
= do ty_in <- getInternalApp fp l
let ty = case ty_in of
PTyped n t -> t
x -> x
return (mkApp ty ++ " = ?" ++ showRHSName fn ++ "_rhs")
where mkApp (PPi _ _ _ _ sc) = mkApp sc
mkApp rt = "(" ++ show rt ++ ") <== " ++ show fn
mkWith :: String -> Name -> String
mkWith str n = let str' = replaceRHS str "with (_)"
in str' ++ "\n" ++ newpat str
where replaceRHS [] str = str
replaceRHS ('?':'=': rest) str = str
replaceRHS ('=': rest) str
| not ('=' `elem` rest) = str
replaceRHS (x : rest) str = x : replaceRHS rest str
newpat ('>':patstr) = '>':newpat patstr
newpat patstr =
" " ++
replaceRHS patstr "| with_pat = ?" ++ showRHSName n ++ "_rhs"
nameMissing :: [PTerm] -> Idris [PTerm]
nameMissing ps = do ist <- get
newPats <- mapM nm ps
let newPats' = mergeAllPats ist (sUN "_") (base (head ps))
newPats
return (map fst newPats')
where
base (PApp fc f args) = PApp fc f (map (fmap (const (PRef fc [] (sUN "_")))) args)
base t = t
nm ptm = do mptm <- elabNewPat ptm
case mptm of
(False, _) -> return ptm
(True, ptm') -> return ptm'