{-# LANGUAGE ScopedTypeVariables #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Data.Singletons.Deriving.Functor
-- Copyright   :  (C) 2018 Ryan Scott
-- License     :  BSD-style (see LICENSE)
-- Maintainer  :  Ryan Scott
-- Stability   :  experimental
-- Portability :  non-portable
--
-- Implements deriving of Functor instances
--
----------------------------------------------------------------------------

module Data.Singletons.Deriving.Functor where

import Data.Singletons.Deriving.Infer
import Data.Singletons.Deriving.Util
import Data.Singletons.Names
import Data.Singletons.Syntax
import Data.Singletons.Util
import Language.Haskell.TH.Desugar

mkFunctorInstance :: forall q. DsMonad q => DerivDesc q
mkFunctorInstance :: DerivDesc q
mkFunctorInstance Maybe DCxt
mb_ctxt DType
ty dd :: DataDecl
dd@(DataDecl Name
_ [DTyVarBndr]
_ [DCon]
cons) = do
  Bool -> DataDecl -> q ()
forall (q :: * -> *). DsMonad q => Bool -> DataDecl -> q ()
functorLikeValidityChecks Bool
False DataDecl
dd
  Name
f <- String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"_f"
  Name
z <- String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"_z"
  let ft_fmap :: FFoldType (q DExp)
      ft_fmap :: FFoldType (q DExp)
ft_fmap = FT :: forall a.
a
-> a
-> (DType -> a -> a)
-> a
-> ([DTyVarBndr] -> a -> a)
-> FFoldType a
FT { ft_triv :: q DExp
ft_triv = (DExp -> q DExp) -> q DExp
forall (q :: * -> *). Quasi q => (DExp -> q DExp) -> q DExp
mkSimpleLam DExp -> q DExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure
                     -- fmap f = \x -> x
                   , ft_var :: q DExp
ft_var = DExp -> q DExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DExp -> q DExp) -> DExp -> q DExp
forall a b. (a -> b) -> a -> b
$ Name -> DExp
DVarE Name
f
                     -- fmap f = f
                   , ft_ty_app :: DType -> q DExp -> q DExp
ft_ty_app = \DType
_ q DExp
g -> DExp -> DExp -> DExp
DAppE (Name -> DExp
DVarE Name
fmapName) (DExp -> DExp) -> q DExp -> q DExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> q DExp
g
                     -- fmap f = fmap g
                   , ft_forall :: [DTyVarBndr] -> q DExp -> q DExp
ft_forall = \[DTyVarBndr]
_ q DExp
g -> q DExp
g
                   , ft_bad_app :: q DExp
ft_bad_app = String -> q DExp
forall a. HasCallStack => String -> a
error String
"in other argument in ft_fmap"
                   }

      ft_replace :: FFoldType (q Replacer)
      ft_replace :: FFoldType (q Replacer)
ft_replace = FT :: forall a.
a
-> a
-> (DType -> a -> a)
-> a
-> ([DTyVarBndr] -> a -> a)
-> FFoldType a
FT { ft_triv :: q Replacer
ft_triv = (DExp -> Replacer) -> q DExp -> q Replacer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap DExp -> Replacer
Nested    (q DExp -> q Replacer) -> q DExp -> q Replacer
forall a b. (a -> b) -> a -> b
$ (DExp -> q DExp) -> q DExp
forall (q :: * -> *). Quasi q => (DExp -> q DExp) -> q DExp
mkSimpleLam DExp -> q DExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure
                        -- (p <$) = \x -> x
                      , ft_var :: q Replacer
ft_var  = (DExp -> Replacer) -> q DExp -> q Replacer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap DExp -> Replacer
Immediate (q DExp -> q Replacer) -> q DExp -> q Replacer
forall a b. (a -> b) -> a -> b
$ (DExp -> q DExp) -> q DExp
forall (q :: * -> *). Quasi q => (DExp -> q DExp) -> q DExp
mkSimpleLam ((DExp -> q DExp) -> q DExp) -> (DExp -> q DExp) -> q DExp
forall a b. (a -> b) -> a -> b
$ \DExp
_ -> DExp -> q DExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DExp -> q DExp) -> DExp -> q DExp
forall a b. (a -> b) -> a -> b
$ Name -> DExp
DVarE Name
z
                        -- (p <$) = const p
                      , ft_ty_app :: DType -> q Replacer -> q Replacer
ft_ty_app = \DType
_ q Replacer
gm -> do
                          Replacer
g <- q Replacer
gm
                          case Replacer
g of
                            Nested DExp
g'   -> Replacer -> q Replacer
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Replacer -> q Replacer)
-> (DExp -> Replacer) -> DExp -> q Replacer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DExp -> Replacer
Nested (DExp -> q Replacer) -> DExp -> q Replacer
forall a b. (a -> b) -> a -> b
$ Name -> DExp
DVarE Name
fmapName    DExp -> DExp -> DExp
`DAppE` DExp
g'
                            Immediate DExp
_ -> Replacer -> q Replacer
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Replacer -> q Replacer)
-> (DExp -> Replacer) -> DExp -> q Replacer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DExp -> Replacer
Nested (DExp -> q Replacer) -> DExp -> q Replacer
forall a b. (a -> b) -> a -> b
$ Name -> DExp
DVarE Name
replaceName DExp -> DExp -> DExp
`DAppE` Name -> DExp
DVarE Name
z
                        -- (p <$) = fmap (p <$)
                      , ft_forall :: [DTyVarBndr] -> q Replacer -> q Replacer
ft_forall  = \[DTyVarBndr]
_ q Replacer
g -> q Replacer
g
                      , ft_bad_app :: q Replacer
