#ifndef HAVE_OVERLOADED_LABELS
#endif
#if 0
#if HAVE_MONAD_FAIL && MIN_VERSION_template_haskell(2,11,0)
#define _FAIL_IN_MONAD
#else
#define _FAIL_IN_MONAD , fail
#endif
#endif
#define _FAIL_IN_MONAD , fail
module Data.OverloadedRecords.TH.Internal
(
overloadedRecord
, overloadedRecords
, overloadedRecordFor
, overloadedRecordsFor
, DeriveOverloadedRecordsParams(..)
#ifndef HAVE_OVERLOADED_LABELS
, fieldDerivation
#endif
, FieldDerivation
, OverloadedField(..)
, defaultFieldDerivation
, defaultMakeFieldName
, field
, simpleField
, fieldGetter
, fieldSetter
, simpleFieldSetter
, DeriveFieldParams(..)
, deriveForConstructor
, deriveForField
, newNames
, strTyLitT
, varEs
, varPs
, wildPs
)
where
import Prelude (Num(()), fromIntegral)
import Control.Applicative (Applicative((<*>)))
import Control.Arrow (Arrow((***)))
import Control.Monad (Monad((>>=) _FAIL_IN_MONAD, return), replicateM)
#if 0
#if HAVE_MONAD_FAIL && MIN_VERSION_template_haskell(2,11,0)
import Control.Monad.Fail (MonadFail(fail))
#endif
#endif
import Data.Bool (Bool(False), otherwise)
import qualified Data.Char as Char (toLower)
import Data.Eq (Eq((==)))
import Data.Foldable (concat, foldl)
import Data.Function ((.), ($), flip)
import Data.Functor (Functor(fmap), (<$>))
import qualified Data.List as List
( drop
, isPrefixOf
, length
, map
, replicate
, zip
, lookup
, unzip
)
import Data.Maybe (Maybe(Just, Nothing), fromMaybe, catMaybes)
import Data.Monoid ((<>))
import Data.String (String)
import Data.Traversable (mapM, sequence)
import Data.Tuple (fst)
import Data.Typeable (Typeable)
import Data.Word (Word)
import GHC.Generics (Generic)
#ifndef HAVE_OVERLOADED_LABELS
import GHC.Exts (Proxy#, proxy#)
#endif
import Text.Show (Show(show))
import Language.Haskell.TH
( Con
( ForallC
, InfixC
, NormalC
, RecC
#if MIN_VERSION_template_haskell(2,11,0)
, GadtC
, RecGadtC
#endif
)
, Dec(DataD, NewtypeD)
, DecsQ
, ExpQ
, Info(TyConI)
, Name
, Pat(ConP, VarP, WildP)
, PatQ
, Q
, Strict
, Type
, TypeQ
, TyVarBndr(KindedTV, PlainTV)
, appE
, appT
, conE
, conP
, conT
, lamE
, litT
, nameBase
, newName
, recUpdE
, reify
, strTyLit
, varE
, varP
, varT
, wildP
)
import Data.Default.Class (Default(def))
#ifndef HAVE_OVERLOADED_LABELS
import Data.OverloadedLabels (IsLabel(fromLabel))
#endif
import Data.OverloadedRecords
( FieldType
, HasField(getField)
, ModifyField(modifyField, setField)
, ModifyRec(getRecField, modifyRecField, setRecField)
, Rec
, UpdateType
)
#ifndef HAVE_OVERLOADED_LABELS
fieldDerivation :: IsLabel "fieldDerivation" a => a
fieldDerivation = fromLabel (proxy# :: Proxy# "fieldDerivation")
#endif
data DeriveOverloadedRecordsParams = DeriveOverloadedRecordsParams
{ _strictFields :: Bool
, _fieldDerivation :: FieldDerivation
}
deriving (Generic, Typeable)
type instance FieldType "fieldDerivation" DeriveOverloadedRecordsParams =
FieldDerivation
instance
HasField "fieldDerivation" DeriveOverloadedRecordsParams FieldDerivation
where
getField _proxy = _fieldDerivation
type instance
UpdateType "fieldDerivation" DeriveOverloadedRecordsParams FieldDerivation =
DeriveOverloadedRecordsParams
instance
ModifyField "fieldDerivation" DeriveOverloadedRecordsParams
DeriveOverloadedRecordsParams FieldDerivation FieldDerivation
where
setField _proxy s b = s{_fieldDerivation = b}
data OverloadedField
= GetterOnlyField String (Maybe ExpQ)
| GetterAndSetterField String (Maybe (ExpQ, ExpQ))
deriving (Generic, Typeable)
type FieldDerivation
= String
-> String
-> Word
-> Maybe String
-> Maybe OverloadedField
defaultMakeFieldName
:: String
-> String
-> Word
-> Maybe String
-> Maybe String
defaultMakeFieldName typeName constructorName _fieldPosition = \case
Nothing -> Nothing
Just fieldName
| startsWith "_" -> Just $ dropPrefix "_" fieldName
| startsWith typePrefix -> Just $ dropPrefix typePrefix fieldName
| startsWith conPrefix -> Just $ dropPrefix conPrefix fieldName
| otherwise -> Nothing
where
startsWith :: String -> Bool
startsWith = (`List.isPrefixOf` fieldName)
dropPrefix :: String -> String -> String
dropPrefix s = headToLower . List.drop (List.length s)
headToLower :: String -> String
headToLower "" = ""
headToLower (x : xs) = Char.toLower x : xs
typePrefix, conPrefix :: String
typePrefix = headToLower typeName
conPrefix = headToLower constructorName
defaultFieldDerivation :: FieldDerivation
defaultFieldDerivation =
(((fmap (`GetterAndSetterField` Nothing) .) .) .) . defaultMakeFieldName
instance Default DeriveOverloadedRecordsParams where
def = DeriveOverloadedRecordsParams
{ _strictFields = False
, _fieldDerivation = defaultFieldDerivation
}
overloadedRecord
:: DeriveOverloadedRecordsParams
-> Name
-> DecsQ
overloadedRecord params = withReified $ \name -> \case
TyConI dec -> case dec of
#if MIN_VERSION_template_haskell(2,11,0)
NewtypeD [] typeName typeVars _kindSignature constructor _deriving ->
#else
NewtypeD [] typeName typeVars constructor _deriving ->
#endif
fst $ deriveForConstructor params [] typeName typeVars constructor
#if MIN_VERSION_template_haskell(2,11,0)
DataD [] typeName typeVars _kindSignature constructors _deriving ->
#else
DataD [] typeName typeVars constructors _deriving ->
#endif
fst $ foldl go (return [], []) constructors
where
go :: (DecsQ, [(String, String)])
-> Con
-> (DecsQ, [(String, String)])
go (decs, seen) con =
let (decs', seen') =
deriveForConstructor params seen typeName typeVars con
in ((<>) <$> decs <*> decs', seen <> seen')
x -> canNotDeriveError name x
x -> canNotDeriveError name x
where
withReified :: (Name -> Info -> Q a) -> Name -> Q a
withReified f t = reify t >>= f t
canNotDeriveError :: Show a => Name -> a -> Q b
canNotDeriveError = (fail .) . errMessage
errMessage :: Show a => Name -> a -> String
errMessage n x =
"`" <> show n <> "' is neither newtype nor data type: " <> show x
overloadedRecords
:: DeriveOverloadedRecordsParams
-> [Name]
-> DecsQ
overloadedRecords params = fmap concat . mapM (overloadedRecord params)
overloadedRecordFor
:: Name
-> (DeriveOverloadedRecordsParams -> DeriveOverloadedRecordsParams)
-> DecsQ
overloadedRecordFor typeName f = overloadedRecord (f def) typeName
overloadedRecordsFor
:: [Name]
-> (DeriveOverloadedRecordsParams -> DeriveOverloadedRecordsParams)
-> DecsQ
overloadedRecordsFor typeNames f = overloadedRecords (f def) typeNames
deriveForConstructor
:: DeriveOverloadedRecordsParams
-> [(String, String)]
-> Name
-> [TyVarBndr]
-> Con
-> (DecsQ, [(String, String)])
deriveForConstructor params seen name typeVars = \case
NormalC constructorName args ->
deriveFor constructorName args $ \(strict, argType) f ->
f Nothing strict argType
RecC constructorName args ->
deriveFor constructorName args $ \(accessor, strict, argType) f ->
f (Just accessor) strict argType
InfixC arg0 constructorName arg1 ->
deriveFor constructorName [arg0, arg1] $ \(strict, argType) f ->
f Nothing strict argType
#if MIN_VERSION_template_haskell(2,11,0)
GadtC _ _ _ -> (fail "GADTs aren't yet supported.", [])
RecGadtC _ _ _ -> (fail "GADTs aren't yet supported.", [])
#endif
ForallC _typeVariables _context _constructor -> (return [], [])
where
deriveFor
:: Name
-> [a]
-> ( a
-> (Maybe Name -> Strict -> Type -> (DecsQ, Maybe (String, String)))
-> (DecsQ, Maybe (String, String))
)
-> (DecsQ, [(String, String)])
deriveFor constrName args f =
concatBoth . flip fmap (withIndexes args) $ \(idx, arg) ->
f arg $ \accessor strict fieldType' ->
deriveForField params seen DeriveFieldParams
{ typeName = name
, typeVariables = List.map getTypeName typeVars
, constructorName = constrName
, numberOfArgs = fromIntegral $ List.length args
, currentIndex = idx
, accessorName = accessor
, strictness = strict
, fieldType = fieldType'
}
where
getTypeName :: TyVarBndr -> Name
getTypeName = \case
PlainTV n -> n
KindedTV n _kind -> n
concatBoth :: [(Q [a], Maybe b)] -> (Q [a], [b])
concatBoth = (fmap concat . sequence *** catMaybes) . List.unzip
withIndexes = List.zip [(0 :: Word) ..]
data DeriveFieldParams = DeriveFieldParams
{ typeName :: Name
, typeVariables :: [Name]
, constructorName :: Name
, numberOfArgs :: Word
, currentIndex :: Word
, accessorName :: Maybe Name
, strictness :: Strict
, fieldType :: Type
}
deriveForField
:: DeriveOverloadedRecordsParams
-> [(String, String)]
-> DeriveFieldParams
-> (DecsQ, Maybe (String, String))
deriveForField params seen DeriveFieldParams{..} =
case possiblyLabel of
Nothing -> (return [], Nothing)
Just (GetterOnlyField label customGetterExpr) ->
ifNotSeenAlreadyThenDo label . deriveGetter' (strTyLitT label)
$ fromMaybe derivedGetterExpr customGetterExpr
Just (GetterAndSetterField label customGetterAndSetterExpr) ->
ifNotSeenAlreadyThenDo label $ (<>)
<$> deriveGetter' labelType getterExpr
<*> deriveSetter' labelType setterExpr
where
labelType = strTyLitT label
(getterExpr, setterExpr) =
fromMaybe (derivedGetterExpr, derivedSetterExpr)
customGetterAndSetterExpr
where
accessorBase = fmap nameBase accessorName
ifNotSeenAlreadyThenDo
:: String
-> DecsQ
-> (DecsQ, Maybe (String, String))
ifNotSeenAlreadyThenDo label action =
case List.lookup label seen of
x@(Just from)
| x == accessorBase -> (return [], Nothing)
| otherwise -> (nameConflictError from, Nothing)
Nothing -> (action, (,) label <$> accessorBase)
nameConflictError n =
fail $ "Two different fields map to the same label \"" <> n <> "\""
possiblyLabel = _fieldDerivation params (nameBase typeName)
(nameBase constructorName) currentIndex accessorBase
deriveGetter' labelType =
deriveGetter labelType recordType (return fieldType)
deriveSetter' labelType =
deriveSetter labelType recordType (return fieldType) newRecordType
newFieldType
recordType = foldl appT (conT typeName) $ List.map varT typeVariables
newFieldType = return fieldType
newRecordType = recordType
numVarsOnRight = numberOfArgs currentIndex 1
inbetween :: (a -> [b]) -> a -> a -> b -> [b]
inbetween f a1 a2 b = f a1 <> (b : f a2)
derivedGetterExpr = case accessorName of
Just name -> varE name
Nothing -> do
a <- newName "a"
lamE [return . ConP constructorName $ nthArg (VarP a)] (varE a)
where
nthArg :: Pat -> [Pat]
nthArg = inbetween wildPs currentIndex numVarsOnRight
derivedSetterExpr = case accessorName of
Just name -> do
s <- newName "s"
b <- newName "b"
lamE [varP s, varP b] $ recUpdE (varE s) [(name, ) <$> varE b]
Nothing -> do
varsBefore <- newNames currentIndex "a"
b <- newName "b"
varsAfter <- newNames numVarsOnRight "a"
lamE [constrPattern varsBefore varsAfter, varP b]
$ constrExpression varsBefore (varE b) varsAfter
where
constrPattern before after =
conP constructorName $ inbetween varPs before after wildP
constrExpression before b after = foldl appE (conE constructorName)
$ varEs before <> (b : varEs after)
field
:: String
-> TypeQ
-> TypeQ
-> TypeQ
-> TypeQ
-> ExpQ
-> ExpQ
-> DecsQ
field label recType fldType newRecType newFldType getterExpr setterExpr = (<>)
<$> deriveGetter labelType recType fldType getterExpr
<*> deriveSetter labelType recType fldType newRecType newFldType setterExpr
where
labelType = strTyLitT label
simpleField
:: String
-> TypeQ
-> TypeQ
-> ExpQ
-> ExpQ
-> DecsQ
simpleField label recType fldType = field label recType fldType recType fldType
fieldGetter
:: String
-> TypeQ
-> TypeQ
-> ExpQ
-> DecsQ
fieldGetter = deriveGetter . strTyLitT
deriveGetter :: TypeQ -> TypeQ -> TypeQ -> ExpQ -> DecsQ
deriveGetter labelType recordType fieldType getter =
[d| type instance FieldType $(labelType) $(recordType) = $(fieldType)
instance HasField $(labelType) $(recordType) $(fieldType) where
getField _proxy = $(getter)
instance
( ModifyRec $(labelType) $(fieldType) cs
) => HasField $(labelType) (Rec cs $(recordType)) $(fieldType)
where
getField = getRecField
|]
simpleFieldSetter
:: String
-> TypeQ
-> TypeQ
-> ExpQ
-> DecsQ
simpleFieldSetter label recordType fieldType =
fieldSetter label recordType fieldType recordType fieldType
fieldSetter
:: String
-> TypeQ
-> TypeQ
-> TypeQ
-> TypeQ
-> ExpQ
-> DecsQ
fieldSetter = deriveSetter . strTyLitT
deriveSetter :: TypeQ -> TypeQ -> TypeQ -> TypeQ -> TypeQ -> ExpQ -> DecsQ
deriveSetter labelType recordType fieldType newRecordType newFieldType setter =
[d| type instance UpdateType $(labelType) $(recordType) $(newFieldType) =
$(newRecordType)
instance
ModifyField $(labelType) $(recordType) $(newRecordType)
$(fieldType) $(newFieldType)
where
setField _proxy = $(setter)
instance
( ModifyRec $(labelType) $(fieldType) cs
) => ModifyField $(labelType) (Rec cs $(recordType))
(Rec cs $(recordType)) $(fieldType) $(fieldType)
where
setField = setRecField
modifyField = modifyRecField
|]
wildPs :: Word -> [Pat]
wildPs n = List.replicate (fromIntegral n) WildP
newNames :: Word -> String -> Q [Name]
newNames n s = fromIntegral n `replicateM` newName s
varPs :: [Name] -> [PatQ]
varPs = List.map varP
varEs :: [Name] -> [ExpQ]
varEs = List.map varE
strTyLitT :: String -> TypeQ
strTyLitT = litT . strTyLit