#ifndef HAVE_OVERLOADED_LABELS
#endif
#if 0
#if HAVE_MONAD_FAIL && MIN_VERSION_template_haskell(2,11,0)
#define _FAIL_IN_MONAD
#else
#define _FAIL_IN_MONAD , fail
#endif
#endif
#define _FAIL_IN_MONAD , fail
module Data.OverloadedRecords.TH
(
overloadedRecord
, overloadedRecords
, overloadedRecordFor
, overloadedRecordsFor
, DeriveOverloadedRecordsParams
#ifndef HAVE_OVERLOADED_LABELS
, fieldDerivation
#endif
, FieldDerivation
, OverloadedField(..)
, defaultFieldDerivation
, defaultMakeFieldName
, field
, simpleField
, fieldGetter
, fieldSetter
, simpleFieldSetter
)
where
import Prelude (Num(()), fromIntegral)
import Control.Applicative (Applicative((<*>)))
import Control.Monad (Monad((>>=) _FAIL_IN_MONAD, return), replicateM)
#if 0
#if HAVE_MONAD_FAIL && MIN_VERSION_template_haskell(2,11,0)
import Control.Monad.Fail (MonadFail(fail))
#endif
#endif
import Data.Bool (Bool(False), otherwise)
import qualified Data.Char as Char (toLower)
import Data.Foldable (concat, foldl)
import Data.Function ((.), ($))
import Data.Functor (Functor(fmap), (<$>))
import qualified Data.List as List
( drop
, isPrefixOf
, length
, map
, replicate
, zip
)
import Data.Maybe (Maybe(Just, Nothing), fromMaybe)
import Data.Monoid ((<>))
import Data.String (String)
import Data.Traversable (forM, mapM)
import Data.Typeable (Typeable)
import Data.Word (Word)
import GHC.Generics (Generic)
#ifndef HAVE_OVERLOADED_LABELS
import GHC.Exts (Proxy#, proxy#)
#endif
import Text.Show (Show(show))
import Language.Haskell.TH
( Con(ForallC, InfixC, NormalC, RecC)
, Dec(DataD, NewtypeD)
, DecsQ
, ExpQ
, Info(TyConI)
, Name
, Pat(ConP, VarP, WildP)
, PatQ
, Q
, Strict
, Type
, TypeQ
, TyVarBndr(KindedTV, PlainTV)
, appE
, appT
, conE
, conP
, conT
, lamE
, litT
, nameBase
, newName
, recUpdE
, reify
, strTyLit
, varE
, varP
, varT
, wildP
)
import Data.Default.Class (Default(def))
#ifndef HAVE_OVERLOADED_LABELS
import Data.OverloadedLabels (IsLabel(fromLabel))
#endif
import Data.OverloadedRecords
( FieldType
, HasField(getField)
, SetField(setField)
, UpdateType
)
#ifndef HAVE_OVERLOADED_LABELS
fieldDerivation :: IsLabel "fieldDerivation" a => a
fieldDerivation = fromLabel (proxy# :: Proxy# "fieldDerivation")
#endif
data DeriveOverloadedRecordsParams = DeriveOverloadedRecordsParams
{ _strictFields :: Bool
, _fieldDerivation :: FieldDerivation
}
deriving (Generic, Typeable)
type instance FieldType "fieldDerivation" DeriveOverloadedRecordsParams =
FieldDerivation
instance
HasField "fieldDerivation" DeriveOverloadedRecordsParams FieldDerivation
where
getField _proxy = _fieldDerivation
type instance
UpdateType "fieldDerivation" DeriveOverloadedRecordsParams FieldDerivation =
DeriveOverloadedRecordsParams
instance
SetField "fieldDerivation" DeriveOverloadedRecordsParams FieldDerivation
where
setField _proxy s b = s{_fieldDerivation = b}
data OverloadedField
= GetterOnlyField String (Maybe ExpQ)
| GetterAndSetterField String (Maybe (ExpQ, ExpQ))
deriving (Generic, Typeable)
type FieldDerivation
= String
-> String
-> Word
-> Maybe String
-> Maybe OverloadedField
defaultMakeFieldName
:: String
-> String
-> Word
-> Maybe String
-> Maybe String
defaultMakeFieldName typeName constructorName _fieldPosition = \case
Nothing -> Nothing
Just fieldName
| startsWith "_" -> Just $ dropPrefix "_" fieldName
| startsWith typePrefix -> Just $ dropPrefix typePrefix fieldName
| startsWith conPrefix -> Just $ dropPrefix conPrefix fieldName
| otherwise -> Nothing
where
startsWith :: String -> Bool
startsWith = (`List.isPrefixOf` fieldName)
dropPrefix :: String -> String -> String
dropPrefix s = headToLower . List.drop (List.length s)
headToLower :: String -> String
headToLower "" = ""
headToLower (x : xs) = Char.toLower x : xs
typePrefix, conPrefix :: String
typePrefix = headToLower typeName
conPrefix = headToLower constructorName
defaultFieldDerivation :: FieldDerivation
defaultFieldDerivation =
(((fmap (`GetterAndSetterField` Nothing) .) .) .) . defaultMakeFieldName
instance Default DeriveOverloadedRecordsParams where
def = DeriveOverloadedRecordsParams
{ _strictFields = False
, _fieldDerivation = defaultFieldDerivation
}
overloadedRecord
:: DeriveOverloadedRecordsParams
-> Name
-> DecsQ
overloadedRecord params = withReified $ \name -> \case
TyConI dec -> case dec of
#if MIN_VERSION_template_haskell(2,11,0)
NewtypeD [] typeName typeVars _kindSignature constructor _deriving ->
#else
NewtypeD [] typeName typeVars constructor _deriving ->
#endif
deriveForConstructor params typeName typeVars constructor
#if MIN_VERSION_template_haskell(2,11,0)
DataD [] typeName typeVars _kindSignature constructors _deriving ->
#else
DataD [] typeName typeVars constructors _deriving ->
#endif
fmap concat . forM constructors
$ deriveForConstructor params typeName typeVars
x -> canNotDeriveError name x
x -> canNotDeriveError name x
where
withReified :: (Name -> Info -> Q a) -> Name -> Q a
withReified f t = (reify t >>= f t)
canNotDeriveError :: Show a => Name -> a -> Q b
canNotDeriveError = (fail .) . errMessage
errMessage :: Show a => Name -> a -> String
errMessage n x =
"`" <> show n <> "' is neither newtype nor data type: " <> show x
overloadedRecords
:: DeriveOverloadedRecordsParams
-> [Name]
-> DecsQ
overloadedRecords params = fmap concat . mapM (overloadedRecord params)
overloadedRecordFor
:: Name
-> (DeriveOverloadedRecordsParams -> DeriveOverloadedRecordsParams)
-> DecsQ
overloadedRecordFor typeName f = overloadedRecord (f def) typeName
overloadedRecordsFor
:: [Name]
-> (DeriveOverloadedRecordsParams -> DeriveOverloadedRecordsParams)
-> DecsQ
overloadedRecordsFor typeNames f = overloadedRecords (f def) typeNames
deriveForConstructor
:: DeriveOverloadedRecordsParams
-> Name
-> [TyVarBndr]
-> Con
-> DecsQ
deriveForConstructor params name typeVars = \case
NormalC constructorName args ->
deriveFor constructorName args $ \(strict, argType) f ->
f Nothing strict argType
RecC constructorName args ->
deriveFor constructorName args $ \(accessor, strict, argType) f ->
f (Just accessor) strict argType
InfixC arg0 constructorName arg1 ->
deriveFor constructorName [arg0, arg1] $ \(strict, argType) f ->
f Nothing strict argType
#if 0
#if MIN_VERSION_template_haskell(2,11,0)
GadtC _ _ _ ->
RecGadtC _ _ _ ->
#endif
#endif
ForallC _typeVariables _context _constructor -> return []
where
deriveFor
:: Name
-> [a]
-> (a -> (Maybe Name -> Strict -> Type -> DecsQ) -> DecsQ)
-> DecsQ
deriveFor constrName args f =
fmap concat . forM (withIndexes args) $ \(idx, arg) ->
f arg $ \accessor strict fieldType' ->
deriveForField params DeriveFieldParams
{ typeName = name
, typeVariables = List.map getTypeName typeVars
, constructorName = constrName
, numberOfArgs = fromIntegral $ List.length args
, currentIndex = idx
, accessorName = accessor
, strictness = strict
, fieldType = fieldType'
}
where
getTypeName :: TyVarBndr -> Name
getTypeName = \case
PlainTV n -> n
KindedTV n _kind -> n
withIndexes = List.zip [(0 :: Word) ..]
data DeriveFieldParams = DeriveFieldParams
{ typeName :: Name
, typeVariables :: [Name]
, constructorName :: Name
, numberOfArgs :: Word
, currentIndex :: Word
, accessorName :: Maybe Name
, strictness :: Strict
, fieldType :: Type
}
deriveForField
:: DeriveOverloadedRecordsParams
-> DeriveFieldParams
-> DecsQ
deriveForField params DeriveFieldParams{..} =
case possiblyLabel of
Nothing -> return []
Just (GetterOnlyField label customGetterExpr) ->
deriveGetter' (strTyLitT label)
$ fromMaybe derivedGetterExpr customGetterExpr
Just (GetterAndSetterField label customGetterAndSetterExpr) -> (<>)
<$> deriveGetter' labelType getterExpr
<*> deriveSetter' labelType setterExpr
where
labelType = strTyLitT label
(getterExpr, setterExpr) =
fromMaybe (derivedGetterExpr, derivedSetterExpr)
customGetterAndSetterExpr
where
possiblyLabel = _fieldDerivation params (nameBase typeName)
(nameBase constructorName) currentIndex (fmap nameBase accessorName)
deriveGetter' labelType =
deriveGetter labelType recordType (return fieldType)
deriveSetter' labelType =
deriveSetter labelType recordType (return fieldType) newRecordType
newFieldType
recordType = foldl appT (conT typeName) $ List.map varT typeVariables
newFieldType = return fieldType
newRecordType = recordType
numVarsOnRight = numberOfArgs currentIndex 1
inbetween :: (a -> [b]) -> a -> a -> b -> [b]
inbetween f a1 a2 b = f a1 <> (b : f a2)
derivedGetterExpr = case accessorName of
Just name -> varE name
Nothing -> do
a <- newName "a"
lamE [return . ConP constructorName $ nthArg (VarP a)] (varE a)
where
nthArg :: Pat -> [Pat]
nthArg = inbetween wildPs currentIndex numVarsOnRight
derivedSetterExpr = case accessorName of
Just name -> do
s <- newName "s"
b <- newName "b"
lamE [varP s, varP b] $ recUpdE (varE s) [(name, ) <$> varE b]
Nothing -> do
varsBefore <- newNames currentIndex "a"
b <- newName "b"
varsAfter <- newNames numVarsOnRight "a"
lamE [constrPattern varsBefore varsAfter, varP b]
$ constrExpression varsBefore (varE b) varsAfter
where
constrPattern before after =
conP constructorName $ inbetween varPs before after wildP
constrExpression before b after = foldl appE (conE constructorName)
$ varEs before <> (b : varEs after)
field
:: String
-> TypeQ
-> TypeQ
-> TypeQ
-> TypeQ
-> ExpQ
-> ExpQ
-> DecsQ
field label recType fldType newRecType newFldType getterExpr setterExpr = (<>)
<$> deriveGetter labelType recType fldType getterExpr
<*> deriveSetter labelType recType fldType newRecType newFldType setterExpr
where
labelType = strTyLitT label
simpleField
:: String
-> TypeQ
-> TypeQ
-> ExpQ
-> ExpQ
-> DecsQ
simpleField label recType fldType = field label recType fldType recType fldType
fieldGetter
:: String
-> TypeQ
-> TypeQ
-> ExpQ
-> DecsQ
fieldGetter = deriveGetter . strTyLitT
deriveGetter :: TypeQ -> TypeQ -> TypeQ -> ExpQ -> DecsQ
deriveGetter labelType recordType fieldType getter =
[d| type instance FieldType $(labelType) $(recordType) = $(fieldType)
instance HasField $(labelType) $(recordType) $(fieldType) where
getField _proxy = $(getter)
|]
simpleFieldSetter
:: String
-> TypeQ
-> TypeQ
-> ExpQ
-> DecsQ
simpleFieldSetter label recordType fieldType =
fieldSetter label recordType fieldType recordType fieldType
fieldSetter
:: String
-> TypeQ
-> TypeQ
-> TypeQ
-> TypeQ
-> ExpQ
-> DecsQ
fieldSetter = deriveSetter . strTyLitT
deriveSetter :: TypeQ -> TypeQ -> TypeQ -> TypeQ -> TypeQ -> ExpQ -> DecsQ
deriveSetter labelType recordType fieldType newRecordType newFieldType setter =
[d| type instance UpdateType $(labelType) $(recordType) $(newFieldType) =
$(newRecordType)
instance SetField $(labelType) $(recordType) $(fieldType) where
setField _proxy = $(setter)
|]
wildPs :: Word -> [Pat]
wildPs n = List.replicate (fromIntegral n) WildP
newNames :: Word -> String -> Q [Name]
newNames n s = fromIntegral n `replicateM` newName s
varPs :: [Name] -> [PatQ]
varPs = List.map varP
varEs :: [Name] -> [ExpQ]
varEs = List.map varE
strTyLitT :: String -> TypeQ
strTyLitT = litT . strTyLit