ft_bad_app = String -> q Replacer
forall a. HasCallStack => String -> a
error String
"in other argument in ft_replace"
                      }

      -- Con a1 a2 ... -> Con (f1 a1) (f2 a2) ...
      clause_for_con :: [DPat] -> DCon -> [DExp] -> q DClause
      clause_for_con :: [DPat] -> DCon -> [DExp] -> q DClause
clause_for_con = (Name -> [DExp] -> DExp) -> [DPat] -> DCon -> [DExp] -> q DClause
forall (q :: * -> *).
Quasi q =>
(Name -> [DExp] -> DExp) -> [DPat] -> DCon -> [DExp] -> q DClause
mkSimpleConClause ((Name -> [DExp] -> DExp) -> [DPat] -> DCon -> [DExp] -> q DClause)
-> (Name -> [DExp] -> DExp)
-> [DPat]
-> DCon
-> [DExp]
-> q DClause
forall a b. (a -> b) -> a -> b
$ \Name
con_name ->
        DExp -> [DExp] -> DExp
foldExp (Name -> DExp
DConE Name
con_name) -- Con x1 x2 ...

      mk_fmap_clause :: DCon -> q DClause
      mk_fmap_clause :: DCon -> q DClause
mk_fmap_clause DCon
con = do
        [q DExp]
parts <- FFoldType (q DExp) -> DCon -> q [q DExp]
forall (q :: * -> *) a. DsMonad q => FFoldType a -> DCon -> q [a]
foldDataConArgs FFoldType (q DExp)
ft_fmap DCon
con
        [DPat] -> DCon -> [DExp] -> q DClause
clause_for_con [Name -> DPat
DVarP Name
f] DCon
con ([DExp] -> q DClause) -> q [DExp] -> q DClause
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [q DExp] -> q [DExp]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [q DExp]
parts

      mk_replace_clause :: DCon -> q DClause
      mk_replace_clause :: DCon -> q DClause
mk_replace_clause DCon
con = do
        [q Replacer]
parts <- FFoldType (q Replacer) -> DCon -> q [q Replacer]
forall (q :: * -> *) a. DsMonad q => FFoldType a -> DCon -> q [a]
foldDataConArgs FFoldType (q Replacer)
ft_replace DCon
con
        [DPat] -> DCon -> [DExp] -> q DClause
clause_for_con [Name -> DPat
DVarP Name
z] DCon
con ([DExp] -> q DClause) -> q [DExp] -> q DClause
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (q Replacer -> q DExp) -> [q Replacer] -> q [DExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Replacer -> DExp) -> q Replacer -> q DExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Replacer -> DExp
replace) [q Replacer]
parts

      mk_fmap :: q [DClause]
      mk_fmap :: q [DClause]
mk_fmap = case [DCon]
cons of
                  [] -> do Name
v <- String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"v"
                           [DClause] -> q [DClause]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [[DPat] -> DExp -> DClause
DClause [DPat
DWildP, Name -> DPat
DVarP Name
v] (DExp -> [DMatch] -> DExp
DCaseE (Name -> DExp
DVarE Name
v) [])]
                  [DCon]
_  -> (DCon -> q DClause) -> [DCon] -> q [DClause]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse DCon -> q DClause
mk_fmap_clause [DCon]
cons

      mk_replace :: q [DClause]
      mk_replace :: q [DClause]
mk_replace = case [DCon]
cons of
                     [] -> do Name
v <- String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"v"
                              [DClause] -> q [DClause]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [[DPat] -> DExp -> DClause
DClause [DPat
DWildP, Name -> DPat
DVarP Name
v] (DExp -> [DMatch] -> DExp
DCaseE (Name -> DExp
DVarE Name
v) [])]
                     [DCon]
_  -> (DCon -> q DClause) -> [DCon] -> q [DClause]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse DCon -> q DClause
mk_replace_clause [DCon]
cons

  [DClause]
fmap_clauses    <- q [DClause]
mk_fmap
  [DClause]
replace_clauses <- q [DClause]
mk_replace
  DCxt
constraints <- Maybe DCxt -> DType -> DType -> [DCon] -> q DCxt
forall (q :: * -> *).
DsMonad q =>
Maybe DCxt -> DType -> DType -> [DCon] -> q DCxt
inferConstraintsDef Maybe DCxt
mb_ctxt (Name -> DType
DConT Name
functorName) DType
ty [DCon]
cons
  UInstDecl -> q UInstDecl
forall (m :: * -> *) a. Monad m => a -> m a
return (UInstDecl -> q UInstDecl) -> UInstDecl -> q UInstDecl
forall a b. (a -> b) -> a -> b
$ InstDecl :: forall (ann :: AnnotationFlag).
DCxt
-> Name
-> DCxt
-> OMap Name DType
-> [(Name, LetDecRHS ann)]
-> InstDecl ann
InstDecl { id_cxt :: DCxt
id_cxt = DCxt
constraints
                    , id_name :: Name
id_name = Name
functorName
                    , id_arg_tys :: DCxt
id_arg_tys = [DType
ty]
                    , id_sigs :: OMap Name DType
id_sigs  = OMap Name DType
forall a. Monoid a => a
mempty
                    , id_meths :: [(Name, LetDecRHS Unannotated)]
id_meths = [ (Name
fmapName,    [DClause] -> LetDecRHS Unannotated
UFunction [DClause]
fmap_clauses)
                                 , (Name
replaceName, [DClause] -> LetDecRHS Unannotated
UFunction [DClause]
replace_clauses)
                                 ] }

data Replacer = Immediate { Replacer -> DExp
replace :: DExp }
              | Nested    { replace :: DExp }