module Wingman.CaseSplit
  ( mkFirstAgda
  , iterateSplit
  , splitToDecl
  ) where

import           Data.Bool (bool)
import           Data.Data
import           Data.Generics
import           Data.Set (Set)
import qualified Data.Set as S
import           Development.IDE.GHC.Compat
import           GHC.Exts (IsString (fromString))
import           GHC.SourceGen (funBindsWithFixity, match, wildP)
import           Wingman.GHC
import           Wingman.Types



------------------------------------------------------------------------------
-- | Construct an 'AgdaMatch' from patterns in scope (should be the LHS of the
-- match) and a body.
mkFirstAgda :: [Pat GhcPs] -> HsExpr GhcPs -> AgdaMatch
mkFirstAgda :: [Pat GhcPs] -> HsExpr GhcPs -> AgdaMatch
mkFirstAgda [Pat GhcPs]
pats (Lambda [Pat GhcPs]
pats' HsExpr GhcPs
body) = [Pat GhcPs] -> HsExpr GhcPs -> AgdaMatch
mkFirstAgda ([Pat GhcPs]
pats [Pat GhcPs] -> [Pat GhcPs] -> [Pat GhcPs]
forall a. Semigroup a => a -> a -> a
<> [Pat GhcPs]
pats') HsExpr GhcPs
body
mkFirstAgda [Pat GhcPs]
pats HsExpr GhcPs
body                = [Pat GhcPs] -> HsExpr GhcPs -> AgdaMatch
AgdaMatch [Pat GhcPs]
pats HsExpr GhcPs
body


------------------------------------------------------------------------------
-- | Transform an 'AgdaMatch' whose body is a case over a bound pattern, by
-- splitting it into multiple matches: one for each alternative of the case.
agdaSplit :: AgdaMatch -> [AgdaMatch]
agdaSplit :: AgdaMatch -> [AgdaMatch]
agdaSplit (AgdaMatch [Pat GhcPs]
pats (Case (HsVar XVar GhcPs
_ (L SrcSpan
_ IdP GhcPs
var)) [(Pat GhcPs, LHsExpr GhcPs)]
matches))
  -- Ensure the thing we're destructing is actually a pattern that's been
  -- bound.
  | RdrName -> [Pat GhcPs] -> Bool
forall a. Data a => RdrName -> a -> Bool
containsVar IdP GhcPs
RdrName
var [Pat GhcPs]
pats
  = do
    (Pat GhcPs
pat, LHsExpr GhcPs
body) <- [(Pat GhcPs, LHsExpr GhcPs)]
matches
    -- TODO(sandy): use an at pattern if necessary
    AgdaMatch -> [AgdaMatch]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AgdaMatch -> [AgdaMatch]) -> AgdaMatch -> [AgdaMatch]
forall a b. (a -> b) -> a -> b
$ [Pat GhcPs] -> HsExpr GhcPs -> AgdaMatch
AgdaMatch (RdrName -> Pat GhcPs -> [Pat GhcPs] -> [Pat GhcPs]
forall a. Data a => RdrName -> Pat GhcPs -> a -> a
rewriteVarPat IdP GhcPs
RdrName
var Pat GhcPs
pat [Pat GhcPs]
pats) (HsExpr GhcPs -> AgdaMatch) -> HsExpr GhcPs -> AgdaMatch
forall a b. (a -> b) -> a -> b
$ LHsExpr GhcPs -> SrcSpanLess (LHsExpr GhcPs)
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc LHsExpr GhcPs
body
agdaSplit AgdaMatch
x = [AgdaMatch
x]


------------------------------------------------------------------------------
-- | Replace unused bound patterns with wild patterns.
wildify :: AgdaMatch -> AgdaMatch
wildify :: AgdaMatch -> AgdaMatch
wildify (AgdaMatch [Pat GhcPs]
pats HsExpr GhcPs
body) =
  let make_wild :: [Pat GhcPs] -> [Pat GhcPs]
