{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
module Data.ProtoLens.Message (
Message(..),
Tag(..),
allFields,
FieldDescriptor(..),
fieldDescriptorName,
isRequired,
FieldAccessor(..),
WireDefault(..),
Packing(..),
FieldTypeDescriptor(..),
ScalarField(..),
MessageOrGroup(..),
FieldDefault(..),
MessageEnum(..),
Default(..),
build,
Registry,
register,
lookupRegistered,
SomeMessageType(..),
matchAnyMessage,
AnyMessageDescriptor(..),
maybeLens,
reverseRepeatedFields,
FieldSet,
TaggedValue(..),
discardUnknownFields,
) where
import qualified Data.ByteString as B
import Data.Default.Class
import Data.Int
import qualified Data.Map as Map
import Data.Map (Map)
import Data.Maybe (fromMaybe)
import Data.Proxy (Proxy(..))
import qualified Data.Text as T
import Data.Word
import Lens.Family2 (Lens', over, set)
import Lens.Family2.Unchecked (lens)
import qualified Data.Semigroup as Semigroup
import Data.ProtoLens.Encoding.Wire
( Tag(..)
, TaggedValue(..)
)
class Default msg => Message msg where
messageName :: Proxy msg -> T.Text
fieldsByTag :: Map Tag (FieldDescriptor msg)
fieldsByTextFormatName :: Map String (FieldDescriptor msg)
fieldsByTextFormatName =
Map.fromList [(n, f) | f@(FieldDescriptor n _ _) <- allFields]
unknownFields :: Lens' msg FieldSet
allFields :: Message msg => [FieldDescriptor msg]
allFields = Map.elems fieldsByTag
type FieldSet = [TaggedValue]
data FieldDescriptor msg where
FieldDescriptor :: String
-> FieldTypeDescriptor value -> FieldAccessor msg value
-> FieldDescriptor msg
fieldDescriptorName :: FieldDescriptor msg -> String
fieldDescriptorName (FieldDescriptor name _ _) = name
isRequired :: FieldDescriptor msg -> Bool
isRequired (FieldDescriptor _ _ (PlainField Required _)) = True
isRequired _ = False
data FieldAccessor msg value where
PlainField :: WireDefault value -> Lens' msg value
-> FieldAccessor msg value
OptionalField :: Lens' msg (Maybe value) -> FieldAccessor msg value
RepeatedField :: Packing -> Lens' msg [value] -> FieldAccessor msg value
MapField :: (Ord key, Message entry) => Lens' entry key -> Lens' entry value
-> Lens' msg (Map key value) -> FieldAccessor msg entry
data WireDefault value where
Required :: WireDefault value
Optional :: (FieldDefault value, Eq value) => WireDefault value
class FieldDefault value where
fieldDefault :: value
instance FieldDefault Bool where
fieldDefault = False
instance FieldDefault Int32 where
fieldDefault = 0
instance FieldDefault Int64 where
fieldDefault = 0
instance FieldDefault Word32 where
fieldDefault = 0
instance FieldDefault Word64 where
fieldDefault = 0
instance FieldDefault Float where
fieldDefault = 0
instance FieldDefault Double where
fieldDefault = 0
instance FieldDefault B.ByteString where
fieldDefault = B.empty
instance FieldDefault T.Text where
fieldDefault = T.empty
data Packing = Packed | Unpacked
data FieldTypeDescriptor value where
MessageField :: Message value => MessageOrGroup -> FieldTypeDescriptor value
ScalarField :: ScalarField value -> FieldTypeDescriptor value
deriving instance Show (FieldTypeDescriptor value)
data MessageOrGroup = MessageType | GroupType
deriving Show
data ScalarField t where
EnumField :: MessageEnum value => ScalarField value
Int32Field :: ScalarField Int32
Int64Field :: ScalarField Int64
UInt32Field :: ScalarField Word32
UInt64Field :: ScalarField Word64
SInt32Field :: ScalarField Int32
SInt64Field :: ScalarField Int64
Fixed32Field :: ScalarField Word32
Fixed64Field :: ScalarField Word64
SFixed32Field :: ScalarField Int32
SFixed64Field :: ScalarField Int64
FloatField :: ScalarField Float
DoubleField :: ScalarField Double
BoolField :: ScalarField Bool
StringField :: ScalarField T.Text
BytesField :: ScalarField B.ByteString
deriving instance Show (ScalarField value)
matchAnyMessage :: forall value . FieldTypeDescriptor value -> Maybe (AnyMessageDescriptor value)
matchAnyMessage (MessageField _)
| messageName (Proxy @value) == "google.protobuf.Any"
, Just (FieldDescriptor _ (ScalarField StringField) (PlainField Optional typeUrlLens))
<- Map.lookup 1 (fieldsByTag @value)
, Just (FieldDescriptor _ (ScalarField BytesField) (PlainField Optional valueLens))
<- Map.lookup 2 (fieldsByTag @value)
= Just $ AnyMessageDescriptor typeUrlLens valueLens
matchAnyMessage _ = Nothing
data AnyMessageDescriptor msg
= AnyMessageDescriptor
{ anyTypeUrlLens :: Lens' msg T.Text
, anyValueLens :: Lens' msg B.ByteString
}
class (Enum a, Bounded a) => MessageEnum a where
maybeToEnum :: Int -> Maybe a
showEnum :: a -> String
readEnum :: String -> Maybe a
build :: Default a => (a -> a) -> a
build = ($ def)
maybeLens :: b -> Lens' (Maybe b) b
maybeLens x = lens (fromMaybe x) $ const Just
reverseRepeatedFields :: Map k (FieldDescriptor msg) -> msg -> msg
reverseRepeatedFields fields x0
= Map.foldl' reverseListField x0 fields
where
reverseListField :: a -> FieldDescriptor a -> a
reverseListField x (FieldDescriptor _ _ (RepeatedField _ f))
= over f reverse x
reverseListField x _ = x
newtype Registry = Registry (Map.Map T.Text SomeMessageType)
deriving (Semigroup.Semigroup, Monoid)
register :: forall msg . Message msg => Proxy msg -> Registry
register p = Registry $ Map.singleton (messageName (Proxy @msg)) (SomeMessageType p)
lookupRegistered :: T.Text -> Registry -> Maybe SomeMessageType
lookupRegistered n (Registry m) = Map.lookup (snd $ T.breakOnEnd "/" n) m
data SomeMessageType where
SomeMessageType :: Message msg => Proxy msg -> SomeMessageType
discardUnknownFields :: Message msg => msg -> msg
discardUnknownFields = set unknownFields []