-- | Data null-padded to a given length.

{-# LANGUAGE OverloadedStrings #-}

module Binrep.Type.NullPadded where

import Binrep
import Bytezap.Poke qualified as BZ
import Bytezap.Struct qualified as BZ.Struct
import FlatParse.Basic qualified as FP
import Raehik.Compat.FlatParse.Basic.WithLength qualified as FP
import Control.Monad.Combinators ( skipCount )

import Binrep.Util ( tshow )

import Refined
import Refined.Unsafe

import GHC.TypeNats
import Util.TypeNats ( natValInt )

import Data.Typeable ( typeRep )

import Bytezap.Parser.Struct qualified as BZG
import GHC.Exts ( Int(I#) )

data NullPad (n :: Natural)

{- | A type which is to be null-padded to a given total length.

Given some @a :: 'NullPadded' n a@, it is guaranteed that

@
'blen' a '<=' 'natValInt' \@n
@

thus

@
'natValInt' \@n '-' 'blen' a '>=' 0
@

That is, the serialized stored data will not be longer than the total length.
-}
type NullPadded n a = Refined (NullPad n) a

instance IsCBLen (NullPadded n a) where type CBLen (NullPadded n a) = n
deriving via ViaCBLen (NullPadded n a) instance KnownNat n => BLen (NullPadded n a)

-- | Assert that term will fit.
instance (BLen a, KnownNat n) => Predicate (NullPad n) a where
    validate :: Proxy (NullPad n) -> a -> Maybe RefineException
validate Proxy (NullPad n)
p a
a
      | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n = Maybe RefineException
success
      | Bool
otherwise
          = TypeRep -> Text -> Maybe RefineException
throwRefineOtherException (Proxy (NullPad n) -> TypeRep
forall {k} (proxy :: k -> Type) (a :: k).
Typeable a =>
proxy a -> TypeRep
typeRep Proxy (NullPad n)
p) (Text -> Maybe RefineException) -> Text -> Maybe RefineException
forall a b. (a -> b) -> a -> b
$
                   Text
"too long: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Show a => a -> Text
tshow Int
len Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" > " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Show a => a -> Text
tshow Int
n
      where
        n :: Int
n = forall (n :: Natural). KnownNat n => Int
natValInt @n
        len :: Int
len = a -> Int
forall a. BLen a => a -> Int
blen a
a

instance (BLen a, KnownNat n, Put a) => PutC (NullPadded n a) where
    putC :: NullPadded n a -> PutterC
putC NullPadded n a
ra = PutterC -> Int -> PutterC -> PutterC
forall s. Poke s -> Int -> Poke s -> Poke s
BZ.Struct.sequencePokes (Poke RealWorld -> PutterC
forall s. Poke s -> Poke s
BZ.toStructPoke (a -> Poke RealWorld
forall a. Put a => a -> Poke RealWorld
put a
a)) Int
len
        (Int -> Word8 -> PutterC
BZ.Struct.replicateByte Int
paddingLen Word8
0x00)
      where
        a :: a
a = NullPadded n a -> a
forall {k} (p :: k) x. Refined p x -> x
unrefine NullPadded n a
ra
        len :: Int
len = a -> Int
forall a. BLen a => a -> Int
blen a
a
        paddingLen :: Int
paddingLen = forall (n :: Natural). KnownNat n => Int
natValInt @n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len
        -- ^ refinement guarantees >=0

instance (BLen a, KnownNat n, Put a) => Put (NullPadded n a) where
    put :: NullPadded n a -> Poke RealWorld
put NullPadded n a
ra = a -> Poke RealWorld
forall a. Put a => a -> Poke RealWorld
put a
a Poke RealWorld -> Poke RealWorld -> Poke RealWorld
forall a. Semigroup a => a -> a -> a
<> Int -> Word8 -> Poke RealWorld
BZ.replicateByte Int
paddingLen Word8
0x00
      where
        a :: a
a = NullPadded n a -> a
forall {k} (p :: k) x. Refined p x -> x
unrefine NullPadded n a
ra
        paddingLen :: Int
paddingLen = forall (n :: Natural). KnownNat n => Int
natValInt @n Int -> Int -> Int
forall a. Num a => a -> a -> a
- a -> Int
forall a. BLen a => a -> Int
blen a
a
        -- ^ refinement guarantees >=0

-- | Run a @Getter a@ isolated to @n@ bytes.
instance (KnownNat n, Get a) => GetC (NullPadded n a) where
    getC :: GetterC (NullPadded n a)
getC = ParserT PureMode E a
-> Int#
-> (a -> Int# -> GetterC (NullPadded n a))
-> GetterC (NullPadded n a)
forall (st :: ZeroBitType) e a r.
ParserT st e a
-> Int# -> (a -> Int# -> ParserT st e r) -> ParserT st e r
fpToBz ParserT PureMode E a
forall a. Get a => Getter a
get Int#
len# ((a -> Int# -> GetterC (NullPadded n a))
 -> GetterC (NullPadded n a))
-> (a -> Int# -> GetterC (NullPadded n a))
-> GetterC (NullPadded n a)
forall a b. (a -> b) -> a -> b
$ \a
a Int#
_unconsumed# ->
        -- TODO consume nulls lol
        NullPadded n a -> GetterC (NullPadded n a)
forall a (st :: ZeroBitType) e. a -> ParserT st e a
BZG.constParse (NullPadded n a -> GetterC (NullPadded n a))
-> NullPadded n a -> GetterC (NullPadded n a)
forall a b. (a -> b) -> a -> b
$ a -> NullPadded n a
forall {k} x (p :: k). x -> Refined p x
reallyUnsafeRefine a
a
      where
        !(I# Int#
len#) = forall (n :: Natural). KnownNat n => Int
natValInt @n

instance (Get a, KnownNat n) => Get (NullPadded n a) where
    get :: Getter (NullPadded n a)
get = do
        (a
a, Int
len) <- ParserT PureMode E a -> ParserT PureMode E (a, Int)
forall (st :: ZeroBitType) e a.
ParserT st e a -> ParserT st e (a, Int)
FP.parseWithLength ParserT PureMode E a
forall a. Get a => Getter a
get
        let paddingLen :: Int
paddingLen = forall (n :: Natural). KnownNat n => Int
natValInt @n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len
        if   Int
paddingLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0
        then EBase -> Getter (NullPadded n a)
forall a. EBase -> Getter a
eBase (EBase -> Getter (NullPadded n a))
-> EBase -> Getter (NullPadded n a)
forall a b. (a -> b) -> a -> b
$ String -> EBase
EFailNamed String
"TODO used to be EOverlong, cba"
        else do
            Int -> ParserT PureMode E () -> ParserT PureMode E ()
forall (m :: Type -> Type) a. Monad m => Int -> m a -> m ()
skipCount Int
paddingLen (Word8 -> ParserT PureMode E ()
forall (st :: ZeroBitType) e. Word8 -> ParserT st e ()
FP.word8 Word8
0x00)
            NullPadded n a -> Getter (NullPadded n a)
forall a. a -> ParserT PureMode E a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (NullPadded n a -> Getter (NullPadded n a))
-> NullPadded n a -> Getter (NullPadded n a)
forall a b. (a -> b) -> a -> b
$ a -> NullPadded n a
forall {k} x (p :: k). x -> Refined p x
reallyUnsafeRefine a
a