{-# LANGUAGE OverloadedLists #-} module FlatBuffers.Internal.Compiler.TH where import Control.Monad (join) import Control.Monad.Except (runExceptT) import Data.Bits ((.&.)) import Data.Foldable (traverse_) import Data.Functor ((<&>)) import Data.Int import Data.List qualified as List import Data.List.NonEmpty (NonEmpty(..)) import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as Map import Data.Text (Text) import Data.Text qualified as T import Data.Word import FlatBuffers.Internal.Build import FlatBuffers.Internal.Compiler.NamingConventions qualified as NC import FlatBuffers.Internal.Compiler.ParserIO qualified as ParserIO import FlatBuffers.Internal.Compiler.SemanticAnalysis (SymbolTable(..)) import FlatBuffers.Internal.Compiler.SemanticAnalysis qualified as SemanticAnalysis import FlatBuffers.Internal.Compiler.SyntaxTree qualified as SyntaxTree import FlatBuffers.Internal.Compiler.ValidSyntaxTree import FlatBuffers.Internal.FileIdentifier (HasFileIdentifier(..), unsafeFileIdentifier) import FlatBuffers.Internal.Read import FlatBuffers.Internal.Types import FlatBuffers.Internal.Write import Language.Haskell.TH import Language.Haskell.TH.Syntax (lift) import Language.Haskell.TH.Syntax qualified as TH -- | Helper method to create function types. -- @ConT ''Int ~> ConT ''String === Int -> String@ (~>) :: Type -> Type -> Type a ~> b = ArrowT `AppT` a `AppT` b infixr 1 ~> -- | Options to control how\/which flatbuffers constructors\/accessor should be generated. -- -- Options can be set using record syntax on `defaultOptions` with the fields below. -- -- > defaultOptions { compileAllSchemas = True } data Options = Options { -- | Directories to search for @include@s (same as flatc @-I@ option). includeDirectories :: [FilePath] -- | Generate code not just for the root schema, -- but for all schemas it includes as well -- (same as flatc @--gen-all@ option). , compileAllSchemas :: Bool } deriving (Show, Eq) -- | Default flatbuffers options: -- -- > Options -- > { includeDirectories = [] -- > , compileAllSchemas = False -- > } defaultOptions :: Options defaultOptions = Options { includeDirectories = [] , compileAllSchemas = False } -- | Generates constructors and accessors for all data types declared in the given flatbuffers -- schema whose namespace matches the current module. -- -- > namespace Data.Game; -- > -- > table Monster {} -- -- > {-# LANGUAGE TemplateHaskell #-} -- > -- > module Data.Game where -- > import FlatBuffers -- > -- > $(mkFlatBuffers "schemas/game.fbs" defaultOptions) mkFlatBuffers :: FilePath -> Options -> Q [Dec] mkFlatBuffers rootFilePath opts = do currentModule <- T.pack . loc_module <$> location parseResult <- runIO $ runExceptT $ ParserIO.parseSchemas rootFilePath (includeDirectories opts) schemaFileTree <- either (fail . fixMsg) pure parseResult registerFiles schemaFileTree symbolTables <- either (fail . fixMsg) pure $ SemanticAnalysis.validateSchemas schemaFileTree let symbolTable = if compileAllSchemas opts then SyntaxTree.fileTreeRoot symbolTables <> mconcat (Map.elems $ SyntaxTree.fileTreeForest symbolTables) else SyntaxTree.fileTreeRoot symbolTables let symbolTable' = filterByCurrentModule currentModule symbolTable compileSymbolTable symbolTable' where registerFiles (SyntaxTree.FileTree rootFilePath _ includedFiles) = do TH.addDependentFile rootFilePath traverse_ TH.addDependentFile $ Map.keys includedFiles filterByCurrentModule currentModule (SymbolTable enums structs tables unions) = SymbolTable { allEnums = Map.filterWithKey (isCurrentModule currentModule) enums , allStructs = Map.filterWithKey (isCurrentModule currentModule) structs , allTables = Map.filterWithKey (isCurrentModule currentModule) tables , allUnions = Map.filterWithKey (isCurrentModule currentModule) unions } isCurrentModule currentModule (ns, _) _ = NC.namespace ns == currentModule -- | This does two things: -- -- 1. ghcid stops parsing an error when it finds a line that start with alphabetical characters or an empty lines, -- so we prepend each line with an empty space to avoid this. -- 2. we also remove any trailing \n, otherwise ghcid would stop parsing here and not show the source code location. fixMsg :: String -> String fixMsg = List.intercalate "\n" . fmap fixLine . lines where fixLine line = " " <> line compileSymbolTable :: SemanticAnalysis.ValidDecls -> Q [Dec] compileSymbolTable symbolTable = do enumDecs <- join <$> traverse mkEnum (Map.elems (allEnums symbolTable)) structDecs <- join <$> traverse mkStruct (Map.elems (allStructs symbolTable)) tableDecs <- join <$> traverse mkTable (Map.elems (allTables symbolTable)) unionDecs <- join <$> traverse mkUnion (Map.elems (allUnions symbolTable)) pure $ enumDecs <> structDecs <> tableDecs <> unionDecs mkEnum :: EnumDecl -> Q [Dec] mkEnum enum = if enumBitFlags enum then mkEnumBitFlags enum else mkEnumNormal enum mkEnumBitFlags :: EnumDecl -> Q [Dec] mkEnumBitFlags enum = do nameFun <- mkEnumBitFlagsNames enum enumValNames pure $ mkEnumBitFlagsConstants enum enumValNames <> mkEnumBitFlagsAllValls enum enumValNames <> nameFun where enumValNames = mkName . T.unpack . NC.enumBitFlagsConstant enum <$> NE.toList (enumVals enum) mkEnumBitFlagsConstants :: EnumDecl -> [Name] -> [Dec] mkEnumBitFlagsConstants enum enumValNames = NE.toList (enumVals enum) `zip` enumValNames >>= \(enumVal, enumValName) -> let sig = SigD enumValName (enumTypeToType (enumType enum)) fun = FunD enumValName [Clause [] (NormalB (intLitE (enumValInt enumVal))) []] in [sig, fun] -- | Generates a list with all the enum values, e.g. -- -- > allColors = [colorsRed, colorsGreen, colorsBlue] mkEnumBitFlagsAllValls :: EnumDecl -> [Name] -> [Dec] mkEnumBitFlagsAllValls enum enumValNames = let name = mkName $ T.unpack $ NC.enumBitFlagsAllFun enum sig = SigD name (ListT `AppT` enumTypeToType (enumType enum)) fun = FunD name [ Clause [] (NormalB body) []] body = ListE (VarE <$> enumValNames) in [sig, fun, inlinePragma name] -- | Generates @colorsNames@. mkEnumBitFlagsNames :: EnumDecl -> [Name] -> Q [Dec] mkEnumBitFlagsNames enum enumValNames = do inputName <- newName "c" firstRes <- newName "res0" firstClause <- [d| $(varP firstRes) = [] |] (clauses, lastRes) <- mkClauses namesAndIdentifiers 1 inputName firstRes firstClause let fun = FunD funName [ Clause [VarP inputName] (NormalB (VarE lastRes)) (List.reverse clauses) ] pure [ sig , fun , inlinePragma funName ] where funName = mkName $ T.unpack $ NC.enumBitFlagsNamesFun enum sig = SigD funName (enumTypeToType (enumType enum) ~> ListT `AppT` ConT ''Text) namesAndIdentifiers :: [(Name, Ident)] namesAndIdentifiers = List.reverse (enumValNames `zip` fmap enumValIdent (NE.toList (enumVals enum))) mkClauses :: [(Name, Ident)] -> Int -> Name -> Name -> [Dec] -> Q ([Dec], Name) mkClauses [] _ _ previousRes clauses = pure (clauses, previousRes) mkClauses ((name, Ident ident) : rest) ix inputName previousRes clauses = do res <- newName ("res" <> show ix) clause <- [d| $(varP res) = if $(varE name) .&. $(varE inputName) /= 0 then $(pure (textLitE ident)) : $(varE previousRes) else $(varE previousRes) |] mkClauses rest (ix + 1) inputName res (clause <> clauses) -- | Generated declarations for a non-bit-flags enum. mkEnumNormal :: EnumDecl -> Q [Dec] mkEnumNormal enum = do let enumName = mkName' $ NC.dataTypeName enum let enumValNames = enumVals enum <&> \enumVal -> mkName $ T.unpack $ NC.enumUnionMember enum enumVal let enumDec = mkEnumDataDec enumName enumValNames let enumValsAndNames = enumVals enum `NE.zip` enumValNames toEnumDecs <- mkToEnum enumName enum enumValsAndNames fromEnumDecs <- mkFromEnum enumName enum enumValsAndNames enumNameDecs <- mkEnumNameFun enumName enum enumValsAndNames pure $ enumDec : toEnumDecs <> fromEnumDecs <> enumNameDecs mkEnumDataDec :: Name -> NonEmpty Name -> Dec mkEnumDataDec enumName enumValNames = DataD [] enumName [] Nothing (fmap (\n -> NormalC n []) (NE.toList enumValNames)) [ DerivClause Nothing [ ConT ''Eq , ConT ''Show , ConT ''Read , ConT ''Ord , ConT ''Bounded ] ] mkToEnum :: Name -> EnumDecl -> NonEmpty (EnumVal, Name) -> Q [Dec] mkToEnum enumName enum enumValsAndNames = do let funName = mkName' $ NC.toEnumFun enum argName <- newName "n" pure [ SigD funName (enumTypeToType (enumType enum) ~> ConT ''Maybe `AppT` ConT enumName) , FunD funName [ Clause [VarP argName] (NormalB (CaseE (VarE argName) matches)) [] ] , inlinePragma funName ] where matches = (mkMatch <$> NE.toList enumValsAndNames) <> [matchWildcard] mkMatch (enumVal, enumName) = Match (intLitP (enumValInt enumVal)) (NormalB (ConE 'Just `AppE` ConE enumName)) [] matchWildcard = Match WildP (NormalB (ConE 'Nothing)) [] mkFromEnum :: Name -> EnumDecl -> NonEmpty (EnumVal, Name) -> Q [Dec] mkFromEnum enumName enum enumValsAndNames = do let funName = mkName' $ NC.fromEnumFun enum argName <- newName "n" pure [ SigD funName (ConT enumName ~> enumTypeToType (enumType enum)) , FunD funName [ Clause [VarP argName] (NormalB (CaseE (VarE argName) (mkMatch <$> NE.toList enumValsAndNames))) [] ] , inlinePragma funName ] where mkMatch (enumVal, enumName) = Match (ConP enumName [] []) (NormalB (intLitE (enumValInt enumVal))) [] -- | Generates @colorsName@. mkEnumNameFun :: Name -> EnumDecl -> NonEmpty (EnumVal, Name) -> Q [Dec] mkEnumNameFun enumName enum enumValsAndNames = do let funName = mkName' $ NC.enumNameFun enum argName <- newName "c" pure [ SigD funName (ConT enumName ~> ConT ''Text) , FunD funName [ Clause [VarP argName] (NormalB (CaseE (VarE argName) (mkMatch <$> NE.toList enumValsAndNames))) [] ] , inlinePragma funName ] where mkMatch (enumVal, enumName) = Match (ConP enumName [] []) (NormalB (textLitE (unIdent (getIdent enumVal)))) [] mkStruct :: StructDecl -> Q [Dec] mkStruct struct = do let structName = mkName' $ NC.dataTypeName struct isStructInstance <- mkIsStructInstance structName struct let dataDec = DataD [] structName [] Nothing [] [] (consSig, cons) <- mkStructConstructor structName struct let getters = foldMap (mkStructFieldGetter structName struct) (structFields struct) pure $ dataDec : isStructInstance <> [ consSig, cons ] <> getters mkIsStructInstance :: Name -> StructDecl -> Q [Dec] mkIsStructInstance structName struct = [d| instance IsStruct $(conT structName) where structAlignmentOf = $(lift . unAlignment . structAlignment $ struct) structSizeOf = $(lift . unInlineSize . structSize $ struct) |] mkStructConstructor :: Name -> StructDecl -> Q (Dec, Dec) mkStructConstructor structName struct = do argsInfo <- traverse mkStructConstructorArg (structFields struct) let (argTypes, pats, exps) = nonEmptyUnzip3 argsInfo let retType = AppT (ConT ''WriteStruct) (ConT structName) let sigType = foldr (~>) retType argTypes let consName = mkName' $ NC.dataTypeConstructor struct let consSig = SigD consName sigType let exp = foldr1 (\e acc -> InfixE (Just e) (VarE '(<>)) (Just acc)) (join exps) let body = NormalB $ ConE 'WriteStruct `AppE` exp let cons = FunD consName [ Clause (NE.toList pats) body [] ] pure (consSig, cons) mkStructConstructorArg :: StructField -> Q (Type, Pat, NonEmpty Exp) mkStructConstructorArg sf = do argName <- newName' $ NC.arg sf let argPat = VarP argName let argRef = VarE argName let argType = structFieldTypeToWriteType (structFieldType sf) let mkWriteExp sft = case sft of SInt8 -> VarE 'buildInt8 SInt16 -> VarE 'buildInt16 SInt32 -> VarE 'buildInt32 SInt64 -> VarE 'buildInt64 SWord8 -> VarE 'buildWord8 SWord16 -> VarE 'buildWord16 SWord32 -> VarE 'buildWord32 SWord64 -> VarE 'buildWord64 SFloat -> VarE 'buildFloat SDouble -> VarE 'buildDouble SBool -> VarE 'buildBool SEnum _ enumType -> mkWriteExp (enumTypeToStructFieldType enumType) SStruct _ -> VarE 'buildStruct let exp = mkWriteExp (structFieldType sf) `AppE` argRef let exps = if structFieldPadding sf == 0 then [ exp ] else [ exp , VarE 'buildPadding `AppE` intLitE (structFieldPadding sf) ] pure (argType, argPat, exps) mkStructFieldGetter :: Name -> StructDecl -> StructField -> [Dec] mkStructFieldGetter structName struct sf = [sig, fun] where funName = mkName (T.unpack (NC.getter struct sf)) fieldOffsetExp = intLitE (structFieldOffset sf) retType = structFieldTypeToReadType (structFieldType sf) sig = SigD funName $ case structFieldType sf of SStruct _ -> ConT ''Struct `AppT` ConT structName ~> retType _ -> ConT ''Struct `AppT` ConT structName ~> ConT ''Either `AppT` ConT ''ReadError `AppT` retType fun = FunD funName [ Clause [] (NormalB body) [] ] body = app [ VarE 'readStructField , mkReadExp (structFieldType sf) , fieldOffsetExp ] mkReadExp sft = case sft of SInt8 -> VarE 'readInt8 SInt16 -> VarE 'readInt16 SInt32 -> VarE 'readInt32 SInt64 -> VarE 'readInt64 SWord8 -> VarE 'readWord8 SWord16 -> VarE 'readWord16 SWord32 -> VarE 'readWord32 SWord64 -> VarE 'readWord64 SFloat -> VarE 'readFloat SDouble -> VarE 'readDouble SBool -> VarE 'readBool SEnum _ enumType -> mkReadExp $ enumTypeToStructFieldType enumType SStruct _ -> VarE 'readStruct mkTable :: TableDecl -> Q [Dec] mkTable table = do let tableName = mkName' $ NC.dataTypeName table (consSig, cons) <- mkTableConstructor tableName table let fileIdentifierDec = mkTableFileIdentifier tableName (tableIsRoot table) let getters = foldMap (mkTableFieldGetter tableName table) (tableFields table) pure $ [ DataD [] tableName [] Nothing [] [] , consSig , cons ] <> fileIdentifierDec <> getters mkTableFileIdentifier :: Name -> IsRoot -> [Dec] mkTableFileIdentifier tableName isRoot = case isRoot of NotRoot -> [] IsRoot Nothing -> [] IsRoot (Just fileIdentifier) -> [ InstanceD Nothing [] (ConT ''HasFileIdentifier `AppT` ConT tableName) [ FunD 'getFileIdentifier [ Clause [] (NormalB $ VarE 'unsafeFileIdentifier `AppE` textLitE fileIdentifier) [] ] ] ] mkTableConstructor :: Name -> TableDecl -> Q (Dec, Dec) mkTableConstructor tableName table = do (argTypes, pats, exps) <- mconcat <$> traverse mkTableContructorArg (tableFields table) let retType = AppT (ConT ''WriteTable) (ConT tableName) let sigType = foldr (~>) retType argTypes let consName = mkName' $ NC.dataTypeConstructor table let consSig = SigD consName sigType let body = NormalB $ AppE (VarE 'writeTable) (ListE exps) let cons = FunD consName [ Clause pats body [] ] pure (consSig, cons) mkTableContructorArg :: TableField -> Q ([Type], [Pat], [Exp]) mkTableContructorArg tf = if tableFieldDeprecated tf then case tableFieldType tf of TUnion _ _ -> pure ([], [], [VarE 'deprecated, VarE 'deprecated]) TVector _ (VUnion _) -> pure ([], [], [VarE 'deprecated, VarE 'deprecated]) _ -> pure ([], [], [VarE 'deprecated]) else do argName <- newName' $ NC.arg tf let argPat = VarP argName let argRef = VarE argName let argType = tableFieldTypeToWriteType (tableFieldType tf) let exps = mkExps argRef (tableFieldType tf) pure ([argType], [argPat], exps) where expForScalar :: Exp -> Exp -> Exp -> Exp expForScalar defaultValExp writeExp varExp = VarE 'optionalDef `AppE` defaultValExp `AppE` writeExp `AppE` varExp expForNonScalar :: Required -> Exp -> Exp -> Exp expForNonScalar Req exp argRef = exp `AppE` argRef expForNonScalar Opt exp argRef = VarE 'optional `AppE` exp `AppE` argRef mkExps :: Exp -> TableFieldType -> [Exp] mkExps argRef tfType = case tfType of TInt8 (DefaultVal n) -> pure $ expForScalar (intLitE n) (VarE 'writeInt8TableField ) argRef TInt16 (DefaultVal n) -> pure $ expForScalar (intLitE n) (VarE 'writeInt16TableField ) argRef TInt32 (DefaultVal n) -> pure $ expForScalar (intLitE n) (VarE 'writeInt32TableField ) argRef TInt64 (DefaultVal n) -> pure $ expForScalar (intLitE n) (VarE 'writeInt64TableField ) argRef TWord8 (DefaultVal n) -> pure $ expForScalar (intLitE n) (VarE 'writeWord8TableField ) argRef TWord16 (DefaultVal n) -> pure $ expForScalar (intLitE n) (VarE 'writeWord16TableField ) argRef TWord32 (DefaultVal n) -> pure $ expForScalar (intLitE n) (VarE 'writeWord32TableField ) argRef TWord64 (DefaultVal n) -> pure $ expForScalar (intLitE n) (VarE 'writeWord64TableField ) argRef TFloat (DefaultVal n) -> pure $ expForScalar (realLitE n) (VarE 'writeFloatTableField ) argRef TDouble (DefaultVal n) -> pure $ expForScalar (realLitE n) (VarE 'writeDoubleTableField ) argRef TBool (DefaultVal b) -> pure $ expForScalar (if b then ConE 'True else ConE 'False) (VarE 'writeBoolTableField) argRef TString req     -> pure $ expForNonScalar req (VarE 'writeTextTableField) argRef TEnum _ enumType dflt -> mkExps argRef (enumTypeToTableFieldType enumType dflt) TStruct _ req -> pure $ expForNonScalar req (VarE 'writeStructTableField) argRef TTable _ req -> pure $ expForNonScalar req (VarE 'writeTableTableField) argRef TUnion _ req -> [ expForNonScalar req (VarE 'writeUnionTypeTableField) argRef , expForNonScalar req (VarE 'writeUnionValueTableField) argRef ] TVector req vecElemType -> mkExpForVector argRef req vecElemType mkExpForVector :: Exp -> Required -> VectorElementType -> [Exp] mkExpForVector argRef req vecElemType = case vecElemType of VInt8 -> [ expForNonScalar req (VarE 'writeVectorInt8TableField) argRef ] VInt16 -> [ expForNonScalar req (VarE 'writeVectorInt16TableField) argRef ] VInt32 -> [ expForNonScalar req (VarE 'writeVectorInt32TableField) argRef ] VInt64 -> [ expForNonScalar req (VarE 'writeVectorInt64TableField) argRef ] VWord8 -> [ expForNonScalar req (VarE 'writeVectorWord8TableField) argRef ] VWord16 -> [ expForNonScalar req (VarE 'writeVectorWord16TableField) argRef ] VWord32 -> [ expForNonScalar req (VarE 'writeVectorWord32TableField) argRef ] VWord64 -> [ expForNonScalar req (VarE 'writeVectorWord64TableField) argRef ] VFloat -> [ expForNonScalar req (VarE 'writeVectorFloatTableField) argRef ] VDouble -> [ expForNonScalar req (VarE 'writeVectorDoubleTableField) argRef ] VBool -> [ expForNonScalar req (VarE 'writeVectorBoolTableField) argRef ] VString -> [ expForNonScalar req (VarE 'writeVectorTextTableField) argRef ] VEnum _ enumType -> mkExpForVector argRef req (enumTypeToVectorElementType enumType) VStruct _ -> [ expForNonScalar req (VarE 'writeVectorStructTableField) argRef ] VTable _ -> [ expForNonScalar req (VarE 'writeVectorTableTableField) argRef ] VUnion _ -> [ expForNonScalar req (VarE 'writeUnionTypesVectorTableField) argRef , expForNonScalar req (VarE 'writeUnionValuesVectorTableField) argRef ] mkTableFieldGetter :: Name -> TableDecl -> TableField -> [Dec] mkTableFieldGetter tableName table tf = if tableFieldDeprecated tf then [] else [sig, mkFun (tableFieldType tf)] where funName = mkName (T.unpack (NC.getter table tf)) fieldIndex = intLitE (tableFieldId tf) sig = SigD funName $ ConT ''Table `AppT` ConT tableName ~> ConT ''Either `AppT` ConT ''ReadError `AppT` tableFieldTypeToReadType (tableFieldType tf) mkFun :: TableFieldType -> Dec mkFun tft = case tft of TWord8 (DefaultVal n) -> mkFunWithBody (bodyForScalar (intLitE n) (VarE 'readWord8)) TWord16 (DefaultVal n) -> mkFunWithBody (bodyForScalar (intLitE n) (VarE 'readWord16)) TWord32 (DefaultVal n) -> mkFunWithBody (bodyForScalar (intLitE n) (VarE 'readWord32)) TWord64 (DefaultVal n) -> mkFunWithBody (bodyForScalar (intLitE n) (VarE 'readWord64)) TInt8 (DefaultVal n) -> mkFunWithBody (bodyForScalar (intLitE n) (VarE 'readInt8)) TInt16 (DefaultVal n) -> mkFunWithBody (bodyForScalar (intLitE n) (VarE 'readInt16)) TInt32 (DefaultVal n) -> mkFunWithBody (bodyForScalar (intLitE n) (VarE 'readInt32)) TInt64 (DefaultVal n) -> mkFunWithBody (bodyForScalar (intLitE n) (VarE 'readInt64)) TFloat (DefaultVal n) -> mkFunWithBody (bodyForScalar (realLitE n) (VarE 'readFloat)) TDouble (DefaultVal n) -> mkFunWithBody (bodyForScalar (realLitE n) (VarE 'readDouble)) TBool (DefaultVal b) -> mkFunWithBody (bodyForScalar (if b then ConE 'True else ConE 'False) (VarE 'readBool)) TString req -> mkFunWithBody (bodyForNonScalar req (VarE 'readText)) TEnum _ enumType dflt -> mkFun $ enumTypeToTableFieldType enumType dflt TStruct _ req -> mkFunWithBody (bodyForNonScalar req (compose [ConE 'Right, VarE 'readStruct])) TTable _ req -> mkFunWithBody (bodyForNonScalar req (VarE 'readTable)) TUnion (TypeRef ns ident) req -> do let readUnionFunName = VarE . mkName . T.unpack . NC.withModulePrefix ns $ NC.readUnionFun ident mkFunWithBody $ app case req of Req -> [ VarE 'readTableFieldUnionReq , readUnionFunName , fieldIndex , stringLitE . unIdent . getIdent $ tf ] Opt -> [ VarE 'readTableFieldUnionOpt , readUnionFunName , fieldIndex ] TVector req vecElemType -> mkFunForVector req vecElemType mkFunForVector :: Required -> VectorElementType -> Dec mkFunForVector req vecElemType = case vecElemType of VInt8 -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorInt8 VInt16 -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorInt16 VInt32 -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorInt32 VInt64 -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorInt64 VWord8 -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorWord8 VWord16 -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorWord16 VWord32 -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorWord32 VWord64 -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorWord64 VFloat -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorFloat VDouble -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorDouble VBool -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorBool VString -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorText VEnum _ enumType -> mkFunForVector req (enumTypeToVectorElementType enumType) VStruct _ -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorStruct VTable _ -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readTableVector VUnion (TypeRef ns ident) -> mkFunWithBody $ case req of Opt -> app [ VarE 'readTableFieldUnionVectorOpt , VarE . mkName . T.unpack . NC.withModulePrefix ns $ NC.readUnionFun ident , fieldIndex ] Req -> app [ VarE 'readTableFieldUnionVectorReq , VarE . mkName . T.unpack . NC.withModulePrefix ns $ NC.readUnionFun ident , fieldIndex , stringLitE . unIdent . getIdent $ tf ] mkFunWithBody :: Exp -> Dec mkFunWithBody body = FunD funName [ Clause [] (NormalB body) [] ] bodyForNonScalar req readExp = case req of Req -> app [ VarE 'readTableFieldReq , readExp , fieldIndex , stringLitE . unIdent . getIdent $ tf ] Opt -> app [ VarE 'readTableFieldOpt , readExp , fieldIndex ] bodyForScalar defaultValExp readExp = app [ VarE 'readTableFieldWithDef , readExp , fieldIndex , defaultValExp ] mkUnion :: UnionDecl -> Q [Dec] mkUnion union = do let unionName = mkName' $ NC.dataTypeName union let unionValNames = unionVals union <&> \unionVal -> mkName $ T.unpack $ NC.enumUnionMember union unionVal unionConstructors <- mkUnionConstructors unionName union readFun <- mkReadUnionFun unionName unionValNames union pure $ mkUnionDataDec unionName (unionVals union `NE.zip` unionValNames) : unionConstructors <> readFun mkUnionDataDec :: Name -> NonEmpty (UnionVal, Name) -> Dec mkUnionDataDec unionName unionValsAndNames = DataD [] unionName [] Nothing (NE.toList $ fmap mkCons unionValsAndNames) [] where mkCons (unionVal, unionValName) = NormalC unionValName [(bang, ConT ''Table `AppT` typeRefToType (unionValTableRef unionVal))] bang = Bang NoSourceUnpackedness SourceStrict mkUnionConstructors :: Name -> UnionDecl -> Q [Dec] mkUnionConstructors unionName union = fmap join . traverse mkUnionConstructor $ NE.toList (unionVals union) `zip` [1..] where mkUnionConstructor :: (UnionVal, Integer) -> Q [Dec] mkUnionConstructor (unionVal, ix) = do let constructorName = mkName' $ NC.unionConstructor union unionVal pure [ SigD constructorName $ ConT ''WriteTable `AppT` typeRefToType (unionValTableRef unionVal) ~> ConT ''WriteUnion `AppT` ConT unionName , FunD constructorName [ Clause [] (NormalB $ VarE 'writeUnion `AppE` intLitE ix) [] ] ] mkReadUnionFun :: Name -> NonEmpty Name -> UnionDecl -> Q [Dec] mkReadUnionFun unionName unionValNames union = do nArg <- newName "n" posArg <- newName "pos" wildcard <- newName "n'" let funName = mkName $ T.unpack $ NC.readUnionFun union let sig = SigD funName $ ConT ''Positive `AppT` ConT ''Word8 ~> ConT ''PositionInfo ~> ConT ''Either `AppT` ConT ''ReadError `AppT` (ConT ''Union `AppT` ConT unionName) let mkMatch :: Name -> Integer -> Match mkMatch unionValName ix = Match (intLitP ix) (NormalB $ InfixE (Just (compose [ConE 'Union, ConE unionValName])) (VarE '(<$>)) (Just (VarE 'readTable' `AppE` VarE posArg)) ) [] let matchWildcard = Match (VarP wildcard) (NormalB $ InfixE (Just (VarE 'pure)) (VarE '($!)) (Just (ConE 'UnionUnknown `AppE` VarE wildcard)) ) [] let matches = (uncurry mkMatch <$> NE.toList unionValNames `zip` [1..]) <> [matchWildcard] let funBody = NormalB $ CaseE (VarE 'getPositive `AppE` VarE nArg) matches let fun = FunD funName [ Clause [VarP nArg, VarP posArg] funBody [] ] pure [sig, fun] enumTypeToType :: EnumType -> Type enumTypeToType et = case et of EInt8 -> ConT ''Int8 EInt16 -> ConT ''Int16 EInt32 -> ConT ''Int32 EInt64 -> ConT ''Int64 EWord8 -> ConT ''Word8 EWord16 -> ConT ''Word16 EWord32 -> ConT ''Word32 EWord64 -> ConT ''Word64 enumTypeToTableFieldType :: Integral a => EnumType -> DefaultVal a -> TableFieldType enumTypeToTableFieldType et dflt = case et of EInt8 -> TInt8 (fromIntegral dflt) EInt16 -> TInt16 (fromIntegral dflt) EInt32 -> TInt32 (fromIntegral dflt) EInt64 -> TInt64 (fromIntegral dflt) EWord8 -> TWord8 (fromIntegral dflt) EWord16 -> TWord16 (fromIntegral dflt) EWord32 -> TWord32 (fromIntegral dflt) EWord64 -> TWord64 (fromIntegral dflt) enumTypeToStructFieldType :: EnumType -> StructFieldType enumTypeToStructFieldType et = case et of EInt8 -> SInt8 EInt16 -> SInt16 EInt32 -> SInt32 EInt64 -> SInt64 EWord8 -> SWord8 EWord16 -> SWord16 EWord32 -> SWord32 EWord64 -> SWord64 enumTypeToVectorElementType :: EnumType -> VectorElementType enumTypeToVectorElementType et = case et of EInt8 -> VInt8 EInt16 -> VInt16 EInt32 -> VInt32 EInt64 -> VInt64 EWord8 -> VWord8 EWord16 -> VWord16 EWord32 -> VWord32 EWord64 -> VWord64 structFieldTypeToWriteType :: StructFieldType -> Type structFieldTypeToWriteType sft = case sft of SInt8 -> ConT ''Int8 SInt16 -> ConT ''Int16 SInt32 -> ConT ''Int32 SInt64 -> ConT ''Int64 SWord8 -> ConT ''Word8 SWord16 -> ConT ''Word16 SWord32 -> ConT ''Word32 SWord64 -> ConT ''Word64 SFloat -> ConT ''Float SDouble -> ConT ''Double SBool -> ConT ''Bool SEnum _ enumType -> enumTypeToType enumType SStruct (namespace, structDecl) -> ConT ''WriteStruct `AppT` typeRefToType (TypeRef namespace (getIdent structDecl)) structFieldTypeToReadType :: StructFieldType -> Type structFieldTypeToReadType sft = case sft of SInt8 -> ConT ''Int8 SInt16 -> ConT ''Int16 SInt32 -> ConT ''Int32 SInt64 -> ConT ''Int64 SWord8 -> ConT ''Word8 SWord16 -> ConT ''Word16 SWord32 -> ConT ''Word32 SWord64 -> ConT ''Word64 SFloat -> ConT ''Float SDouble -> ConT ''Double SBool -> ConT ''Bool SEnum _ enumType -> enumTypeToType enumType SStruct (namespace, structDecl) -> ConT ''Struct `AppT` typeRefToType (TypeRef namespace (getIdent structDecl)) tableFieldTypeToWriteType :: TableFieldType -> Type tableFieldTypeToWriteType tft = case tft of TInt8 _ -> ConT ''Maybe `AppT` ConT ''Int8 TInt16 _ -> ConT ''Maybe `AppT` ConT ''Int16 TInt32 _ -> ConT ''Maybe `AppT` ConT ''Int32 TInt64 _ -> ConT ''Maybe `AppT` ConT ''Int64 TWord8 _ -> ConT ''Maybe `AppT` ConT ''Word8 TWord16 _ -> ConT ''Maybe `AppT` ConT ''Word16 TWord32 _ -> ConT ''Maybe `AppT` ConT ''Word32 TWord64 _ -> ConT ''Maybe `AppT` ConT ''Word64 TFloat _ -> ConT ''Maybe `AppT` ConT ''Float TDouble _ -> ConT ''Maybe `AppT` ConT ''Double TBool _ -> ConT ''Maybe `AppT` ConT ''Bool TString req -> requiredType req (ConT ''Text) TEnum _ enumType _ -> ConT ''Maybe `AppT` enumTypeToType enumType TStruct typeRef req -> requiredType req (ConT ''WriteStruct `AppT` typeRefToType typeRef) TTable typeRef req -> requiredType req (ConT ''WriteTable `AppT` typeRefToType typeRef) TUnion typeRef req -> requiredType req (ConT ''WriteUnion `AppT` typeRefToType typeRef) TVector req vecElemType -> requiredType req (vectorElementTypeToWriteType vecElemType) tableFieldTypeToReadType :: TableFieldType -> Type tableFieldTypeToReadType tft = case tft of TInt8 _ -> ConT ''Int8 TInt16 _ -> ConT ''Int16 TInt32 _ -> ConT ''Int32 TInt64 _ -> ConT ''Int64 TWord8 _ -> ConT ''Word8 TWord16 _ -> ConT ''Word16 TWord32 _ -> ConT ''Word32 TWord64 _ -> ConT ''Word64 TFloat _ -> ConT ''Float TDouble _ -> ConT ''Double TBool _ -> ConT ''Bool TString req -> requiredType req (ConT ''Text) TEnum _ enumType _ -> enumTypeToType enumType TStruct typeRef req -> requiredType req (ConT ''Struct `AppT` typeRefToType typeRef) TTable typeRef req -> requiredType req (ConT ''Table `AppT` typeRefToType typeRef) TUnion typeRef req -> requiredType req (ConT ''Union `AppT` typeRefToType typeRef) TVector req vecElemType -> requiredType req (vectorElementTypeToReadType vecElemType) vectorElementTypeToWriteType :: VectorElementType -> Type vectorElementTypeToWriteType vet = case vet of VInt8 -> ConT ''WriteVector `AppT` ConT ''Int8 VInt16 -> ConT ''WriteVector `AppT` ConT ''Int16 VInt32 -> ConT ''WriteVector `AppT` ConT ''Int32 VInt64 -> ConT ''WriteVector `AppT` ConT ''Int64 VWord8 -> ConT ''WriteVector `AppT` ConT ''Word8 VWord16 -> ConT ''WriteVector `AppT` ConT ''Word16 VWord32 -> ConT ''WriteVector `AppT` ConT ''Word32 VWord64 -> ConT ''WriteVector `AppT` ConT ''Word64 VFloat -> ConT ''WriteVector `AppT` ConT ''Float VDouble -> ConT ''WriteVector `AppT` ConT ''Double VBool -> ConT ''WriteVector `AppT` ConT ''Bool VString -> ConT ''WriteVector `AppT` ConT ''Text VEnum _ enumType -> ConT ''WriteVector `AppT` enumTypeToType enumType VStruct typeRef -> ConT ''WriteVector `AppT` (ConT ''WriteStruct `AppT` typeRefToType typeRef) VTable typeRef -> ConT ''WriteVector `AppT` (ConT ''WriteTable `AppT` typeRefToType typeRef) VUnion typeRef -> ConT ''WriteVector `AppT` (ConT ''WriteUnion `AppT` typeRefToType typeRef) vectorElementTypeToReadType :: VectorElementType -> Type vectorElementTypeToReadType vet = case vet of VInt8 -> ConT ''Vector `AppT` ConT ''Int8 VInt16 -> ConT ''Vector `AppT` ConT ''Int16 VInt32 -> ConT ''Vector `AppT` ConT ''Int32 VInt64 -> ConT ''Vector `AppT` ConT ''Int64 VWord8 -> ConT ''Vector `AppT` ConT ''Word8 VWord16 -> ConT ''Vector `AppT` ConT ''Word16 VWord32 -> ConT ''Vector `AppT` ConT ''Word32 VWord64 -> ConT ''Vector `AppT` ConT ''Word64 VFloat -> ConT ''Vector `AppT` ConT ''Float VDouble -> ConT ''Vector `AppT` ConT ''Double VBool -> ConT ''Vector `AppT` ConT ''Bool VString -> ConT ''Vector `AppT` ConT ''Text VEnum _ enumType -> ConT ''Vector `AppT` enumTypeToType enumType VStruct typeRef -> ConT ''Vector `AppT` (ConT ''Struct `AppT` typeRefToType typeRef) VTable typeRef -> ConT ''Vector `AppT` (ConT ''Table `AppT` typeRefToType typeRef) VUnion typeRef -> ConT ''Vector `AppT` (ConT ''Union `AppT` typeRefToType typeRef) typeRefToType :: TypeRef -> Type typeRefToType (TypeRef ns ident) = ConT . mkName' . NC.withModulePrefix ns . NC.dataTypeName $ ident requiredType :: Required -> Type -> Type requiredType Req t = t requiredType Opt t = AppT (ConT ''Maybe) t mkName' :: Text -> Name mkName' = mkName . T.unpack newName' :: Text -> Q Name newName' = newName . T.unpack intLitP :: Integral i => i -> Pat intLitP = LitP . IntegerL . toInteger intLitE :: Integral i => i -> Exp intLitE = LitE . IntegerL . toInteger realLitE :: Real i => i -> Exp realLitE = LitE . RationalL . toRational textLitE :: Text -> Exp textLitE t = VarE 'T.pack `AppE` LitE (StringL (T.unpack t)) stringLitE :: Text -> Exp stringLitE t = LitE (StringL (T.unpack t)) inlinePragma :: Name -> Dec inlinePragma funName = PragmaD $ InlineP funName Inline FunLike AllPhases -- | Applies a function to multiple arguments. Assumes the list is not empty. app :: [Exp] -> Exp app = foldl1 AppE compose :: [Exp] -> Exp compose = foldr1 (\e1 e2 -> InfixE (Just e1) (VarE '(.)) (Just e2)) nonEmptyUnzip3 :: NonEmpty (a,b,c) -> (NonEmpty a, NonEmpty b, NonEmpty c) nonEmptyUnzip3 xs = ( (\(x, _, _) -> x) <$> xs , (\(_, x, _) -> x) <$> xs , (\(_, _, x) -> x) <$> xs )