make_wild = ([Pat GhcPs] -> [Pat GhcPs])
-> ([Pat GhcPs] -> [Pat GhcPs])
-> Bool
-> [Pat GhcPs]
-> [Pat GhcPs]
forall a. a -> a -> Bool -> a
bool [Pat GhcPs] -> [Pat GhcPs]
forall a. a -> a
id (Set OccName -> [Pat GhcPs] -> [Pat GhcPs]
forall a. Data a => Set OccName -> a -> a
wildifyT (HsExpr GhcPs -> Set OccName
forall a. Data a => a -> Set OccName
allOccNames HsExpr GhcPs
body)) (Bool -> [Pat GhcPs] -> [Pat GhcPs])
-> Bool -> [Pat GhcPs] -> [Pat GhcPs]
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ HsExpr GhcPs -> Bool
forall a. Data a => a -> Bool
containsHole HsExpr GhcPs
body
   in [Pat GhcPs] -> HsExpr GhcPs -> AgdaMatch
AgdaMatch ([Pat GhcPs] -> [Pat GhcPs]
make_wild [Pat GhcPs]
pats) HsExpr GhcPs
body


------------------------------------------------------------------------------
-- | Helper function for 'wildify'.
wildifyT :: Data a => Set OccName -> a -> a
wildifyT :: Set OccName -> a -> a
wildifyT ((OccName -> String) -> Set OccName -> Set String
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map OccName -> String
occNameString -> Set String
used) = (forall a. Data a => a -> a) -> forall a. Data a => a -> a
everywhere ((forall a. Data a => a -> a) -> forall a. Data a => a -> a)
-> (forall a. Data a => a -> a) -> forall a. Data a => a -> a
forall a b. (a -> b) -> a -> b
$ (Pat GhcPs -> Pat GhcPs) -> a -> a
forall a b. (Typeable a, Typeable b) => (b -> b) -> a -> a
mkT ((Pat GhcPs -> Pat GhcPs) -> a -> a)
-> (Pat GhcPs -> Pat GhcPs) -> a -> a
forall a b. (a -> b) -> a -> b
$ \case
  VarPat XVarPat GhcPs
_ (L SrcSpan
_ IdP GhcPs
var) | String -> Set String -> Bool
forall a. Ord a => a -> Set a -> Bool
S.notMember (OccName -> String
occNameString (OccName -> String) -> OccName -> String
forall a b. (a -> b) -> a -> b
$ RdrName -> OccName
forall name. HasOccName name => name -> OccName
occName IdP GhcPs
RdrName
var) Set String
used -> Pat GhcPs
wildP
  (Pat GhcPs
x :: Pat GhcPs)                                                    -> Pat GhcPs
x


------------------------------------------------------------------------------
-- | Determine whether the given 'RdrName' exists as a 'VarPat' inside of @a@.
containsVar :: Data a => RdrName -> a -> Bool
containsVar :: RdrName -> a -> Bool
containsVar RdrName
name = (Bool -> Bool -> Bool)
-> (forall a. Data a => a -> Bool) -> forall a. Data a => a -> Bool
forall r. (r -> r -> r) -> GenericQ r -> GenericQ r
everything Bool -> Bool -> Bool
(||) ((forall a. Data a => a -> Bool) -> forall a. Data a => a -> Bool)
-> (forall a. Data a => a -> Bool) -> forall a. Data a => a -> Bool
forall a b. (a -> b) -> a -> b
$
  Bool -> (Pat GhcPs -> Bool) -> a -> Bool
forall a b r. (Typeable a, Typeable b) => r -> (b -> r) -> a -> r
mkQ Bool
False (\case
    VarPat XVarPat GhcPs
_ (L SrcSpan
_ IdP GhcPs
var) -> RdrName -> RdrName -> Bool
eqRdrName RdrName
name IdP GhcPs
RdrName
var
    (Pat GhcPs
_ :: Pat GhcPs)   -> Bool
False
      )
  (a -> Bool)
