{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TypeFamilies #-}
module Capnp.Message (
hPutMsg
, hGetMsg
, putMsg
, getMsg
, readMessage
, writeMessage
, maxSegmentSize
, maxSegments
, maxCaps
, encode
, decode
, Message(..)
, ConstMsg
, empty
, getSegment
, getWord
, getCap
, getCapTable
, MutMsg
, newMessage
, alloc
, allocInSeg
, newSegment
, setSegment
, setWord
, setCap
, appendCap
, WriteCtx
, Client
, nullClient
, withCapTable
) where
import {-# SOURCE #-} Capnp.Rpc.Untyped (Client, nullClient)
import Prelude hiding (read)
import Data.Bits (shiftL)
import Control.Monad (void, when, (>=>))
import Control.Monad.Catch (MonadThrow(..))
import Control.Monad.Primitive (PrimMonad, PrimState)
import Control.Monad.State (evalStateT, get, put)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Writer (execWriterT, tell)
import Data.Bytes.Get (getWord32le, runGetS)
import Data.ByteString.Internal (ByteString(..))
import Data.Maybe (fromJust)
import Data.Primitive (MutVar, newMutVar, readMutVar, writeMutVar)
import Data.Word (Word32, Word64)
import System.Endian (fromLE64, toLE64)
import System.IO (Handle, stdin, stdout)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as BB
import qualified Data.Vector as V
import qualified Data.Vector.Generic.Mutable as GMV
import qualified Data.Vector.Mutable as MV
import qualified Data.Vector.Storable as SV
import qualified Data.Vector.Storable.Mutable as SMV
import Capnp.Address (WordAddr(..))
import Capnp.Bits (WordCount(..), hi, lo)
import Capnp.TraversalLimit (LimitT, MonadLimit(invoice), evalLimitT)
import Data.Mutable (Mutable(..))
import Internal.AppendVec (AppendVec)
import qualified Capnp.Errors as E
import qualified Internal.AppendVec as AppendVec
maxSegmentSize :: Int
maxSegmentSize = 1 `shiftL` 28
maxSegments :: Int
maxSegments = 1024
maxCaps :: Int
maxCaps = 512
class Monad m => Message m msg where
data Segment msg
numSegs :: msg -> m Int
numWords :: Segment msg -> m WordCount
numCaps :: msg -> m Int
internalGetSeg :: msg -> Int -> m (Segment msg)
internalGetCap :: msg -> Int -> m Client
slice :: WordCount -> WordCount -> Segment msg -> m (Segment msg)
read :: Segment msg -> WordCount -> m Word64
fromByteString :: ByteString -> m (Segment msg)
toByteString :: Segment msg -> m ByteString
getSegment :: (MonadThrow m, Message m msg) => msg -> Int -> m (Segment msg)
getSegment msg i = do
checkIndex i =<< numSegs msg
internalGetSeg msg i
withCapTable :: V.Vector Client -> ConstMsg -> ConstMsg
withCapTable newCaps msg = msg { constCaps = newCaps }
getCapTable :: ConstMsg -> V.Vector Client
getCapTable = constCaps
getCap :: (MonadThrow m, Message m msg) => msg -> Int -> m Client
getCap msg i = do
ncaps <- numCaps msg
if i >= ncaps || i < 0
then pure nullClient
else msg `internalGetCap` i
getWord :: (MonadThrow m, Message m msg) => msg -> WordAddr -> m Word64
getWord msg WordAt{wordIndex=i, segIndex} = do
seg <- getSegment msg segIndex
checkIndex i =<< numWords seg
seg `read` i
setSegment :: (WriteCtx m s, MonadThrow m) => MutMsg s -> Int -> Segment (MutMsg s) -> m ()
setSegment msg i seg = do
checkIndex i =<< numSegs msg
internalSetSeg msg i seg
setWord :: (WriteCtx m s, MonadThrow m) => MutMsg s -> WordAddr -> Word64 -> m ()
setWord msg WordAt{wordIndex=i, segIndex} val = do
seg <- getSegment msg segIndex
checkIndex i =<< numWords seg
write seg i val
setCap :: (WriteCtx m s, MonadThrow m) => MutMsg s -> Int -> Client -> m ()
setCap msg@MutMsg{mutCaps} i cap = do
checkIndex i =<< numCaps msg
capTable <- AppendVec.getVector <$> readMutVar mutCaps
MV.write capTable i cap
appendCap :: WriteCtx m s => MutMsg s -> Client -> m Int
appendCap msg@MutMsg{mutCaps} cap = do
i <- numCaps msg
capTable <- readMutVar mutCaps
capTable <- AppendVec.grow capTable 1 maxCaps
writeMutVar mutCaps capTable
setCap msg i cap
pure i
data ConstMsg = ConstMsg
{ constSegs :: V.Vector (Segment ConstMsg)
, constCaps :: V.Vector Client
}
deriving(Eq)
instance Monad m => Message m ConstMsg where
newtype Segment ConstMsg = ConstSegment { constSegToVec :: SV.Vector Word64 }
deriving(Eq)
numSegs ConstMsg{constSegs} = pure $ V.length constSegs
numCaps ConstMsg{constCaps} = pure $ V.length constCaps
internalGetSeg ConstMsg{constSegs} i = constSegs `V.indexM` i
internalGetCap ConstMsg{constCaps} i = constCaps `V.indexM` i
numWords (ConstSegment vec) = pure $ WordCount $ SV.length vec
slice (WordCount start) (WordCount len) (ConstSegment vec) =
pure $ ConstSegment (SV.slice start len vec)
read (ConstSegment vec) i = fromLE64 <$> vec `SV.indexM` fromIntegral i
fromByteString (PS fptr offset len) =
pure $ ConstSegment (SV.unsafeCast $ SV.unsafeFromForeignPtr fptr offset len)
toByteString (ConstSegment vec) = pure $ PS fptr offset len where
(fptr, offset, len) = SV.unsafeToForeignPtr (SV.unsafeCast vec)
decode :: MonadThrow m => ByteString -> m ConstMsg
decode bytes = fromByteString bytes >>= decodeSeg
encode :: Monad m => ConstMsg -> m BB.Builder
encode msg =
pure $ fromJust $ execWriterT $ writeMessage
msg
(tell . BB.word32LE)
(toByteString >=> tell . BB.byteString)
decodeSeg :: MonadThrow m => Segment ConstMsg -> m ConstMsg
decodeSeg seg = do
len <- numWords seg
flip evalStateT (Nothing, 0) $ evalLimitT len $
readMessage read32 readSegment
where
read32 = do
(cur, idx) <- get
case cur of
Just n -> do
put (Nothing, idx)
return n
Nothing -> do
word <- lift $ lift $ read seg idx
put (Just $ hi word, idx + 1)
return (lo word)
readSegment len = do
(cur, idx) <- get
put (cur, idx + len)
lift $ lift $ slice idx len seg
readMessage :: (MonadThrow m, MonadLimit m) => m Word32 -> (WordCount -> m (Segment ConstMsg)) -> m ConstMsg
readMessage read32 readSegment = do
invoice 1
numSegs' <- read32
let numSegs = numSegs' + 1
invoice (fromIntegral numSegs `div` 2)
segSizes <- V.replicateM (fromIntegral numSegs) read32
when (numSegs `mod` 2 == 0) $ void read32
V.mapM_ (invoice . fromIntegral) segSizes
constSegs <- V.mapM (readSegment . fromIntegral) segSizes
pure ConstMsg{constSegs, constCaps = V.empty}
writeMessage :: MonadThrow m => ConstMsg -> (Word32 -> m ()) -> (Segment ConstMsg -> m ()) -> m ()
writeMessage ConstMsg{constSegs} write32 writeSegment = do
let numSegs = V.length constSegs
write32 (fromIntegral numSegs - 1)
V.forM_ constSegs $ \seg -> write32 =<< fromIntegral <$> numWords seg
when (numSegs `mod` 2 == 0) $ write32 0
V.forM_ constSegs writeSegment
hPutMsg :: Handle -> ConstMsg -> IO ()
hPutMsg handle msg = encode msg >>= BB.hPutBuilder handle
putMsg :: ConstMsg -> IO ()
putMsg = hPutMsg stdout
hGetMsg :: Handle -> WordCount -> IO ConstMsg
hGetMsg handle size =
evalLimitT size $ readMessage read32 readSegment
where
read32 :: LimitT IO Word32
read32 = lift $ do
bytes <- BS.hGet handle 4
case runGetS getWord32le bytes of
Left _ ->
throwM $ E.InvalidDataError "Unexpected end of input"
Right result ->
pure result
readSegment n = lift $ BS.hGet handle (fromIntegral n * 8) >>= fromByteString
getMsg :: WordCount -> IO ConstMsg
getMsg = hGetMsg stdin
data MutMsg s = MutMsg
{ mutSegs :: MutVar s (AppendVec MV.MVector s (Segment (MutMsg s)))
, mutCaps :: MutVar s (AppendVec MV.MVector s Client)
}
deriving(Eq)
type WriteCtx m s = (PrimMonad m, s ~ PrimState m, MonadThrow m)
instance (PrimMonad m, s ~ PrimState m) => Message m (MutMsg s) where
newtype Segment (MutMsg s) = MutSegment (AppendVec SMV.MVector s Word64)
numWords (MutSegment mseg) = pure $ WordCount $ GMV.length (AppendVec.getVector mseg)
slice (WordCount start) (WordCount len) (MutSegment mseg) =
pure $ MutSegment $ AppendVec.fromVector $
SMV.slice start len (AppendVec.getVector mseg)
read (MutSegment mseg) i = fromLE64 <$> SMV.read (AppendVec.getVector mseg) (fromIntegral i)
fromByteString bytes = do
vec <- constSegToVec <$> fromByteString bytes
MutSegment . AppendVec.fromVector <$> SV.thaw vec
toByteString mseg = do
seg <- freeze mseg
toByteString (seg :: Segment ConstMsg)
numSegs MutMsg{mutSegs} = GMV.length . AppendVec.getVector <$> readMutVar mutSegs
numCaps MutMsg{mutCaps} = GMV.length . AppendVec.getVector <$> readMutVar mutCaps
internalGetSeg MutMsg{mutSegs} i = do
segs <- AppendVec.getVector <$> readMutVar mutSegs
MV.read segs i
internalGetCap MutMsg{mutCaps} i = do
caps <- AppendVec.getVector <$> readMutVar mutCaps
MV.read caps i
internalSetSeg :: WriteCtx m s => MutMsg s -> Int -> Segment (MutMsg s) -> m ()
internalSetSeg MutMsg{mutSegs} segIndex seg = do
segs <- AppendVec.getVector <$> readMutVar mutSegs
MV.write segs segIndex seg
write :: WriteCtx m s => Segment (MutMsg s) -> WordCount -> Word64 -> m ()
write (MutSegment seg) (WordCount i) val =
SMV.write (AppendVec.getVector seg) i (toLE64 val)
grow :: WriteCtx m s => Segment (MutMsg s) -> Int -> m (Segment (MutMsg s))
grow (MutSegment vec) amount =
MutSegment <$> AppendVec.grow vec amount maxSegmentSize
newSegment :: WriteCtx m s => MutMsg s -> Int -> m (Int, Segment (MutMsg s))
newSegment msg@MutMsg{mutSegs} sizeHint = do
segIndex <- numSegs msg
segs <- readMutVar mutSegs
segs <- AppendVec.grow segs 1 maxSegments
writeMutVar mutSegs segs
newSeg <- MutSegment . AppendVec.makeEmpty <$> SMV.new sizeHint
setSegment msg segIndex newSeg
pure (segIndex, newSeg)
allocInSeg :: WriteCtx m s => MutMsg s -> Int -> WordCount -> m WordAddr
allocInSeg msg segIndex (WordCount size) = do
oldSeg@(MutSegment vec) <- getSegment msg segIndex
let ret = WordAt
{ segIndex
, wordIndex = WordCount $ GMV.length $ AppendVec.getVector vec
}
newSeg <- grow oldSeg size
setSegment msg segIndex newSeg
pure ret
alloc :: WriteCtx m s => MutMsg s -> WordCount -> m WordAddr
alloc msg size@(WordCount sizeInt) = do
segIndex <- pred <$> numSegs msg
MutSegment vec <- getSegment msg segIndex
if AppendVec.canGrowWithoutCopy vec sizeInt
then
allocInSeg msg segIndex size
else do
segments <- readMutVar (mutSegs msg)
segs <- V.freeze (AppendVec.getVector segments)
let totalAllocation = V.sum $ fmap (\(MutSegment vec) -> AppendVec.getCapacity vec) segs
( newSegIndex, _ ) <- newSegment msg (min maxSegmentSize (max totalAllocation sizeInt))
allocInSeg msg newSegIndex size
empty :: ConstMsg
empty = ConstMsg
{ constSegs = V.fromList [ ConstSegment $ SV.fromList [0] ]
, constCaps = V.empty
}
newMessage :: WriteCtx m s => Maybe WordCount -> m (MutMsg s)
newMessage Nothing = newMessage (Just 32)
newMessage (Just (WordCount sizeHint)) = do
mutSegs <- MV.new 1 >>= newMutVar . AppendVec.makeEmpty
mutCaps <- MV.new 0 >>= newMutVar . AppendVec.makeEmpty
let msg = MutMsg{mutSegs,mutCaps}
_ <- newSegment msg sizeHint
_ <- alloc msg 1
pure msg
instance Thaw (Segment ConstMsg) where
type Mutable s (Segment ConstMsg) = Segment (MutMsg s)
thaw = thawSeg thaw
unsafeThaw = thawSeg unsafeThaw
freeze = freezeSeg freeze
unsafeFreeze = freezeSeg unsafeFreeze
thawSeg
:: (PrimMonad m, s ~ PrimState m)
=> (AppendVec.FrozenAppendVec SV.Vector Word64 -> m (AppendVec SMV.MVector s Word64))
-> Segment ConstMsg
-> m (Segment (MutMsg s))
thawSeg thaw (ConstSegment vec) =
MutSegment <$> thaw (AppendVec.FrozenAppendVec vec)
freezeSeg
:: (PrimMonad m, s ~ PrimState m)
=> (AppendVec SMV.MVector s Word64 -> m (AppendVec.FrozenAppendVec SV.Vector Word64))
-> Segment (MutMsg s)
-> m (Segment ConstMsg)
freezeSeg freeze (MutSegment mvec) =
ConstSegment . AppendVec.getFrozenVector <$> freeze mvec
instance Thaw ConstMsg where
type Mutable s ConstMsg = MutMsg s
thaw = thawMsg thaw V.thaw
unsafeThaw = thawMsg unsafeThaw V.unsafeThaw
freeze = freezeMsg freeze V.freeze
unsafeFreeze = freezeMsg unsafeFreeze V.unsafeFreeze
thawMsg :: (PrimMonad m, s ~ PrimState m)
=> (Segment ConstMsg -> m (Segment (MutMsg s)))
-> (V.Vector Client -> m (MV.MVector s Client))
-> ConstMsg
-> m (MutMsg s)
thawMsg thawSeg thawCaps ConstMsg{constSegs, constCaps}= do
mutSegs <- newMutVar . AppendVec.fromVector =<< (V.mapM thawSeg constSegs >>= V.unsafeThaw)
mutCaps <- newMutVar . AppendVec.fromVector =<< thawCaps constCaps
pure MutMsg{mutSegs, mutCaps}
freezeMsg :: (PrimMonad m, s ~ PrimState m)
=> (Segment (MutMsg s) -> m (Segment ConstMsg))
-> (MV.MVector s Client -> m (V.Vector Client))
-> MutMsg s
-> m ConstMsg
freezeMsg freezeSeg freezeCaps msg@MutMsg{mutCaps} = do
len <- numSegs msg
constSegs <- V.generateM len (internalGetSeg msg >=> freezeSeg)
constCaps <- freezeCaps . AppendVec.getVector =<< readMutVar mutCaps
pure ConstMsg{constSegs, constCaps}
checkIndex :: (Integral a, MonadThrow m) => a -> a -> m ()
checkIndex i len =
when (i < 0 || i >= len) $
throwM E.BoundsError
{ E.index = fromIntegral i
, E.maxIndex = fromIntegral len
}