module Wingman.CaseSplit
( mkFirstAgda
, iterateSplit
, splitToDecl
) where
import Data.Bool (bool)
import Data.Data
import Data.Generics
import Data.Set (Set)
import qualified Data.Set as S
import Development.IDE.GHC.Compat
import GHC.Exts (IsString (fromString))
import GHC.SourceGen (funBindsWithFixity, match, wildP)
import Wingman.GHC
import Wingman.Types
mkFirstAgda :: [Pat GhcPs] -> HsExpr GhcPs -> AgdaMatch
mkFirstAgda :: [Pat GhcPs] -> HsExpr GhcPs -> AgdaMatch
mkFirstAgda [Pat GhcPs]
pats (Lambda [Pat GhcPs]
pats' HsExpr GhcPs
body) = [Pat GhcPs] -> HsExpr GhcPs -> AgdaMatch
mkFirstAgda ([Pat GhcPs]
pats [Pat GhcPs] -> [Pat GhcPs] -> [Pat GhcPs]
forall a. Semigroup a => a -> a -> a
<> [Pat GhcPs]
pats') HsExpr GhcPs
body
mkFirstAgda [Pat GhcPs]
pats HsExpr GhcPs
body = [Pat GhcPs] -> HsExpr GhcPs -> AgdaMatch
AgdaMatch [Pat GhcPs]
pats HsExpr GhcPs
body
agdaSplit :: AgdaMatch -> [AgdaMatch]
agdaSplit :: AgdaMatch -> [AgdaMatch]
agdaSplit (AgdaMatch [Pat GhcPs]
pats (Case (HsVar XVar GhcPs
_ (L SrcSpan
_ IdP GhcPs
var)) [(Pat GhcPs, LHsExpr GhcPs)]
matches))
| RdrName -> [Pat GhcPs] -> Bool
forall a. Data a => RdrName -> a -> Bool
containsVar IdP GhcPs
RdrName
var [Pat GhcPs]
pats
= do
(Pat GhcPs
pat, LHsExpr GhcPs
body) <- [(Pat GhcPs, LHsExpr GhcPs)]
matches
AgdaMatch -> [AgdaMatch]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AgdaMatch -> [AgdaMatch]) -> AgdaMatch -> [AgdaMatch]
forall a b. (a -> b) -> a -> b
$ [Pat GhcPs] -> HsExpr GhcPs -> AgdaMatch
AgdaMatch (RdrName -> Pat GhcPs -> [Pat GhcPs] -> [Pat GhcPs]
forall a. Data a => RdrName -> Pat GhcPs -> a -> a
rewriteVarPat IdP GhcPs
RdrName
var Pat GhcPs
pat [Pat GhcPs]
pats) (HsExpr GhcPs -> AgdaMatch) -> HsExpr GhcPs -> AgdaMatch
forall a b. (a -> b) -> a -> b
$ LHsExpr GhcPs -> SrcSpanLess (LHsExpr GhcPs)
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc LHsExpr GhcPs
body
agdaSplit AgdaMatch
x = [AgdaMatch
x]
wildify :: AgdaMatch -> AgdaMatch
wildify :: AgdaMatch -> AgdaMatch
wildify (AgdaMatch [Pat GhcPs]
pats HsExpr GhcPs
body) =
let make_wild :: [Pat GhcPs] -> [Pat GhcPs]
make_wild = ([Pat GhcPs] -> [Pat GhcPs])
-> ([Pat GhcPs] -> [Pat GhcPs])
-> Bool
-> [Pat GhcPs]
-> [Pat GhcPs]
forall a. a -> a -> Bool -> a
bool [Pat GhcPs] -> [Pat GhcPs]
forall a. a -> a
id (Set OccName -> [Pat GhcPs] -> [Pat GhcPs]
forall a. Data a => Set OccName -> a -> a
wildifyT (HsExpr GhcPs -> Set OccName
forall a. Data a => a -> Set OccName
allOccNames HsExpr GhcPs
body)) (Bool -> [Pat GhcPs] -> [Pat GhcPs])
-> Bool -> [Pat GhcPs] -> [Pat GhcPs]
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ HsExpr GhcPs -> Bool
forall a. Data a => a -> Bool
containsHole HsExpr GhcPs
body
in [Pat GhcPs] -> HsExpr GhcPs -> AgdaMatch
AgdaMatch ([Pat GhcPs] -> [Pat GhcPs]
make_wild [Pat GhcPs]
pats) HsExpr GhcPs
body
wildifyT :: Data a => Set OccName -> a -> a
wildifyT :: Set OccName -> a -> a
wildifyT ((OccName -> String) -> Set OccName -> Set String
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map OccName -> String
occNameString -> Set String
used) = (forall a. Data a => a -> a) -> forall a. Data a => a -> a
everywhere ((forall a. Data a => a -> a) -> forall a. Data a => a -> a)
-> (forall a. Data a => a -> a) -> forall a. Data a => a -> a
forall a b. (a -> b) -> a -> b
$ (Pat GhcPs -> Pat GhcPs) -> a -> a
forall a b. (Typeable a, Typeable b) => (b -> b) -> a -> a
mkT ((Pat GhcPs -> Pat GhcPs) -> a -> a)
-> (Pat GhcPs -> Pat GhcPs) -> a -> a
forall a b. (a -> b) -> a -> b
$ \case
VarPat XVarPat GhcPs
_ (L SrcSpan
_ IdP GhcPs
var) | String -> Set String -> Bool
forall a. Ord a => a -> Set a -> Bool
S.notMember (OccName -> String
occNameString (OccName -> String) -> OccName -> String
forall a b. (a -> b) -> a -> b
$ RdrName -> OccName
forall name. HasOccName name => name -> OccName
occName IdP GhcPs
RdrName
var) Set String
used -> Pat GhcPs
wildP
(Pat GhcPs
x :: Pat GhcPs) -> Pat GhcPs
x
containsVar :: Data a => RdrName -> a -> Bool
containsVar :: RdrName -> a -> Bool
containsVar RdrName
name = (Bool -> Bool -> Bool)
-> (forall a. Data a => a -> Bool) -> forall a. Data a => a -> Bool
forall r. (r -> r -> r) -> GenericQ r -> GenericQ r
everything Bool -> Bool -> Bool
(||) ((forall a. Data a => a -> Bool) -> forall a. Data a => a -> Bool)
-> (forall a. Data a => a -> Bool) -> forall a. Data a => a -> Bool
forall a b. (a -> b) -> a -> b
$
Bool -> (Pat GhcPs -> Bool) -> a -> Bool
forall a b r. (Typeable a, Typeable b) => r -> (b -> r) -> a -> r
mkQ Bool
False (\case
VarPat XVarPat GhcPs
_ (L SrcSpan
_ IdP GhcPs
var) -> RdrName -> RdrName -> Bool
eqRdrName RdrName
name IdP GhcPs
RdrName
var
(Pat GhcPs
_ :: Pat GhcPs) -> Bool
False
)
(a -> Bool)
-> (HsRecField' (FieldOcc GhcPs) (Located (Pat GhcPs)) -> Bool)
-> a
-> Bool
forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` \case
HsRecField Located (FieldOcc GhcPs)
lbl Located (Pat GhcPs)
_ Bool
True -> RdrName -> RdrName -> Bool
eqRdrName RdrName
name (RdrName -> Bool) -> RdrName -> Bool
forall a b. (a -> b) -> a -> b
$ Located RdrName -> SrcSpanLess (Located RdrName)
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc (Located RdrName -> SrcSpanLess (Located RdrName))
-> Located RdrName -> SrcSpanLess (Located RdrName)
forall a b. (a -> b) -> a -> b
$ FieldOcc GhcPs -> Located RdrName
forall pass. FieldOcc pass -> Located RdrName
rdrNameFieldOcc (FieldOcc GhcPs -> Located RdrName)
-> FieldOcc GhcPs -> Located RdrName
forall a b. (a -> b) -> a -> b
$ Located (FieldOcc GhcPs) -> SrcSpanLess (Located (FieldOcc GhcPs))
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc Located (FieldOcc GhcPs)
lbl
(_ :: HsRecField' (FieldOcc GhcPs) (PatCompat GhcPs)) -> Bool
False
rewriteVarPat :: Data a => RdrName -> Pat GhcPs -> a -> a
rewriteVarPat :: RdrName -> Pat GhcPs -> a -> a
rewriteVarPat RdrName
name Pat GhcPs
rep = (forall a. Data a => a -> a) -> forall a. Data a => a -> a
everywhere ((forall a. Data a => a -> a) -> forall a. Data a => a -> a)
-> (forall a. Data a => a -> a) -> forall a. Data a => a -> a
forall a b. (a -> b) -> a -> b
$
(Pat GhcPs -> Pat GhcPs) -> a -> a
forall a b. (Typeable a, Typeable b) => (b -> b) -> a -> a
mkT (\case
VarPat XVarPat GhcPs
_ (L SrcSpan
_ IdP GhcPs
var) | RdrName -> RdrName -> Bool
eqRdrName RdrName
name IdP GhcPs
RdrName
var -> Pat GhcPs
rep
(Pat GhcPs
x :: Pat GhcPs) -> Pat GhcPs
x
)
(a -> a)
-> (HsRecField' (FieldOcc GhcPs) (Located (Pat GhcPs))
-> HsRecField' (FieldOcc GhcPs) (Located (Pat GhcPs)))
-> a
-> a
forall a b.
(Typeable a, Typeable b) =>
(a -> a) -> (b -> b) -> a -> a
`extT` \case
HsRecField Located (FieldOcc GhcPs)
lbl Located (Pat GhcPs)
_ Bool
True
| RdrName -> RdrName -> Bool
eqRdrName RdrName
name (RdrName -> Bool) -> RdrName -> Bool
forall a b. (a -> b) -> a -> b
$ Located RdrName -> SrcSpanLess (Located RdrName)
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc (Located RdrName -> SrcSpanLess (Located RdrName))
-> Located RdrName -> SrcSpanLess (Located RdrName)
forall a b. (a -> b) -> a -> b
$ FieldOcc GhcPs -> Located RdrName
forall pass. FieldOcc pass -> Located RdrName
rdrNameFieldOcc (FieldOcc GhcPs -> Located RdrName)
-> FieldOcc GhcPs -> Located RdrName
forall a b. (a -> b) -> a -> b
$ Located (FieldOcc GhcPs) -> SrcSpanLess (Located (FieldOcc GhcPs))
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc Located (FieldOcc GhcPs)
lbl
-> Located (FieldOcc GhcPs)
-> Located (Pat GhcPs)
-> Bool
-> HsRecField' (FieldOcc GhcPs) (Located (Pat GhcPs))
forall id arg. Located id -> arg -> Bool -> HsRecField' id arg
HsRecField Located (FieldOcc GhcPs)
lbl (Pat GhcPs -> PatCompat GhcPs
forall p. PatCompattable p => Pat p -> PatCompat p
toPatCompat Pat GhcPs
rep) Bool
False
(x :: HsRecField' (FieldOcc GhcPs) (PatCompat GhcPs)) -> HsRecField' (FieldOcc GhcPs) (PatCompat GhcPs)
HsRecField' (FieldOcc GhcPs) (Located (Pat GhcPs))
x
splitToDecl
:: Maybe LexicalFixity
-> OccName
-> [AgdaMatch]
-> LHsDecl GhcPs
splitToDecl :: Maybe LexicalFixity -> OccName -> [AgdaMatch] -> LHsDecl GhcPs
splitToDecl Maybe LexicalFixity
fixity OccName
name [AgdaMatch]
ams = do
String -> Maybe LexicalFixity -> LHsDecl GhcPs -> LHsDecl GhcPs
forall a b. Show a => String -> a -> b -> b
traceX String
"fixity" Maybe LexicalFixity
fixity (LHsDecl GhcPs -> LHsDecl GhcPs) -> LHsDecl GhcPs -> LHsDecl GhcPs
forall a b. (a -> b) -> a -> b
$
SrcSpanLess (LHsDecl GhcPs) -> LHsDecl GhcPs
forall a. HasSrcSpan a => SrcSpanLess a -> a
noLoc (SrcSpanLess (LHsDecl GhcPs) -> LHsDecl GhcPs)
-> SrcSpanLess (LHsDecl GhcPs) -> LHsDecl GhcPs
forall a b. (a -> b) -> a -> b
$
Maybe LexicalFixity -> OccNameStr -> [RawMatch] -> HsDecl GhcPs
forall t.
HasValBind t =>
Maybe LexicalFixity -> OccNameStr -> [RawMatch] -> t
funBindsWithFixity Maybe LexicalFixity
fixity (String -> OccNameStr
forall a. IsString a => String -> a
fromString (String -> OccNameStr)
-> (OccName -> String) -> OccName -> OccNameStr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OccName -> String
occNameString (OccName -> String) -> (OccName -> OccName) -> OccName -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OccName -> OccName
forall name. HasOccName name => name -> OccName
occName (OccName -> OccNameStr) -> OccName -> OccNameStr
forall a b. (a -> b) -> a -> b
$ OccName
name) ([RawMatch] -> HsDecl GhcPs) -> [RawMatch] -> HsDecl GhcPs
forall a b. (a -> b) -> a -> b
$ do
AgdaMatch [Pat GhcPs]
pats HsExpr GhcPs
body <- [AgdaMatch]
ams
RawMatch -> [RawMatch]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (RawMatch -> [RawMatch]) -> RawMatch -> [RawMatch]
forall a b. (a -> b) -> a -> b
$ [Pat GhcPs] -> HsExpr GhcPs -> RawMatch
match [Pat GhcPs]
pats HsExpr GhcPs
body
iterateSplit :: AgdaMatch -> [AgdaMatch]
iterateSplit :: AgdaMatch -> [AgdaMatch]
iterateSplit AgdaMatch
am =
let iterated :: [[AgdaMatch]]
iterated = ([AgdaMatch] -> [AgdaMatch]) -> [AgdaMatch] -> [[AgdaMatch]]
forall a. (a -> a) -> a -> [a]
iterate (AgdaMatch -> [AgdaMatch]
agdaSplit (AgdaMatch -> [AgdaMatch]) -> [AgdaMatch] -> [AgdaMatch]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<) ([AgdaMatch] -> [[AgdaMatch]]) -> [AgdaMatch] -> [[AgdaMatch]]
forall a b. (a -> b) -> a -> b
$ AgdaMatch -> [AgdaMatch]
forall (f :: * -> *) a. Applicative f => a -> f a
pure AgdaMatch
am
in (AgdaMatch -> AgdaMatch) -> [AgdaMatch] -> [AgdaMatch]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap AgdaMatch -> AgdaMatch
wildify ([AgdaMatch] -> [AgdaMatch])
-> ([[AgdaMatch]] -> [AgdaMatch]) -> [[AgdaMatch]] -> [AgdaMatch]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([[AgdaMatch]] -> Int -> [AgdaMatch]
forall a. [a] -> Int -> a
!! Int
5) ([[AgdaMatch]] -> [AgdaMatch]) -> [[AgdaMatch]] -> [AgdaMatch]
forall a b. (a -> b) -> a -> b
$ [[AgdaMatch]]
iterated