-> (HsRecField' (FieldOcc GhcPs) (Located (Pat GhcPs)) -> Bool)
-> a
-> Bool
forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` \case
    HsRecField Located (FieldOcc GhcPs)
lbl Located (Pat GhcPs)
_ Bool
True ->  RdrName -> RdrName -> Bool
eqRdrName RdrName
name (RdrName -> Bool) -> RdrName -> Bool
forall a b. (a -> b) -> a -> b
$ Located RdrName -> SrcSpanLess (Located RdrName)
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc (Located RdrName -> SrcSpanLess (Located RdrName))
-> Located RdrName -> SrcSpanLess (Located RdrName)
forall a b. (a -> b) -> a -> b
$ FieldOcc GhcPs -> Located RdrName
forall pass. FieldOcc pass -> Located RdrName
rdrNameFieldOcc (FieldOcc GhcPs -> Located RdrName)
-> FieldOcc GhcPs -> Located RdrName
forall a b. (a -> b) -> a -> b
$ Located (FieldOcc GhcPs) -> SrcSpanLess (Located (FieldOcc GhcPs))
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc Located (FieldOcc GhcPs)
lbl
    (_ :: HsRecField' (FieldOcc GhcPs) (PatCompat GhcPs)) -> Bool
False


------------------------------------------------------------------------------
-- | Replace a 'VarPat' with the given @'Pat' GhcPs@.
rewriteVarPat :: Data a => RdrName -> Pat GhcPs -> a -> a
rewriteVarPat :: RdrName -> Pat GhcPs -> a -> a
rewriteVarPat RdrName
name Pat GhcPs
rep = (forall a. Data a => a -> a) -> forall a. Data a => a -> a
everywhere ((forall a. Data a => a -> a) -> forall a. Data a => a -> a)
-> (forall a. Data a => a -> a) -> forall a. Data a => a -> a
forall a b. (a -> b) -> a -> b
$
  (Pat GhcPs -> Pat GhcPs) -> a -> a
forall a b. (Typeable a, Typeable b) => (b -> b) -> a -> a
mkT (\case
    VarPat XVarPat GhcPs
_ (L SrcSpan
_ IdP GhcPs
var) | RdrName -> RdrName -> Bool
eqRdrName RdrName
name IdP GhcPs
RdrName
var -> Pat GhcPs
rep
    (Pat GhcPs
x :: Pat GhcPs)                        -> Pat GhcPs
x
      )
  (a -> a)
-> (HsRecField' (FieldOcc GhcPs) (Located (Pat GhcPs))
    -> HsRecField' (FieldOcc GhcPs) (Located (Pat GhcPs)))
-> a
-> a
forall a b.
(Typeable a, Typeable b) =>
(a -> a) -> (b -> b) -> a -> a
`extT` \case
    HsRecField Located (FieldOcc GhcPs)
lbl Located (Pat GhcPs)
_ Bool
True
      | RdrName -> RdrName -> Bool
eqRdrName RdrName
name (RdrName -> Bool) -> RdrName -> Bool
forall a b. (a -> b) -> a -> b
$ Located RdrName -> SrcSpanLess (Located RdrName)
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc (Located RdrName -> SrcSpanLess (Located RdrName))
-> Located RdrName -> SrcSpanLess (Located RdrName)
forall a b. (a -> b) -> a -> b
$ FieldOcc GhcPs -> Located RdrName
forall pass. FieldOcc pass -> Located RdrName
rdrNameFieldOcc (FieldOcc GhcPs -> Located RdrName)
-> FieldOcc GhcPs -> Located RdrName
forall a b. (a -> b) -> a -> b
$ Located (FieldOcc GhcPs) -> SrcSpanLess (Located (FieldOcc GhcPs))
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc Located (FieldOcc GhcPs)
lbl
          -> Located (FieldOcc GhcPs)
-> Located (Pat GhcPs)
-> Bool
-> HsRecField' (FieldOcc GhcPs) (Located (Pat GhcPs))
forall id arg. Located id -> arg -> Bool -> HsRecField' id arg
HsRecField Located (FieldOcc GhcPs)
lbl (Pat GhcPs -> PatCompat GhcPs
forall p. PatCompattable p => Pat p -> PatCompat p
toPatCompat Pat GhcPs
rep) Bool
False
    (x :: HsRecField' (FieldOcc GhcPs) (PatCompat GhcPs)) -> HsRecField' (FieldOcc GhcPs) (PatCompat GhcPs)
HsRecField' (FieldOcc GhcPs) (Located (Pat GhcPs))
x


------------------------------------------------------------------------------
-- | Construct an 'HsDecl' from a set of 'AgdaMatch'es.
splitToDecl
    :: Maybe LexicalFixity
    -> OccName  -- ^ The name of the function
    -> [AgdaMatch]
    -> LHsDecl GhcPs
splitToDecl :: Maybe LexicalFixity -> OccName -> [AgdaMatch] -> LHsDecl GhcPs
splitToDecl Maybe LexicalFixity
fixity OccName
name [AgdaMatch]
ams = do
  String -> Maybe LexicalFixity -> LHsDecl GhcPs -> LHsDecl GhcPs
forall a b. Show a => String -> a -> b -> b
traceX String
"fixity" Maybe LexicalFixity
fixity (LHsDecl GhcPs -> LHsDecl GhcPs) -> LHsDecl GhcPs -> LHsDecl GhcPs
forall a b. (a -> b) -> a -> b
$
    SrcSpanLess (LHsDecl GhcPs) -> LHsDecl GhcPs
forall a. HasSrcSpan a => SrcSpanLess a -> a
noLoc (SrcSpanLess (LHsDecl GhcPs) -> LHsDecl GhcPs)
-> SrcSpanLess (LHsDecl GhcPs) -> LHsDecl GhcPs
forall a b. (a -> b) -> a -> b
$
      Maybe LexicalFixity -> OccNameStr -> [RawMatch] -> HsDecl GhcPs
forall t.
HasValBind t =>
Maybe LexicalFixity -> OccNameStr -> [RawMatch] -> t
funBindsWithFixity Maybe LexicalFixity
fixity (String -> OccNameStr
forall a. IsString a => String -> a
fromString (String -> OccNameStr)
-> (OccName -> String) -> OccName -> OccNameStr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OccName -> String
occNameString (OccName -> String) -> (OccName -> OccName) -> OccName -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OccName -> OccName
forall name. HasOccName name => name -> OccName
occName (OccName -> OccNameStr) -> OccName -> OccNameStr
forall a b. (a -> b) -> a -> b
$ OccName
name) ([RawMatch] -> HsDecl GhcPs) -> [RawMatch] -> HsDecl GhcPs
forall a b. (a -> b) -> a -> b
$ do
        AgdaMatch [Pat GhcPs]
pats HsExpr GhcPs
body <- [AgdaMatch]
ams
        RawMatch -> [RawMatch]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (RawMatch -> [RawMatch]) -> RawMatch -> [RawMatch]
forall a b. (a -> b) -> a -> b
$ [Pat GhcPs] -> HsExpr GhcPs -> RawMatch
match [Pat GhcPs]
pats HsExpr GhcPs
body


------------------------------------------------------------------------------
-- | Sometimes 'agdaSplit' exposes another opportunity to do 'agdaSplit'. This
-- function runs it a few times, hoping it will find a fixpoint.
iterateSplit :: AgdaMatch -> [AgdaMatch]
iterateSplit :: AgdaMatch -> [AgdaMatch]
iterateSplit AgdaMatch
am =
  let iterated :: [[AgdaMatch]]
iterated = ([AgdaMatch] -> [AgdaMatch]) -> [AgdaMatch] -> [[AgdaMatch]]
forall a. (a -> a) -> a -> [a]
iterate (AgdaMatch -> [AgdaMatch]
agdaSplit (AgdaMatch -> [AgdaMatch]) -> [AgdaMatch] -> [AgdaMatch]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<) ([AgdaMatch] -> [[AgdaMatch]]) -> [AgdaMatch] -> [[AgdaMatch]]
forall a b. (a -> b) -> a -> b
$ AgdaMatch -> [AgdaMatch]
forall (f :: * -> *) a. Applicative f => a -> f a
pure AgdaMatch
am
   in (AgdaMatch -> AgdaMatch) -> [AgdaMatch] -> [AgdaMatch]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap AgdaMatch -> AgdaMatch
wildify ([AgdaMatch] -> [AgdaMatch])
-> ([[AgdaMatch]] -> [AgdaMatch]) -> [[AgdaMatch]] -> [AgdaMatch]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([[AgdaMatch]] -> Int -> [AgdaMatch]
forall a. [a] -> Int -> a
!! Int
5) ([[AgdaMatch]] -> [AgdaMatch]) -> [[AgdaMatch]] -> [AgdaMatch]
forall a b. (a -> b) -> a -> b
$ [[AgdaMatch]]
iterated