{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
module Data.Avro.Deriving
(
DeriveOptions(..)
, FieldStrictness(..)
, FieldUnpackedness(..)
, NamespaceBehavior(..)
, defaultDeriveOptions
, mkPrefixedFieldName
, mkAsIsFieldName
, mkLazyField
, mkStrictPrimitiveField
, makeSchema
, makeSchemaFrom
, makeSchemaFromByteString
, deriveAvroWithOptions
, deriveAvroWithOptions'
, deriveAvroFromByteString
, deriveAvro
, deriveAvro'
, r
)
where
import Control.Monad (join)
import Control.Monad.Identity (Identity)
import Data.Aeson (eitherDecode)
import qualified Data.Aeson as J
import Data.Avro hiding (decode, encode)
import Data.Avro.Encoding.ToAvro (ToAvro (..))
import Data.Avro.Internal.EncodeRaw (putI)
import Data.Avro.Schema.Schema as S
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.Char (isAlphaNum)
import qualified Data.Foldable as Foldable
import Data.Int
import Data.List.NonEmpty (NonEmpty ((:|)))
import qualified Data.List.NonEmpty as NE
import Data.Map (Map)
import Data.Maybe (fromMaybe)
import Data.Semigroup ((<>))
import qualified Data.Text as Text
import Data.Time (Day, DiffTime, UTCTime)
import Data.UUID (UUID)
import Text.RawString.QQ (r)
import qualified Data.Avro.Encoding.FromAvro as AV
import GHC.Generics (Generic)
import Language.Haskell.TH as TH hiding (notStrict)
import Language.Haskell.TH.Lib as TH hiding (notStrict)
import Language.Haskell.TH.Syntax
import Data.Avro.Deriving.NormSchema
import Data.Avro.EitherN
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import qualified Data.ByteString.Lazy.Char8 as LBSC8
import qualified Data.HashMap.Strict as HM
import qualified Data.Set as S
import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Vector as V
import Data.Avro.Deriving.Lift ()
import Language.Haskell.TH.Syntax (lift)
data NamespaceBehavior =
IgnoreNamespaces
| HandleNamespaces
| Custom (T.Text -> [T.Text] -> T.Text)
data FieldStrictness = StrictField | LazyField
deriving Generic
data FieldUnpackedness = UnpackedField | NonUnpackedField
deriving Generic
data DeriveOptions = DeriveOptions
{
fieldNameBuilder :: Text -> Field -> T.Text
, fieldRepresentation :: TypeName -> Field -> (FieldStrictness, FieldUnpackedness)
, namespaceBehavior :: NamespaceBehavior
} deriving Generic
defaultDeriveOptions = DeriveOptions
{ fieldNameBuilder = mkPrefixedFieldName
, fieldRepresentation = mkLazyField
, namespaceBehavior = IgnoreNamespaces
}
mkPrefixedFieldName :: Text -> Field -> T.Text
mkPrefixedFieldName prefix fld =
sanitiseName $ updateFirst T.toLower prefix <> updateFirst T.toUpper (fldName fld)
mkLazyField :: TypeName -> Field -> (FieldStrictness, FieldUnpackedness)
mkLazyField _ _ =
(LazyField, NonUnpackedField)
mkStrictPrimitiveField :: TypeName -> Field -> (FieldStrictness, FieldUnpackedness)
mkStrictPrimitiveField _ field =
if shouldStricten
then (StrictField, unpackedness)
else (LazyField, NonUnpackedField)
where
unpackedness =
case S.fldType field of
S.Null -> NonUnpackedField
S.Boolean -> NonUnpackedField
_ -> UnpackedField
shouldStricten =
case S.fldType field of
S.Null -> True
S.Boolean -> True
S.Int _ -> True
S.Long _ -> True
S.Float -> True
S.Double -> True
_ -> False
mkAsIsFieldName :: Text -> Field -> Text
mkAsIsFieldName _ = sanitiseName . updateFirst T.toLower . fldName
deriveAvroWithOptions :: DeriveOptions -> FilePath -> Q [Dec]
deriveAvroWithOptions o p = readSchema p >>= deriveAvroWithOptions' o
deriveAvroWithOptions' :: DeriveOptions -> Schema -> Q [Dec]
deriveAvroWithOptions' o s = do
let schemas = extractDerivables s
types <- traverse (genType o) schemas
hasSchema <- traverse (genHasAvroSchema $ namespaceBehavior o) schemas
fromAvros <- traverse (genFromValue $ namespaceBehavior o) schemas
encodeAvros <- traverse (genToAvro o) schemas
pure $ join types <> join hasSchema <> join fromAvros <> join encodeAvros
deriveAvro :: FilePath -> Q [Dec]
deriveAvro = deriveAvroWithOptions defaultDeriveOptions
deriveAvro' :: Schema -> Q [Dec]
deriveAvro' = deriveAvroWithOptions' defaultDeriveOptions
deriveAvroFromByteString :: LBS.ByteString -> Q [Dec]
deriveAvroFromByteString bs = case eitherDecode bs of
Right schema -> deriveAvroWithOptions' defaultDeriveOptions schema
Left err -> fail $ "Unable to generate Avro from bytestring: " <> err
makeSchema :: FilePath -> Q Exp
makeSchema p = readSchema p >>= lift
makeSchemaFromByteString :: LBS.ByteString -> Q Exp
makeSchemaFromByteString bs = case eitherDecode @Schema bs of
Right schema -> lift schema
Left err -> fail $ "Unable to generate Avro Schema from bytestring: " <> err
makeSchemaFrom :: FilePath -> Text -> Q Exp
makeSchemaFrom p name = do
s <- readSchema p
case subdefinition s name of
Nothing -> fail $ "No such entity '" <> T.unpack name <> "' defined in " <> p
Just ss -> lift ss
readSchema :: FilePath -> Q Schema
readSchema p = do
qAddDependentFile p
mbSchema <- runIO $ decodeSchema p
case mbSchema of
Left err -> fail $ "Unable to generate AVRO for " <> p <> ": " <> err
Right sch -> pure sch
badValueNew :: Show v => v -> String -> Either String a
badValueNew v t = Left $ "Unexpected value for '" <> t <> "': " <> show v
genFromValue :: NamespaceBehavior -> Schema -> Q [Dec]
genFromValue namespaceBehavior (S.Enum n _ _ _ ) =
[d| instance AV.FromAvro $(conT $ mkDataTypeName namespaceBehavior n) where
fromAvro (AV.Enum _ i _) = $([| pure . toEnum|]) i
fromAvro value = $( [|\v -> badValueNew v $(mkTextLit $ S.renderFullname n)|] ) value
|]
genFromValue namespaceBehavior (S.Record n _ _ fs) =
[d| instance AV.FromAvro $(conT $ mkDataTypeName namespaceBehavior n) where
fromAvro (AV.Record _ r) =
$(genFromAvroNewFieldsExp (mkDataTypeName namespaceBehavior n) fs) r
fromAvro value = $( [|\v -> badValueNew v $(mkTextLit $ S.renderFullname n)|] ) value
|]
genFromValue namespaceBehavior (S.Fixed n _ s _) =
[d| instance AV.FromAvro $(conT $ mkDataTypeName namespaceBehavior n) where
fromAvro (AV.Fixed _ v)
| BS.length v == s = pure $ $(conE (mkDataTypeName namespaceBehavior n)) v
fromAvro value = $( [|\v -> badValueNew v $(mkTextLit $ S.renderFullname n)|] ) value
|]
genFromValue _ _ = pure []
genFromAvroNewFieldsExp :: Name -> [Field] -> Q Exp
genFromAvroNewFieldsExp n xs =
[| \r ->
$(let ctor = [| pure $(conE n) |]
in foldl (\expr (i, _) -> [| $expr <*> AV.fromAvro (r V.! i) |]) ctor (zip [(0 :: Int)..] xs)
)
|]
genHasAvroSchema :: NamespaceBehavior -> Schema -> Q [Dec]
genHasAvroSchema namespaceBehavior s = do
let sname = mkSchemaValueName namespaceBehavior (name s)
sdef <- schemaDef sname s
idef <- hasAvroSchema sname
pure (sdef <> idef)
where
hasAvroSchema sname =
[d| instance HasAvroSchema $(conT $ mkDataTypeName namespaceBehavior (name s)) where
schema = pure $(varE sname)
|]
newNames :: String
-> Int
-> Q [Name]
newNames base n = sequence [newName (base ++ show i) | i <- [1..n]]
genToAvro :: DeriveOptions -> Schema -> Q [Dec]
genToAvro opts s@(S.Enum n _ _ _) =
encodeAvroInstance (mkSchemaValueName (namespaceBehavior opts) n)
where
encodeAvroInstance sname =
[d| instance ToAvro $(conT $ mkDataTypeName (namespaceBehavior opts) n) where
toAvro = $([| \_ x -> putI (fromEnum x) |])
|]
genToAvro opts s@(S.Record n _ _ fs) =
encodeAvroInstance (mkSchemaValueName (namespaceBehavior opts) n)
where
encodeAvroInstance sname =
[d| instance ToAvro $(conT $ mkDataTypeName (namespaceBehavior opts) n) where
toAvro = $(encodeAvroFieldsExp sname)
|]
encodeAvroFieldsExp sname = do
names <- newNames "p_" (length fs)
wn <- varP <$> newName "_"
let con = conP (mkDataTypeName (namespaceBehavior opts) n) (varP <$> names)
lamE [wn, con]
[| mconcat $( let build (fld, n) = [| toAvro (fldType fld) $(varE n) |]
in listE $ build <$> (zip fs names)
)
|]
genToAvro opts s@(S.Fixed n _ _ _) =
encodeAvroInstance (mkSchemaValueName (namespaceBehavior opts) n)
where
encodeAvroInstance sname =
[d| instance ToAvro $(conT $ mkDataTypeName (namespaceBehavior opts) n) where
toAvro = $(do
x <- newName "x"
wc <- newName "_"
lamE [varP wc, conP (mkDataTypeName (namespaceBehavior opts) n) [varP x]] [| toAvro $(varE sname) $(varE x) |])
|]
genToAvro _ _ = pure []
schemaDef :: Name -> Schema -> Q [Dec]
schemaDef sname sch = setName sname $
[d|
x :: Schema
x = sch
|]
setName :: Name -> Q [Dec] -> Q [Dec]
setName = fmap . map . sn
where
sn n (SigD _ t) = SigD n t
sn n (ValD (VarP _) x y) = ValD (VarP n) x y
sn _ d = d
genType :: DeriveOptions -> Schema -> Q [Dec]
genType opts (S.Record n _ _ fs) = do
flds <- traverse (mkField opts n) fs
let dname = mkDataTypeName (namespaceBehavior opts) n
sequenceA [genDataType dname flds]
genType opts (S.Enum n _ _ vs) = do
let dname = mkDataTypeName (namespaceBehavior opts) n
sequenceA [genEnum dname (mkAdtCtorName (namespaceBehavior opts) n <$> (V.toList vs))]
genType opts (S.Fixed n _ s _) = do
let dname = mkDataTypeName (namespaceBehavior opts) n
sequenceA [genNewtype dname]
genType _ _ = pure []
mkFieldTypeName :: NamespaceBehavior -> S.Schema -> Q TH.Type
mkFieldTypeName namespaceBehavior = \case
S.Boolean -> [t| Bool |]
S.Long (Just (DecimalL (Decimal p s)))
-> [t| Decimal $(litT $ numTyLit p) $(litT $ numTyLit s) |]
S.Long (Just TimeMicros)
-> [t| DiffTime |]
S.Long (Just TimestampMicros)
-> [t| UTCTime |]
S.Long (Just TimestampMillis)
-> [t| UTCTime |]
S.Long _ -> [t| Int64 |]
S.Int (Just Date) -> [t| Day |]
S.Int (Just TimeMillis)
-> [t| DiffTime |]
S.Int _ -> [t| Int32 |]
S.Float -> [t| Float |]
S.Double -> [t| Double |]
S.Bytes _ -> [t| ByteString |]
S.String Nothing -> [t| Text |]
S.String (Just UUID) -> [t| UUID |]
S.Union branches -> union (Foldable.toList branches)
S.Record n _ _ _ -> [t| $(conT $ mkDataTypeName namespaceBehavior n) |]
S.Map x -> [t| Map Text $(go x) |]
S.Array x -> [t| [$(go x)] |]
S.NamedType n -> [t| $(conT $ mkDataTypeName namespaceBehavior n)|]
S.Fixed n _ _ _ -> [t| $(conT $ mkDataTypeName namespaceBehavior n)|]
S.Enum n _ _ _ -> [t| $(conT $ mkDataTypeName namespaceBehavior n)|]
t -> error $ "Avro type is not supported: " <> show t
where go = mkFieldTypeName namespaceBehavior
union = \case
[] ->
error "Empty union types are not supported"
[x] -> [t| Identity $(go x) |]
[Null, x] -> [t| Maybe $(go x) |]
[x, Null] -> [t| Maybe $(go x) |]
[x, y] -> [t| Either $(go x) $(go y) |]
[a, b, c] -> [t| Either3 $(go a) $(go b) $(go c) |]
[a, b, c, d] -> [t| Either4 $(go a) $(go b) $(go c) $(go d) |]
[a, b, c, d, e] -> [t| Either5 $(go a) $(go b) $(go c) $(go d) $(go e) |]
[a, b, c, d, e, f] -> [t| Either6 $(go a) $(go b) $(go c) $(go d) $(go e) $(go f) |]
[a, b, c, d, e, f, g] -> [t| Either7 $(go a) $(go b) $(go c) $(go d) $(go e) $(go f) $(go g)|]
[a, b, c, d, e, f, g, h] -> [t| Either8 $(go a) $(go b) $(go c) $(go d) $(go e) $(go f) $(go g) $(go h)|]
[a, b, c, d, e, f, g, h, i] -> [t| Either9 $(go a) $(go b) $(go c) $(go d) $(go e) $(go f) $(go g) $(go h) $(go i)|]
[a, b, c, d, e, f, g, h, i, j] -> [t| Either10 $(go a) $(go b) $(go c) $(go d) $(go e) $(go f) $(go g) $(go h) $(go i) $(go j)|]
ls ->
error $ "Unions with more than 10 elements are not yet supported: Union has " <> (show . length) ls <> " elements"
updateFirst :: (Text -> Text) -> Text -> Text
updateFirst f t =
let (l, ls) = T.splitAt 1 t
in f l <> ls
decodeSchema :: FilePath -> IO (Either String Schema)
decodeSchema p = eitherDecode <$> LBS.readFile p
mkAdtCtorName :: NamespaceBehavior -> TypeName -> Text -> Name
mkAdtCtorName namespaceBehavior prefix nm =
concatNames (mkDataTypeName namespaceBehavior prefix) (mkDataTypeName' nm)
concatNames :: Name -> Name -> Name
concatNames a b = mkName $ nameBase a <> nameBase b
sanitiseName :: Text -> Text
sanitiseName =
let valid c = isAlphaNum c || c == '\'' || c == '_'
in T.concat . T.split (not . valid)
renderName :: NamespaceBehavior
-> TypeName
-> Text
renderName namespaceBehavior (TN name namespace) = case namespaceBehavior of
HandleNamespaces -> Text.intercalate "'" $ namespace <> [name]
IgnoreNamespaces -> name
Custom f -> f name namespace
mkSchemaValueName :: NamespaceBehavior -> TypeName -> Name
mkSchemaValueName namespaceBehavior typeName =
mkTextName $ "schema'" <> renderName namespaceBehavior typeName
mkDataTypeName :: NamespaceBehavior -> TypeName -> Name
mkDataTypeName namespaceBehavior = mkDataTypeName' . renderName namespaceBehavior
mkDataTypeName' :: Text -> Name
mkDataTypeName' =
mkTextName . sanitiseName . updateFirst T.toUpper . T.takeWhileEnd (/='.')
mkField :: DeriveOptions -> TypeName -> Field -> Q VarStrictType
mkField opts typeName field = do
ftype <- mkFieldTypeName (namespaceBehavior opts) (fldType field)
let prefix = renderName (namespaceBehavior opts) typeName
fName = mkTextName $ (fieldNameBuilder opts) prefix field
(fieldStrictness, fieldUnpackedness) =
fieldRepresentation opts typeName field
strictness =
case fieldStrictness of
StrictField -> strict fieldUnpackedness
LazyField -> notStrict
pure (fName, strictness, ftype)
genNewtype :: Name -> Q Dec
#if MIN_VERSION_template_haskell(2,12,0)
genNewtype dn = do
ders <- sequenceA [[t|Eq|], [t|Show|], [t|Generic|]]
fldType <- [t|ByteString|]
let ctor = RecC dn [(mkName ("un" ++ nameBase dn), notStrict, fldType)]
pure $ NewtypeD [] dn [] Nothing ctor [DerivClause Nothing ders]
#elif MIN_VERSION_template_haskell(2,11,0)
genNewtype dn = do
ders <- sequenceA [[t|Eq|], [t|Show|], [t|Generic|]]
fldType <- [t|ByteString|]
let ctor = RecC dn [(mkName ("un" ++ nameBase dn), notStrict, fldType)]
pure $ NewtypeD [] dn [] Nothing ctor ders
#else
genNewtype dn = do
[ConT eq, ConT sh, ConT gen] <- sequenceA [[t|Eq|], [t|Show|], [t|Generic|]]
fldType <- [t|ByteString|]
let ctor = RecC dn [(mkName ("un" ++ nameBase dn), notStrict, fldType)]
pure $ NewtypeD [] dn [] ctor [eq, sh, gen]
#endif
genEnum :: Name -> [Name] -> Q Dec
#if MIN_VERSION_template_haskell(2,12,0)
genEnum dn vs = do
ders <- sequenceA [[t|Eq|], [t|Show|], [t|Ord|], [t|Enum|], [t|Bounded|], [t|Generic|]]
pure $ DataD [] dn [] Nothing ((\n -> NormalC n []) <$> vs) [DerivClause Nothing ders]
#elif MIN_VERSION_template_haskell(2,11,0)
genEnum dn vs = do
ders <- sequenceA [[t|Eq|], [t|Show|], [t|Ord|], [t|Enum|], [t|Bounded|], [t|Generic|]]
pure $ DataD [] dn [] Nothing ((\n -> NormalC n []) <$> vs) ders
#else
genEnum dn vs = do
[ConT eq, ConT sh, ConT or, ConT en, ConT gen] <- sequenceA [[t|Eq|], [t|Show|], [t|Ord|], [t|Enum|], [t|Bounded|], [t|Generic|]]
pure $ DataD [] dn [] ((\n -> NormalC n []) <$> vs) [eq, sh, or, en, gen]
#endif
genDataType :: Name -> [VarStrictType] -> Q Dec
#if MIN_VERSION_template_haskell(2,12,0)
genDataType dn flds = do
ders <- sequenceA [[t|Eq|], [t|Show|], [t|Generic|]]
pure $ DataD [] dn [] Nothing [RecC dn flds] [DerivClause Nothing ders]
#elif MIN_VERSION_template_haskell(2,11,0)
genDataType dn flds = do
ders <- sequenceA [[t|Eq|], [t|Show|], [t|Generic|]]
pure $ DataD [] dn [] Nothing [RecC dn flds] ders
#else
genDataType dn flds = do
[ConT eq, ConT sh, ConT gen] <- sequenceA [[t|Eq|], [t|Show|], [t|Generic|]]
pure $ DataD [] dn [] [RecC dn flds] [eq, sh, gen]
#endif
notStrict :: Strict
#if MIN_VERSION_template_haskell(2,11,0)
notStrict = Bang SourceNoUnpack NoSourceStrictness
#else
notStrict = NotStrict
#endif
strict :: FieldUnpackedness -> Strict
#if MIN_VERSION_template_haskell(2,11,0)
strict UnpackedField = Bang SourceUnpack SourceStrict
strict NonUnpackedField = Bang SourceNoUnpack SourceStrict
#else
strict UnpackedField = Unpacked
strict NonUnpackedField = IsStrict
#endif
mkTextName :: Text -> Name
mkTextName = mkName . T.unpack
mkLit :: String -> ExpQ
mkLit = litE . StringL
mkTextLit :: Text -> ExpQ
mkTextLit = litE . StringL . T.unpack