#ifndef MIN_VERSION_template_haskell
#define MIN_VERSION_template_haskell(x,y,z) 1
#endif
module Data.SafeCopy.Derive
(
deriveSafeCopy
, deriveSafeCopyIndexedType
, deriveSafeCopySimple
, deriveSafeCopySimpleIndexedType
, deriveSafeCopyHappstackData
, deriveSafeCopyHappstackDataIndexedType
) where
import Data.Serialize (getWord8, putWord8, label)
import Data.SafeCopy.SafeCopy
#if MIN_VERSION_template_haskell(2,8,0)
import Language.Haskell.TH hiding (Kind)
#else
import Language.Haskell.TH hiding (Kind(..))
#endif
import Control.Applicative
import Control.Monad
import Data.Maybe (fromMaybe)
#ifdef __HADDOCK__
import Data.Word (Word8)
#endif
deriveSafeCopy :: Version a -> Name -> Name -> Q [Dec]
deriveSafeCopy = internalDeriveSafeCopy Normal
deriveSafeCopyIndexedType :: Version a -> Name -> Name -> [Name] -> Q [Dec]
deriveSafeCopyIndexedType = internalDeriveSafeCopyIndexedType Normal
deriveSafeCopySimple :: Version a -> Name -> Name -> Q [Dec]
deriveSafeCopySimple = internalDeriveSafeCopy Simple
deriveSafeCopySimpleIndexedType :: Version a -> Name -> Name -> [Name] -> Q [Dec]
deriveSafeCopySimpleIndexedType = internalDeriveSafeCopyIndexedType Simple
deriveSafeCopyHappstackData :: Version a -> Name -> Name -> Q [Dec]
deriveSafeCopyHappstackData = internalDeriveSafeCopy HappstackData
deriveSafeCopyHappstackDataIndexedType :: Version a -> Name -> Name -> [Name] -> Q [Dec]
deriveSafeCopyHappstackDataIndexedType = internalDeriveSafeCopyIndexedType HappstackData
data DeriveType = Normal | Simple | HappstackData
forceTag :: DeriveType -> Bool
forceTag HappstackData = True
forceTag _ = False
tyVarName :: TyVarBndr -> Name
tyVarName (PlainTV n) = n
#if MIN_VERSION_template_haskell(2,10,0)
tyVarName (KindedTV n _) = n
#endif
internalDeriveSafeCopy :: DeriveType -> Version a -> Name -> Name -> Q [Dec]
internalDeriveSafeCopy deriveType versionId kindName tyName = do
info <- reify tyName
case info of
TyConI (DataD context _name tyvars cons _derivs)
| length cons > 255 -> fail $ "Can't derive SafeCopy instance for: " ++ show tyName ++
". The datatype must have less than 256 constructors."
| otherwise -> worker context tyvars (zip [0..] cons)
TyConI (NewtypeD context _name tyvars con _derivs) ->
worker context tyvars [(0, con)]
FamilyI _ insts -> do
decs <- forM insts $ \inst ->
case inst of
DataInstD context _name ty cons _derivs ->
worker' (foldl appT (conT tyName) (map return ty)) context [] (zip [0..] cons)
NewtypeInstD context _name ty con _derivs ->
worker' (foldl appT (conT tyName) (map return ty)) context [] [(0, con)]
_ -> fail $ "Can't derive SafeCopy instance for: " ++ show (tyName, inst)
return $ concat decs
_ -> fail $ "Can't derive SafeCopy instance for: " ++ show (tyName, info)
where
worker = worker' (conT tyName)
worker' tyBase context tyvars cons =
let ty = foldl appT tyBase [ varT $ tyVarName var | var <- tyvars ]
#if MIN_VERSION_template_haskell(2,10,0)
safeCopyClass args = foldl appT (conT ''SafeCopy) args
#else
safeCopyClass args = classP ''SafeCopy args
#endif
in (:[]) <$> instanceD (cxt $ [safeCopyClass [varT $ tyVarName var] | var <- tyvars] ++ map return context)
(conT ''SafeCopy `appT` ty)
[ mkPutCopy deriveType cons
, mkGetCopy deriveType (show tyName) cons
, valD (varP 'version) (normalB $ litE $ integerL $ fromIntegral $ unVersion versionId) []
, valD (varP 'kind) (normalB (varE kindName)) []
, funD 'errorTypeName [clause [wildP] (normalB $ litE $ StringL (show tyName)) []]
]
internalDeriveSafeCopyIndexedType :: DeriveType -> Version a -> Name -> Name -> [Name] -> Q [Dec]
internalDeriveSafeCopyIndexedType deriveType versionId kindName tyName tyIndex' = do
tyIndex <- mapM conT tyIndex'
info <- reify tyName
case info of
FamilyI _ insts -> do
decs <- forM insts $ \inst ->
case inst of
DataInstD context _name ty cons _derivs
| ty == tyIndex ->
worker' (foldl appT (conT tyName) (map return ty)) context [] (zip [0..] cons)
| otherwise ->
return []
NewtypeInstD context _name ty con _derivs
| ty == tyIndex ->
worker' (foldl appT (conT tyName) (map return ty)) context [] [(0, con)]
| otherwise ->
return []
_ -> fail $ "Can't derive SafeCopy instance for: " ++ show (tyName, inst)
return $ concat decs
_ -> fail $ "Can't derive SafeCopy instance for: " ++ show (tyName, info)
where
typeNameStr = unwords $ map show (tyName:tyIndex')
worker' tyBase context tyvars cons =
let ty = foldl appT tyBase [ varT $ tyVarName var | var <- tyvars ]
#if MIN_VERSION_template_haskell(2,10,0)
safeCopyClass args = foldl appT (conT ''SafeCopy) args
#else
safeCopyClass args = classP ''SafeCopy args
#endif
in (:[]) <$> instanceD (cxt $ [safeCopyClass [varT $ tyVarName var] | var <- tyvars] ++ map return context)
(conT ''SafeCopy `appT` ty)
[ mkPutCopy deriveType cons
, mkGetCopy deriveType typeNameStr cons
, valD (varP 'version) (normalB $ litE $ integerL $ fromIntegral $ unVersion versionId) []
, valD (varP 'kind) (normalB (varE kindName)) []
, funD 'errorTypeName [clause [wildP] (normalB $ litE $ StringL typeNameStr) []]
]
mkPutCopy :: DeriveType -> [(Integer, Con)] -> DecQ
mkPutCopy deriveType cons = funD 'putCopy $ map mkPutClause cons
where
manyConstructors = length cons > 1 || forceTag deriveType
mkPutClause (conNumber, con)
= do putVars <- replicateM (conSize con) (newName "arg")
(putFunsDecs, putFuns) <- case deriveType of
Normal -> mkSafeFunctions "safePut_" 'getSafePut con
_ -> return ([], const 'safePut)
let putClause = conP (conName con) (map varP putVars)
putCopyBody = varE 'contain `appE` doE (
[ noBindS $ varE 'putWord8 `appE` litE (IntegerL conNumber) | manyConstructors ] ++
putFunsDecs ++
[ noBindS $ varE (putFuns typ) `appE` varE var | (typ, var) <- zip (conTypes con) putVars ] ++
[ noBindS $ varE 'return `appE` tupE [] ])
clause [putClause] (normalB putCopyBody) []
mkGetCopy :: DeriveType -> String -> [(Integer, Con)] -> DecQ
mkGetCopy deriveType tyName cons = valD (varP 'getCopy) (normalB $ varE 'contain `appE` mkLabel) []
where
mkLabel = varE 'label `appE` litE (stringL labelString) `appE` getCopyBody
labelString = tyName ++ ":"
getCopyBody
= case cons of
[(_, con)] | not (forceTag deriveType) -> mkGetBody con
_ -> do
tagVar <- newName "tag"
doE [ bindS (varP tagVar) (varE 'getWord8)
, noBindS $ caseE (varE tagVar) (
[ match (litP $ IntegerL i) (normalB $ mkGetBody con) [] | (i, con) <- cons ] ++
[ match wildP (normalB $ varE 'fail `appE` errorMsg tagVar) [] ]) ]
mkGetBody con
= do (getFunsDecs, getFuns) <- case deriveType of
Normal -> mkSafeFunctions "safeGet_" 'getSafeGet con
_ -> return ([], const 'safeGet)
let getBase = appE (varE 'return) (conE (conName con))
getArgs = foldl (\a t -> infixE (Just a) (varE '(<*>)) (Just (varE (getFuns t)))) getBase (conTypes con)
doE (getFunsDecs ++ [noBindS getArgs])
errorMsg tagVar = infixE (Just $ strE str1) (varE '(++)) $ Just $
infixE (Just tagStr) (varE '(++)) (Just $ strE str2)
where
strE = litE . StringL
tagStr = varE 'show `appE` varE tagVar
str1 = "Could not identify tag \""
str2 = concat [ "\" for type "
, show tyName
, " that has only "
, show (length cons)
, " constructors. Maybe your data is corrupted?" ]
mkSafeFunctions :: String -> Name -> Con -> Q ([StmtQ], Type -> Name)
mkSafeFunctions name baseFun con = do let origTypes = conTypes con
realTypes <- mapM followSynonyms origTypes
finish (zip origTypes realTypes) <$> foldM go ([], []) realTypes
where go (ds, fs) t
| found = return (ds, fs)
| otherwise = do funVar <- newName (name ++ typeName t)
return ( bindS (varP funVar) (varE baseFun) : ds
, (t, funVar) : fs )
where found = any ((== t) . fst) fs
finish
:: [(Type, Type)]
-> ([StmtQ], [(Type, Name)])
-> ([StmtQ], Type -> Name)
finish typeList (ds, fs) = (reverse ds, getName)
where getName typ = fromMaybe err $ lookup typ typeList >>= flip lookup fs
err = error "mkSafeFunctions: never here"
followSynonyms :: Type -> Q Type
followSynonyms t@(ConT name)
= maybe (return t) followSynonyms =<<
recover (return Nothing) (do info <- reify name
return $ case info of
TyVarI _ ty -> Just ty
TyConI (TySynD _ _ ty) -> Just ty
_ -> Nothing)
followSynonyms (AppT ty1 ty2) = liftM2 AppT (followSynonyms ty1) (followSynonyms ty2)
followSynonyms (SigT ty k) = liftM (flip SigT k) (followSynonyms ty)
followSynonyms t = return t
conSize :: Con -> Int
conSize (NormalC _name args) = length args
conSize (RecC _name recs) = length recs
conSize InfixC{} = 2
conSize ForallC{} = error "Found complex constructor. Cannot derive SafeCopy for it."
conName :: Con -> Name
conName (NormalC name _args) = name
conName (RecC name _recs) = name
conName (InfixC _ name _) = name
conName _ = error "conName: never here"
conTypes :: Con -> [Type]
conTypes (NormalC _name args) = [t | (_, t) <- args]
conTypes (RecC _name args) = [t | (_, _, t) <- args]
conTypes (InfixC (_, t1) _ (_, t2)) = [t1, t2]
conTypes _ = error "conName: never here"
typeName :: Type -> String
typeName (VarT name) = nameBase name
typeName (ConT name) = nameBase name
typeName (TupleT n) = "Tuple" ++ show n
typeName ArrowT = "Arrow"
typeName ListT = "List"
typeName (AppT t u) = typeName t ++ typeName u
typeName (SigT t _k) = typeName t
typeName _ = "_"