{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module FlatBuffers.Internal.Compiler.SemanticAnalysis where
import Control.Monad ( forM_, join, when )
import Control.Monad.Except ( throwError )
import Control.Monad.Reader ( ReaderT, asks, local, runReaderT )
import Control.Monad.State ( MonadState, State, StateT, evalState, evalStateT, get, mapStateT, modify, put )
import Control.Monad.Trans ( lift )
import Data.Bits ( (.&.), (.|.), Bits, FiniteBits, bit, finiteBitSize )
import Data.Coerce ( coerce )
import Data.Foldable ( asum, find, foldlM, traverse_ )
import qualified Data.Foldable as Foldable
import Data.Functor ( ($>), (<&>) )
import Data.Int
import Data.Ix ( inRange )
import qualified Data.List as List
import Data.List.NonEmpty ( NonEmpty((:|)) )
import qualified Data.List.NonEmpty as NE
import Data.Map.Strict ( Map )
import qualified Data.Map.Strict as Map
import Data.Maybe ( catMaybes, fromMaybe, isJust )
import Data.Monoid ( Sum(..) )
import Data.Scientific ( Scientific )
import qualified Data.Scientific as Scientific
import Data.Set ( Set )
import qualified Data.Set as Set
import Data.Text ( Text )
import qualified Data.Text as T
import Data.Traversable ( for )
import Data.Word
import FlatBuffers.Internal.Compiler.Display ( Display(..) )
import FlatBuffers.Internal.Compiler.SyntaxTree ( FileTree(..), HasIdent(..), HasMetadata(..), Ident, Namespace, Schema, TypeRef(..), qualify )
import qualified FlatBuffers.Internal.Compiler.SyntaxTree as ST
import FlatBuffers.Internal.Compiler.ValidSyntaxTree
import FlatBuffers.Internal.Constants
import FlatBuffers.Internal.Types
import Text.Read ( readMaybe )
newtype Validation a = Validation
{ runValidation :: ReaderT ValidationState (Either String) a
}
deriving newtype (Functor, Applicative, Monad)
data ValidationState = ValidationState
{ validationStateCurrentContext :: ![Ident]
, validationStateAllAttributes :: !(Set ST.AttributeDecl)
}
class Monad m => MonadValidation m where
validating :: HasIdent a => a -> m b -> m b
resetContext :: m a -> m a
getContext :: m [Ident]
getDeclaredAttributes :: m (Set ST.AttributeDecl)
throwErrorMsg :: String -> m a
instance MonadValidation Validation where
validating a (Validation v) = Validation (local addIdent v)
where
addIdent (ValidationState ctx attrs) = ValidationState (getIdent a : ctx) attrs
resetContext (Validation v) = Validation (local reset v)
where
reset (ValidationState _ attrs) = ValidationState [] attrs
getContext = Validation (asks (List.reverse . validationStateCurrentContext))
getDeclaredAttributes = Validation (asks validationStateAllAttributes)
throwErrorMsg msg = do
idents <- getContext
if null idents
then Validation (throwError msg)
else Validation . throwError $ "[" <> List.intercalate "." (T.unpack . unIdent <$> idents) <> "]: " <> msg
instance MonadValidation m => MonadValidation (StateT s m) where
validating = mapStateT . validating
resetContext = mapStateT resetContext
getContext = lift getContext
getDeclaredAttributes = lift getDeclaredAttributes
throwErrorMsg = lift . throwErrorMsg
data SymbolTable enum struct table union = SymbolTable
{ allEnums :: !(Map (Namespace, Ident) enum)
, allStructs :: !(Map (Namespace, Ident) struct)
, allTables :: !(Map (Namespace, Ident) table)
, allUnions :: !(Map (Namespace, Ident) union)
}
deriving (Eq, Show)
instance Semigroup (SymbolTable e s t u) where
SymbolTable e1 s1 t1 u1 <> SymbolTable e2 s2 t2 u2 =
SymbolTable (e1 <> e2) (s1 <> s2) (t1 <> t2) (u1 <> u2)
instance Monoid (SymbolTable e s t u) where
mempty = SymbolTable mempty mempty mempty mempty
type Stage1 = SymbolTable ST.EnumDecl ST.StructDecl ST.TableDecl ST.UnionDecl
type Stage2 = SymbolTable EnumDecl ST.StructDecl ST.TableDecl ST.UnionDecl
type Stage3 = SymbolTable EnumDecl StructDecl ST.TableDecl ST.UnionDecl
type Stage4 = SymbolTable EnumDecl StructDecl TableDecl ST.UnionDecl
type ValidDecls = SymbolTable EnumDecl StructDecl TableDecl UnionDecl
validateSchemas :: FileTree Schema -> Either String (FileTree ValidDecls)
validateSchemas schemas =
flip runReaderT (ValidationState [] allAttributes) $ runValidation $ do
symbolTables <- createSymbolTables schemas
checkDuplicateIdentifiers (allQualifiedTopLevelIdentifiers symbolTables)
validateEnums symbolTables
>>= validateStructs
>>= validateTables
>>= validateUnions
>>= updateRootTable (fileTreeRoot schemas)
where
allQualifiedTopLevelIdentifiers symbolTables =
flip concatMap symbolTables $ \symbolTable ->
join
[ uncurry qualify <$> Map.keys (allEnums symbolTable)
, uncurry qualify <$> Map.keys (allStructs symbolTable)
, uncurry qualify <$> Map.keys (allTables symbolTable)
, uncurry qualify <$> Map.keys (allUnions symbolTable)
]
declaredAttributes =
flip concatMap schemas $ \schema ->
[ attr | ST.DeclA attr <- ST.decls schema ]
allAttributes = Set.fromList $ declaredAttributes <> knownAttributes
createSymbolTables :: FileTree Schema -> Validation (FileTree Stage1)
createSymbolTables = traverse (createSymbolTable . ST.decls)
where
createSymbolTable :: [ST.Decl] -> Validation Stage1
createSymbolTable decls = snd <$> foldlM go ("", mempty) decls
go :: (Namespace, Stage1) -> ST.Decl -> Validation (Namespace, Stage1)
go (currentNamespace, symbolTable) decl =
case decl of
ST.DeclE enum -> addEnum symbolTable currentNamespace enum <&> \symbolTable' -> (currentNamespace, symbolTable')
ST.DeclS struct -> addStruct symbolTable currentNamespace struct <&> \symbolTable' -> (currentNamespace, symbolTable')
ST.DeclT table -> addTable symbolTable currentNamespace table <&> \symbolTable' -> (currentNamespace, symbolTable')
ST.DeclU union -> addUnion symbolTable currentNamespace union <&> \symbolTable' -> (currentNamespace, symbolTable')
ST.DeclN (ST.NamespaceDecl newNamespace) -> pure (newNamespace, symbolTable)
_ -> pure (currentNamespace, symbolTable)
addEnum (SymbolTable es ss ts us) namespace enum = insertSymbol namespace enum es <&> \es' -> SymbolTable es' ss ts us
addStruct (SymbolTable es ss ts us) namespace struct = insertSymbol namespace struct ss <&> \ss' -> SymbolTable es ss' ts us
addTable (SymbolTable es ss ts us) namespace table = insertSymbol namespace table ts <&> \ts' -> SymbolTable es ss ts' us
addUnion (SymbolTable es ss ts us) namespace union = insertSymbol namespace union us <&> \us' -> SymbolTable es ss ts us'
insertSymbol :: HasIdent a => Namespace -> a -> Map (Namespace, Ident) a -> Validation (Map (Namespace, Ident) a)
insertSymbol namespace symbol map =
if Map.member key map
then throwErrorMsg $ display (qualify namespace symbol) <> " declared more than once"
else pure $ Map.insert key symbol map
where
key = (namespace, getIdent symbol)
data RootInfo = RootInfo
{ rootTableNamespace :: !Namespace
, rootTable :: !TableDecl
, rootFileIdent :: !(Maybe Text)
}
updateRootTable :: Schema -> FileTree ValidDecls -> Validation (FileTree ValidDecls)
updateRootTable schema symbolTables =
getRootInfo schema symbolTables <&> \case
Just rootInfo -> updateSymbolTable rootInfo <$> symbolTables
Nothing -> symbolTables
where
updateSymbolTable :: RootInfo -> ValidDecls -> ValidDecls
updateSymbolTable rootInfo st = st { allTables = Map.mapWithKey (updateTable rootInfo) (allTables st) }
updateTable :: RootInfo -> (Namespace, Ident) -> TableDecl -> TableDecl
updateTable (RootInfo rootTableNamespace rootTable fileIdent) (namespace, _) table =
if namespace == rootTableNamespace && table == rootTable
then table { tableIsRoot = IsRoot fileIdent }
else table
getRootInfo :: Schema -> FileTree ValidDecls -> Validation (Maybe RootInfo)
getRootInfo schema symbolTables =
foldlM go ("", Nothing, Nothing) (ST.decls schema) <&> \case
(_, Just (rootTableNamespace, rootTable), fileIdent) -> Just $ RootInfo rootTableNamespace rootTable fileIdent
_ -> Nothing
where
go :: (Namespace, Maybe (Namespace, TableDecl), Maybe Text) -> ST.Decl -> Validation (Namespace, Maybe (Namespace, TableDecl), Maybe Text)
go state@(currentNamespace, rootInfo, fileIdent) decl =
case decl of
ST.DeclN (ST.NamespaceDecl newNamespace) -> pure (newNamespace, rootInfo, fileIdent)
ST.DeclFI (ST.FileIdentifierDecl newFileIdent) -> pure (currentNamespace, rootInfo, Just (coerce newFileIdent))
ST.DeclR (ST.RootDecl typeRef) ->
findDecl currentNamespace symbolTables typeRef >>= \case
MatchT rootTableNamespace rootTable -> pure (currentNamespace, Just (rootTableNamespace, rootTable), fileIdent)
_ -> throwErrorMsg "root type must be a table"
_ -> pure state
knownAttributes :: [ST.AttributeDecl]
knownAttributes =
coerce
[ idAttr
, deprecatedAttr
, requiredAttr
, forceAlignAttr
, bitFlagsAttr
]
<> otherKnownAttributes
idAttr, deprecatedAttr, requiredAttr, forceAlignAttr, bitFlagsAttr :: Text
idAttr = "id"
deprecatedAttr = "deprecated"
requiredAttr = "required"
forceAlignAttr = "force_align"
bitFlagsAttr = "bit_flags"
otherKnownAttributes :: [ST.AttributeDecl]
otherKnownAttributes =
[ "nested_flatbuffer"
, "flexbuffer"
, "key"
, "hash"
, "original_order"
, "native_inline"
, "native_default"
, "native_custom_alloc"
, "native_type"
, "cpp_type"
, "cpp_ptr_type"
, "cpp_str_type"
, "cpp_str_flex_ctor"
, "shared"
]
data Match enum struct table union
= MatchE !Namespace !enum
| MatchS !Namespace !struct
| MatchT !Namespace !table
| MatchU !Namespace !union
findDecl ::
MonadValidation m
=> Namespace
-> FileTree (SymbolTable e s t u)
-> TypeRef
-> m (Match e s t u)
findDecl currentNamespace symbolTables typeRef@(TypeRef refNamespace refIdent) =
let parentNamespaces' = parentNamespaces currentNamespace
results = do
parentNamespace <- parentNamespaces'
let candidateNamespace = parentNamespace <> refNamespace
let searchSymbolTable symbolTable =
asum
[ MatchE candidateNamespace <$> Map.lookup (candidateNamespace, refIdent) (allEnums symbolTable)
, MatchS candidateNamespace <$> Map.lookup (candidateNamespace, refIdent) (allStructs symbolTable)
, MatchT candidateNamespace <$> Map.lookup (candidateNamespace, refIdent) (allTables symbolTable)
, MatchU candidateNamespace <$> Map.lookup (candidateNamespace, refIdent) (allUnions symbolTable)
]
pure $ asum $ fmap searchSymbolTable symbolTables
in
case asum results of
Just match -> pure match
Nothing ->
throwErrorMsg $
"type "
<> display typeRef
<> " does not exist (checked in these namespaces: "
<> display parentNamespaces'
<> ")"
parentNamespaces :: ST.Namespace -> NonEmpty ST.Namespace
parentNamespaces (ST.Namespace ns) =
coerce $ NE.reverse $ NE.inits ns
validateEnums :: FileTree Stage1 -> Validation (FileTree Stage2)
validateEnums symbolTables =
for symbolTables $ \symbolTable -> do
validEnums <- Map.traverseWithKey validateEnum (allEnums symbolTable)
pure symbolTable { allEnums = validEnums }
validateEnum :: (Namespace, Ident) -> ST.EnumDecl -> Validation EnumDecl
validateEnum (currentNamespace, _) enum =
validating (qualify currentNamespace enum) $ do
checkDuplicateFields
checkUndeclaredAttributes enum
validEnum
where
isBitFlags = hasAttribute bitFlagsAttr (ST.enumMetadata enum)
validEnum = do
enumType <- validateEnumType (ST.enumType enum)
let enumVals = flip evalState Nothing . traverse mapEnumVal $ ST.enumVals enum
validateOrder enumVals
traverse_ (validateBounds enumType) enumVals
pure EnumDecl
{ enumIdent = getIdent enum
, enumType = enumType
, enumBitFlags = isBitFlags
, enumVals = shiftBitFlags <$> enumVals
}
mapEnumVal :: ST.EnumVal -> State (Maybe Integer) EnumVal
mapEnumVal enumVal = do
thisInt <-
case ST.enumValLiteral enumVal of
Just (ST.IntLiteral thisInt) ->
pure thisInt
Nothing ->
get <&> \case
Just lastInt -> lastInt + 1
Nothing -> 0
put (Just thisInt)
pure (EnumVal (getIdent enumVal) thisInt)
validateOrder :: NonEmpty EnumVal -> Validation ()
validateOrder xs =
let consecutivePairs = NE.toList xs `zip` NE.tail xs
outOfOrderPais = filter (\(x, y) -> enumValInt x >= enumValInt y) consecutivePairs
in
case outOfOrderPais of
[] -> pure ()
(x, y) : _ -> throwErrorMsg $
"enum values must be specified in ascending order. "
<> display (enumValIdent y)
<> " ("
<> display (enumValInt y)
<> ") should be greater than "
<> display (enumValIdent x)
<> " ("
<> display (enumValInt x)
<> ")"
validateBounds :: EnumType -> EnumVal -> Validation ()
validateBounds enumType enumVal =
validating enumVal $
case enumType of
EInt8 -> validateBounds' @Int8 enumVal
EInt16 -> validateBounds' @Int16 enumVal
EInt32 -> validateBounds' @Int32 enumVal
EInt64 -> validateBounds' @Int64 enumVal
EWord8 -> validateBounds' @Word8 enumVal
EWord16 -> validateBounds' @Word16 enumVal
EWord32 -> validateBounds' @Word32 enumVal
EWord64 -> validateBounds' @Word64 enumVal
validateBounds' :: forall a. (FiniteBits a, Integral a, Bounded a) => EnumVal -> Validation ()
validateBounds' e =
if inRange (lower, upper) (enumValInt e)
then pure ()
else throwErrorMsg $
"enum value of "
<> display (enumValInt e)
<> " does not fit ["
<> display lower
<> "; "
<> display upper
<> "]"
where
lower = if isBitFlags
then 0
else toInteger (minBound @a)
upper = if isBitFlags
then toInteger (finiteBitSize @a (undefined :: a) - 1)
else toInteger (maxBound @a)
validateEnumType :: ST.Type -> Validation EnumType
validateEnumType t =
case t of
ST.TInt8 -> unlessIsBitFlags EInt8
ST.TInt16 -> unlessIsBitFlags EInt16
ST.TInt32 -> unlessIsBitFlags EInt32
ST.TInt64 -> unlessIsBitFlags EInt64
ST.TWord8 -> pure EWord8
ST.TWord16 -> pure EWord16
ST.TWord32 -> pure EWord32
ST.TWord64 -> pure EWord64
_ -> throwErrorMsg "underlying enum type must be integral"
where
unlessIsBitFlags x =
if isBitFlags
then throwErrorMsg "underlying type of bit_flags enum must be unsigned"
else pure x
shiftBitFlags :: EnumVal -> EnumVal
shiftBitFlags e =
if isBitFlags
then e { enumValInt = bit (fromIntegral @Integer @Int (enumValInt e)) }
else e
checkDuplicateFields :: Validation ()
checkDuplicateFields =
checkDuplicateIdentifiers
(ST.enumVals enum)
data TableFieldWithoutId = TableFieldWithoutId !Ident !TableFieldType !Bool
validateTables :: FileTree Stage3 -> Validation (FileTree Stage4)
validateTables symbolTables =
for symbolTables $ \symbolTable -> do
validTables <- Map.traverseWithKey (validateTable symbolTables) (allTables symbolTable)
pure symbolTable { allTables = validTables }
validateTable :: FileTree Stage3 -> (Namespace, Ident) -> ST.TableDecl -> Validation TableDecl
validateTable symbolTables (currentNamespace, _) table =
validating (qualify currentNamespace table) $ do
let fields = ST.tableFields table
let fieldsMetadata = ST.tableFieldMetadata <$> fields
checkDuplicateFields fields
checkUndeclaredAttributes table
validFieldsWithoutIds <- traverse validateTableField fields
validFields <- assignFieldIds fieldsMetadata validFieldsWithoutIds
pure TableDecl
{ tableIdent = getIdent table
, tableIsRoot = NotRoot
, tableFields = validFields
}
where
checkDuplicateFields :: [ST.TableField] -> Validation ()
checkDuplicateFields = checkDuplicateIdentifiers
assignFieldIds :: [ST.Metadata] -> [TableFieldWithoutId] -> Validation [TableField]
assignFieldIds metadata fieldsWithoutIds = do
ids <- catMaybes <$> traverse (findIntAttr idAttr) metadata
if null ids
then pure $ evalState (traverse assignFieldId fieldsWithoutIds) (-1)
else if length ids == length fieldsWithoutIds
then do
let fields = zipWith (\(TableFieldWithoutId ident typ depr) id -> TableField id ident typ depr) fieldsWithoutIds ids
let sorted = List.sortOn tableFieldId fields
evalStateT (traverse_ checkFieldId sorted) (-1)
pure sorted
else
throwErrorMsg "either all fields or no fields must have an 'id' attribute"
assignFieldId :: TableFieldWithoutId -> State Integer TableField
assignFieldId (TableFieldWithoutId ident typ depr) = do
lastId <- get
let fieldId =
case typ of
TUnion _ _ -> lastId + 2
TVector _ (VUnion _) -> lastId + 2
_ -> lastId + 1
put fieldId
pure (TableField fieldId ident typ depr)
checkFieldId :: TableField -> StateT Integer Validation ()
checkFieldId field = do
lastId <- get
validating field $ do
case tableFieldType field of
TUnion _ _ ->
when (tableFieldId field /= lastId + 2) $
throwErrorMsg "the id of a union field must be the last field's id + 2"
TVector _ (VUnion _) ->
when (tableFieldId field /= lastId + 2) $
throwErrorMsg "the id of a vector of unions field must be the last field's id + 2"
_ ->
when (tableFieldId field /= lastId + 1) $
throwErrorMsg $ "field ids must be consecutive from 0; id " <> display (lastId + 1) <> " is missing"
put (tableFieldId field)
validateTableField :: ST.TableField -> Validation TableFieldWithoutId
validateTableField tf =
validating tf $ do
checkUndeclaredAttributes tf
validFieldType <- validateTableFieldType (ST.tableFieldMetadata tf) (ST.tableFieldDefault tf) (ST.tableFieldType tf)
pure $ TableFieldWithoutId
(getIdent tf)
validFieldType
(hasAttribute deprecatedAttr (ST.tableFieldMetadata tf))
validateTableFieldType :: ST.Metadata -> Maybe ST.DefaultVal -> ST.Type -> Validation TableFieldType
validateTableFieldType md dflt tableFieldType =
case tableFieldType of
ST.TInt8 -> checkNoRequired md >> validateDefaultValAsInt @Int8 dflt <&> TInt8
ST.TInt16 -> checkNoRequired md >> validateDefaultValAsInt @Int16 dflt <&> TInt16
ST.TInt32 -> checkNoRequired md >> validateDefaultValAsInt @Int32 dflt <&> TInt32
ST.TInt64 -> checkNoRequired md >> validateDefaultValAsInt @Int64 dflt <&> TInt64
ST.TWord8 -> checkNoRequired md >> validateDefaultValAsInt @Word8 dflt <&> TWord8
ST.TWord16 -> checkNoRequired md >> validateDefaultValAsInt @Word16 dflt <&> TWord16
ST.TWord32 -> checkNoRequired md >> validateDefaultValAsInt @Word32 dflt <&> TWord32
ST.TWord64 -> checkNoRequired md >> validateDefaultValAsInt @Word64 dflt <&> TWord64
ST.TFloat -> checkNoRequired md >> validateDefaultValAsScientific dflt <&> TFloat
ST.TDouble -> checkNoRequired md >> validateDefaultValAsScientific dflt <&> TDouble
ST.TBool -> checkNoRequired md >> validateDefaultValAsBool dflt <&> TBool
ST.TString -> checkNoDefault dflt $> TString (isRequired md)
ST.TRef typeRef ->
findDecl currentNamespace symbolTables typeRef >>= \case
MatchE ns enum -> do
checkNoRequired md
validDefault <- validateDefaultAsEnum dflt enum
pure $ TEnum (TypeRef ns (getIdent enum)) (enumType enum) validDefault
MatchS ns struct -> checkNoDefault dflt $> TStruct (TypeRef ns (getIdent struct)) (isRequired md)
MatchT ns table -> checkNoDefault dflt $> TTable (TypeRef ns (getIdent table)) (isRequired md)
MatchU ns union -> checkNoDefault dflt $> TUnion (TypeRef ns (getIdent union)) (isRequired md)
ST.TVector vecType ->
checkNoDefault dflt >> TVector (isRequired md) <$>
case vecType of
ST.TInt8 -> pure VInt8
ST.TInt16 -> pure VInt16
ST.TInt32 -> pure VInt32
ST.TInt64 -> pure VInt64
ST.TWord8 -> pure VWord8
ST.TWord16 -> pure VWord16
ST.TWord32 -> pure VWord32
ST.TWord64 -> pure VWord64
ST.TFloat -> pure VFloat
ST.TDouble -> pure VDouble
ST.TBool -> pure VBool
ST.TString -> pure VString
ST.TVector _ -> throwErrorMsg "nested vector types not supported"
ST.TRef typeRef ->
findDecl currentNamespace symbolTables typeRef <&> \case
MatchE ns enum ->
VEnum (TypeRef ns (getIdent enum))
(enumType enum)
MatchS ns struct ->
VStruct (TypeRef ns (getIdent struct))
MatchT ns table -> VTable (TypeRef ns (getIdent table))
MatchU ns union -> VUnion (TypeRef ns (getIdent union))
checkNoRequired :: ST.Metadata -> Validation ()
checkNoRequired md =
when (hasAttribute requiredAttr md) $
throwErrorMsg "only non-scalar fields (strings, vectors, unions, structs, tables) may be 'required'"
checkNoDefault :: Maybe ST.DefaultVal -> Validation ()
checkNoDefault dflt =
when (isJust dflt) $
throwErrorMsg
"default values currently only supported for scalar fields (integers, floating point, bool, enums)"
isRequired :: ST.Metadata -> Required
isRequired md = if hasAttribute requiredAttr md then Req else Opt
validateDefaultValAsInt :: forall a. (Integral a, Bounded a, Display a) => Maybe ST.DefaultVal -> Validation (DefaultVal Integer)
validateDefaultValAsInt dflt =
case dflt of
Nothing -> pure (DefaultVal 0)
Just (ST.DefaultNum n) -> scientificToInteger @a n "default value must be integral"
Just _ -> throwErrorMsg "default value must be integral"
validateDefaultValAsScientific :: Maybe ST.DefaultVal -> Validation (DefaultVal Scientific)
validateDefaultValAsScientific dflt =
case dflt of
Nothing -> pure (DefaultVal 0)
Just (ST.DefaultNum n) -> pure (DefaultVal n)
Just _ -> throwErrorMsg "default value must be a number"
validateDefaultValAsBool :: Maybe ST.DefaultVal -> Validation (DefaultVal Bool)
validateDefaultValAsBool dflt =
case dflt of
Nothing -> pure (DefaultVal False)
Just (ST.DefaultBool b) -> pure (DefaultVal b)
Just _ -> throwErrorMsg "default value must be a boolean"
validateDefaultAsEnum :: Maybe ST.DefaultVal -> EnumDecl -> Validation (DefaultVal Integer)
validateDefaultAsEnum dflt enum =
case dflt of
Nothing ->
if enumBitFlags enum
then pure 0
else
case find (\val -> enumValInt val == 0) (enumVals enum) of
Just _ -> pure 0
Nothing -> throwErrorMsg "enum does not have a 0 value; please manually specify a default for this field"
Just (ST.DefaultNum n) ->
if enumBitFlags enum
then
case enumType enum of
EWord8 -> scientificToInteger @Word8 n defaultErrorMsg
EWord16 -> scientificToInteger @Word16 n defaultErrorMsg
EWord32 -> scientificToInteger @Word32 n defaultErrorMsg
EWord64 -> scientificToInteger @Word64 n defaultErrorMsg
_ -> throwErrorMsg "The 'impossible' has happened: bit_flags enum with signed integer"
else
case Scientific.floatingOrInteger @Float n of
Left _float -> throwErrorMsg defaultErrorMsg
Right i ->
case find (\val -> enumValInt val == i) (enumVals enum) of
Just matchingVal -> pure (DefaultVal (enumValInt matchingVal))
Nothing -> throwErrorMsg $ "default value of " <> display i <> " is not part of enum " <> display (getIdent enum)
Just (ST.DefaultRef refs) ->
if enumBitFlags enum
then
foldr1 (.|.) <$> traverse findEnumByRef refs
else
case refs of
ref :| [] -> findEnumByRef ref
_ -> throwErrorMsg $ "default value must be a single identifier, found "
<> display (NE.length refs)
<> ": "
<> display (fmap (\ref -> "'" <> ref <> "'") refs)
Just (ST.DefaultBool _) ->
throwErrorMsg defaultErrorMsg
where
defaultErrorMsg =
if enumBitFlags enum
then case enumVals enum of
x :| y : _ ->
"default value must be integral, one of ["
<> display (getIdent <$> enumVals enum)
<> "], or a combination of the latter in double quotes (e.g. \""
<> T.unpack (unIdent (getIdent x))
<> " "
<> T.unpack (unIdent (getIdent y))
<> "\")"
_ ->
"default value must be integral or one of: " <> display (getIdent <$> enumVals enum)
else
"default value must be integral or one of: " <> display (getIdent <$> enumVals enum)
findEnumByRef :: Text -> Validation (DefaultVal Integer)
findEnumByRef ref =
case find (\val -> unIdent (getIdent val) == ref) (enumVals enum) of
Just matchingVal -> pure (DefaultVal (enumValInt matchingVal))
Nothing -> throwErrorMsg $ "default value of " <> display ref <> " is not part of enum " <> display (getIdent enum)
scientificToInteger ::
forall a. (Integral a, Bounded a, Display a)
=> Scientific -> String -> Validation (DefaultVal Integer)
scientificToInteger n notIntegerErrorMsg =
if not (Scientific.isInteger n)
then throwErrorMsg notIntegerErrorMsg
else
case Scientific.toBoundedInteger @a n of
Nothing ->
throwErrorMsg $
"default value does not fit ["
<> display (minBound @a)
<> "; "
<> display (maxBound @a)
<> "]"
Just i -> pure (DefaultVal (toInteger i))
validateUnions :: FileTree Stage4 -> Validation (FileTree ValidDecls)
validateUnions symbolTables =
for symbolTables $ \symbolTable -> do
validUnions <- Map.traverseWithKey (validateUnion symbolTables) (allUnions symbolTable)
pure symbolTable { allUnions = validUnions }
validateUnion :: FileTree Stage4 -> (Namespace, Ident) -> ST.UnionDecl -> Validation UnionDecl
validateUnion symbolTables (currentNamespace, _) union =
validating (qualify currentNamespace union) $ do
validUnionVals <- traverse validateUnionVal (ST.unionVals union)
checkDuplicateVals validUnionVals
checkUndeclaredAttributes union
pure $ UnionDecl
{ unionIdent = getIdent union
, unionVals = validUnionVals
}
where
validateUnionVal :: ST.UnionVal -> Validation UnionVal
validateUnionVal uv = do
let tref = ST.unionValTypeRef uv
let partiallyQualifiedTypeRef = qualify (typeRefNamespace tref) (typeRefIdent tref)
let ident = fromMaybe partiallyQualifiedTypeRef (ST.unionValIdent uv)
let identFormatted = coerce $ T.replace "." "_" $ coerce ident
validating identFormatted $ do
tableRef <- validateUnionValType tref
pure $ UnionVal
{ unionValIdent = identFormatted
, unionValTableRef = tableRef
}
validateUnionValType :: TypeRef -> Validation TypeRef
validateUnionValType typeRef =
findDecl currentNamespace symbolTables typeRef >>= \case
MatchT ns table -> pure $ TypeRef ns (getIdent table)
_ -> throwErrorMsg "union members may only be tables"
checkDuplicateVals :: NonEmpty UnionVal -> Validation ()
checkDuplicateVals vals = checkDuplicateIdentifiers (NE.cons "NONE" (fmap getIdent vals))
type ValidatedStructs = Map (Namespace, Ident) StructDecl
validateStructs :: FileTree Stage2 -> Validation (FileTree Stage3)
validateStructs symbolTables =
flip evalStateT Map.empty $ traverse validateFile symbolTables
where
validateFile :: Stage2 -> StateT ValidatedStructs Validation Stage3
validateFile symbolTable = do
let structs = allStructs symbolTable
traverse_ (\((ns, _), struct) -> checkStructCycles symbolTables (ns, struct)) (Map.toList structs)
validStructs <- Map.traverseWithKey (\(ns, _) struct -> validateStruct symbolTables ns struct) structs
pure symbolTable { allStructs = validStructs }
checkStructCycles :: forall m. MonadValidation m => FileTree Stage2 -> (Namespace, ST.StructDecl) -> m ()
checkStructCycles symbolTables = go []
where
go :: [Ident] -> (Namespace, ST.StructDecl) -> m ()
go visited (currentNamespace, struct) = do
let qualifiedName = qualify currentNamespace struct
resetContext $
validating qualifiedName $
if qualifiedName `elem` visited
then
throwErrorMsg $
"cyclic dependency detected ["
<> display (T.intercalate " -> " . coerce $ List.dropWhile (/= qualifiedName) $ List.reverse (qualifiedName : visited))
<>"] - structs cannot contain themselves, directly or indirectly"
else
forM_ (ST.structFields struct) $ \field ->
validating field $
case ST.structFieldType field of
ST.TRef typeRef ->
findDecl currentNamespace symbolTables typeRef >>= \case
MatchS ns struct -> go (qualifiedName : visited) (ns, struct)
_ -> pure ()
_ -> pure ()
data UnpaddedStructField = UnpaddedStructField
{ unpaddedStructFieldIdent :: !Ident
, unpaddedStructFieldType :: !StructFieldType
} deriving (Show, Eq)
validateStruct ::
forall m. (MonadState ValidatedStructs m, MonadValidation m)
=> FileTree Stage2
-> Namespace
-> ST.StructDecl
-> m StructDecl
validateStruct symbolTables currentNamespace struct =
resetContext $
validating (qualify currentNamespace struct) $ do
validStructs <- get
case Map.lookup (currentNamespace, getIdent struct) validStructs of
Just match -> pure match
Nothing -> do
checkDuplicateFields
checkUndeclaredAttributes struct
fields <- traverse validateStructField (ST.structFields struct)
let naturalAlignment = maximum (structFieldAlignment <$> fields)
forceAlignAttrVal <- getForceAlignAttr
forceAlign <- traverse (validateForceAlign naturalAlignment) forceAlignAttrVal
let alignment = fromMaybe naturalAlignment forceAlign
let (size, paddedFields) = addFieldPadding alignment fields
let validStruct = StructDecl
{ structIdent = getIdent struct
, structAlignment = alignment
, structSize = size
, structFields = paddedFields
}
modify (Map.insert (currentNamespace, getIdent validStruct) validStruct)
pure validStruct
where
invalidStructFieldType = "struct fields may only be integers, floating point, bool, enums, or other structs"
addFieldPadding :: Alignment -> NonEmpty UnpaddedStructField -> (InlineSize, NonEmpty StructField)
addFieldPadding structAlignment unpaddedFields =
(size, NE.fromList (reverse paddedFields))
where
(size, paddedFields) = go 0 [] (NE.toList unpaddedFields)
go :: InlineSize -> [StructField] -> [UnpaddedStructField] -> (InlineSize, [StructField])
go size paddedFields [] = (size, paddedFields)
go size paddedFields (x : y : tail) =
let size' = size + structFieldTypeSize (unpaddedStructFieldType x)
nextFieldsAlignment = fromIntegral @Alignment @InlineSize (structFieldAlignment y)
paddingNeeded = (size' `roundUpToNearestMultipleOf` nextFieldsAlignment) - size'
size'' = size' + paddingNeeded
paddedField = StructField
{ structFieldIdent = unpaddedStructFieldIdent x
, structFieldPadding = fromIntegral @InlineSize @Word8 paddingNeeded
, structFieldOffset = coerce size
, structFieldType = unpaddedStructFieldType x
}
in go size'' (paddedField : paddedFields) (y : tail)
go size paddedFields [x] =
let size' = size + structFieldTypeSize (unpaddedStructFieldType x)
structAlignment' = fromIntegral @Alignment @InlineSize structAlignment
paddingNeeded = (size' `roundUpToNearestMultipleOf` structAlignment') - size'
size'' = size' + paddingNeeded
paddedField = StructField
{ structFieldIdent = unpaddedStructFieldIdent x
, structFieldPadding = fromIntegral @InlineSize @Word8 paddingNeeded
, structFieldOffset = coerce size
, structFieldType = unpaddedStructFieldType x
}
in (size'', paddedField : paddedFields)
validateStructField :: ST.StructField -> m UnpaddedStructField
validateStructField sf =
validating sf $ do
checkUnsupportedAttributes sf
checkUndeclaredAttributes sf
structFieldType <- validateStructFieldType (ST.structFieldType sf)
pure $ UnpaddedStructField
{ unpaddedStructFieldIdent = getIdent sf
, unpaddedStructFieldType = structFieldType
}
validateStructFieldType :: ST.Type -> m StructFieldType
validateStructFieldType structFieldType =
case structFieldType of
ST.TInt8 -> pure SInt8
ST.TInt16 -> pure SInt16
ST.TInt32 -> pure SInt32
ST.TInt64 -> pure SInt64
ST.TWord8 -> pure SWord8
ST.TWord16 -> pure SWord16
ST.TWord32 -> pure SWord32
ST.TWord64 -> pure SWord64
ST.TFloat -> pure SFloat
ST.TDouble -> pure SDouble
ST.TBool -> pure SBool
ST.TString -> throwErrorMsg invalidStructFieldType
ST.TVector _ -> throwErrorMsg invalidStructFieldType
ST.TRef typeRef ->
findDecl currentNamespace symbolTables typeRef >>= \case
MatchE enumNamespace enum ->
pure (SEnum (TypeRef enumNamespace (getIdent enum)) (enumType enum))
MatchS nestedNamespace nestedStruct -> do
validNestedStruct <- validateStruct symbolTables nestedNamespace nestedStruct
pure $ SStruct (nestedNamespace, validNestedStruct)
_ -> throwErrorMsg invalidStructFieldType
checkUnsupportedAttributes :: ST.StructField -> m ()
checkUnsupportedAttributes structField = do
when (hasAttribute deprecatedAttr (ST.structFieldMetadata structField)) $
throwErrorMsg "can't deprecate fields in a struct"
when (hasAttribute requiredAttr (ST.structFieldMetadata structField)) $
throwErrorMsg "struct fields are already required, the 'required' attribute is redundant"
when (hasAttribute idAttr (ST.structFieldMetadata structField)) $
throwErrorMsg "struct fields cannot be reordered using the 'id' attribute"
getForceAlignAttr :: m (Maybe Integer)
getForceAlignAttr = findIntAttr forceAlignAttr (ST.structMetadata struct)
validateForceAlign :: Alignment -> Integer -> m Alignment
validateForceAlign naturalAlignment forceAlign =
if isPowerOfTwo forceAlign
&& inRange (fromIntegral @Alignment @Integer naturalAlignment, 16) forceAlign
then pure (fromIntegral @Integer @Alignment forceAlign)
else throwErrorMsg $
"force_align must be a power of two integer ranging from the struct's natural alignment (in this case, "
<> display naturalAlignment
<> ") to 16"
checkDuplicateFields :: m ()
checkDuplicateFields =
checkDuplicateIdentifiers
(ST.structFields struct)
structFieldAlignment :: UnpaddedStructField -> Alignment
structFieldAlignment usf =
case unpaddedStructFieldType usf of
SInt8 -> int8Size
SInt16 -> int16Size
SInt32 -> int32Size
SInt64 -> int64Size
SWord8 -> word8Size
SWord16 -> word16Size
SWord32 -> word32Size
SWord64 -> word64Size
SFloat -> floatSize
SDouble -> doubleSize
SBool -> boolSize
SEnum _ enumType -> enumAlignment enumType
SStruct (_, nestedStruct) -> structAlignment nestedStruct
enumAlignment :: EnumType -> Alignment
enumAlignment = Alignment . enumSize
enumSize :: EnumType -> Word8
enumSize e =
case e of
EInt8 -> int8Size
EInt16 -> int16Size
EInt32 -> int32Size
EInt64 -> int64Size
EWord8 -> word8Size
EWord16 -> word16Size
EWord32 -> word32Size
EWord64 -> word64Size
structFieldTypeSize :: StructFieldType -> InlineSize
structFieldTypeSize sft =
case sft of
SInt8 -> int8Size
SInt16 -> int16Size
SInt32 -> int32Size
SInt64 -> int64Size
SWord8 -> word8Size
SWord16 -> word16Size
SWord32 -> word32Size
SWord64 -> word64Size
SFloat -> floatSize
SDouble -> doubleSize
SBool -> boolSize
SEnum _ enumType -> fromIntegral @Word8 @InlineSize (enumSize enumType)
SStruct (_, nestedStruct) -> structSize nestedStruct
checkDuplicateIdentifiers :: (MonadValidation m, Foldable f, Functor f, HasIdent a) => f a -> m ()
checkDuplicateIdentifiers xs =
case findDups (getIdent <$> xs) of
[] -> pure ()
dups ->
throwErrorMsg $
display dups <> " declared more than once"
where
findDups :: (Foldable f, Functor f, Ord a) => f a -> [a]
findDups xs = Map.keys $ Map.filter (>1) $ occurrences xs
occurrences :: (Foldable f, Functor f, Ord a) => f a -> Map a (Sum Int)
occurrences xs =
Map.unionsWith (<>) $ Foldable.toList $ fmap (\x -> Map.singleton x (Sum 1)) xs
checkUndeclaredAttributes :: (MonadValidation m, HasMetadata a) => a -> m ()
checkUndeclaredAttributes a = do
allAttributes <- getDeclaredAttributes
forM_ (Map.keys . ST.unMetadata . getMetadata $ a) $ \attr ->
when (coerce attr `Set.notMember` allAttributes) $
throwErrorMsg $ "user defined attributes must be declared before use: " <> display attr
hasAttribute :: Text -> ST.Metadata -> Bool
hasAttribute name (ST.Metadata attrs) = Map.member name attrs
findIntAttr :: MonadValidation m => Text -> ST.Metadata -> m (Maybe Integer)
findIntAttr name (ST.Metadata attrs) =
case Map.lookup name attrs of
Nothing -> pure Nothing
Just Nothing -> err
Just (Just (ST.AttrI i)) -> pure (Just i)
Just (Just (ST.AttrS t)) ->
case readMaybe @Integer (T.unpack t) of
Just i -> pure (Just i)
Nothing -> err
where
err =
throwErrorMsg $
"expected attribute '"
<> display name
<> "' to have an integer value, e.g. '"
<> display name
<> ": 123'"
findStringAttr :: Text -> ST.Metadata -> Validation (Maybe Text)
findStringAttr name (ST.Metadata attrs) =
case Map.lookup name attrs of
Nothing -> pure Nothing
Just (Just (ST.AttrS s)) -> pure (Just s)
Just _ ->
throwErrorMsg $
"expected attribute '"
<> display name
<> "' to have a string value, e.g. '"
<> display name
<> ": \"abc\"'"
isPowerOfTwo :: (Num a, Bits a) => a -> Bool
isPowerOfTwo 0 = False
isPowerOfTwo n = (n .&. (n - 1)) == 0
roundUpToNearestMultipleOf :: Integral n => n -> n -> n
roundUpToNearestMultipleOf x y =
case x `rem` y of
0 -> x
remainder -> (y - remainder) + x