{-# OPTIONS_GHC -fplugin GHC.TypeLits.Extra.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Finitary.PackWords
(
PackWords(.., Packed)
, intoWords, outOfWords
)
where
import Data.Kind (Type)
import Data.Hashable (Hashable(..), hashByteArrayWithSalt)
import Foreign.Storable (Storable(..))
import GHC.Exts
import GHC.IO
import GHC.Natural (Natural(..))
import GHC.TypeNats
import qualified Data.Binary as Bin
import CoercibleUtils (op, over, over2)
import Control.DeepSeq (NFData(..))
import Data.Finitary (Finitary(..))
import Data.Finite.Internal (Finite(..), getFinite)
import GHC.TypeLits.Extra
import Control.Monad.Primitive (PrimMonad(primitive))
import Data.Primitive.ByteArray (ByteArray(..), MutableByteArray(..))
import qualified Data.Vector.Unboxed.Base as VU
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Primitive as VP
import qualified Data.Vector.Generic.Mutable as VGM
import qualified Data.Vector.Primitive.Mutable as VPM
import Data.Vector.Binary ()
import Data.Vector.Instances ()
#ifdef BIGNUM
import GHC.Num.BigNat (BigNat(..), bigNatCompare, bigNatSize#)
import GHC.Num.Integer (integerToNaturalClamp, integerFromBigNat#)
#else
import GHC.Integer.GMP.Internals
( BigNat(..), bigNatToInteger, compareBigNat, sizeofBigNat# )
#endif
newtype PackWords (a :: Type) = PackedWords ByteArray
deriving (Eq, Show)
type role PackWords nominal
{-# COMPLETE Packed #-}
pattern Packed :: forall (a :: Type) .
(Finitary a, 1 <= Cardinality a) =>
a -> PackWords a
pattern Packed x <- (unpackWords -> x)
where Packed x = packWords x
instance (Finitary a, 1 <= Cardinality a) => Ord (PackWords a) where
compare (PackedWords (ByteArray ba1)) (PackedWords (ByteArray ba2)) =
#ifdef BIGNUM
bigNatCompare ba1 ba2
#else
compareBigNat (BN# ba1) (BN# ba2)
#endif
instance (Finitary a, 1 <= Cardinality a) => Bin.Binary (PackWords a) where
{-# INLINE put #-}
put = Bin.put . VP.Vector @Word 0 (wordLength @a) . op PackedWords
{-# INLINE get #-}
get = PackedWords . ( \ ( VP.Vector _ _ ba :: VP.Vector Word ) -> ba ) <$> Bin.get
instance (Finitary a, 1 <= Cardinality a) => Hashable (PackWords a) where
{-# INLINE hashWithSalt #-}
hashWithSalt salt = ( \ ( ByteArray ba ) -> hashByteArrayWithSalt ba 0 (bytesPerWord * wordLength @a) salt )
. op PackedWords
instance NFData (PackWords a) where
{-# INLINE rnf #-}
rnf = rnf . op PackedWords
instance (Finitary a, 1 <= Cardinality a) => Finitary (PackWords a) where
type Cardinality (PackWords a) = Cardinality a
{-# INLINE fromFinite #-}
fromFinite = PackedWords . intoWords
{-# INLINE toFinite #-}
toFinite = outOfWords . op PackedWords
instance (Finitary a, 1 <= Cardinality a) => Bounded (PackWords a) where
{-# INLINE minBound #-}
minBound = start
{-# INLINE maxBound #-}
maxBound = end
instance (Finitary a, 1 <= Cardinality a) => Storable (PackWords a) where
{-# INLINABLE sizeOf #-}
sizeOf _ = wordLength @a * bytesPerWord
{-# INLINABLE alignment #-}
alignment _ = alignment (undefined :: Word)
{-# INLINABLE peek #-}
peek (Ptr addr) =
IO $ \ s1 ->
case newByteArray# nbBytes s1 of
(# s2, mba #) -> case copyAddrToByteArray# addr mba 0# nbBytes s2 of
s3 -> case unsafeFreezeByteArray# mba s3 of
(# s4, ba #) -> (# s4, PackedWords (ByteArray ba) #)
where
nbBytes :: Int#
!(I# nbBytes) = bytesPerWord * wordLength @a
{-# INLINE poke #-}
poke (Ptr addr) (PackedWords (ByteArray ba)) =
IO $ \ s1 ->
case copyByteArrayToAddr# ba 0# addr nbBytes s1 of
s2 -> (# s2, () #)
where
nbBytes :: Int#
!(I# nbBytes) = bytesPerWord * wordLength @a
newtype instance VU.MVector s (PackWords a) = MV_PackWords (VU.MVector s Word)
instance (Finitary a, 1 <= Cardinality a) => VGM.MVector VU.MVector (PackWords a) where
{-# INLINE basicLength #-}
basicLength = over MV_PackWords ((`div` wordLength @a) . VGM.basicLength)
{-# INLINE basicOverlaps #-}
basicOverlaps = over2 MV_PackWords VGM.basicOverlaps
{-# INLINABLE basicUnsafeSlice #-}
basicUnsafeSlice i len = over MV_PackWords (VGM.basicUnsafeSlice (i * wordLength @a) (len * wordLength @a))
{-# INLINABLE basicUnsafeNew #-}
basicUnsafeNew len = MV_PackWords <$> VGM.basicUnsafeNew (len * wordLength @a)
{-# INLINE basicInitialize #-}
basicInitialize = VGM.basicInitialize . op MV_PackWords
{-# INLINABLE basicUnsafeRead #-}
basicUnsafeRead (MV_PackWords (VU.MV_Word (VPM.MVector (I# off) _ (MutableByteArray full_mba)))) (I# i) =
primitive $ \ s1 ->
case newByteArray# nbBytes s1 of
(# s2, elem_mba #) -> case copyMutableByteArray# full_mba (wordSize *# off +# nbBytes *# i) elem_mba 0# nbBytes s2 of
s3 -> case unsafeFreezeByteArray# elem_mba s3 of
(# s4, elem_ba #) -> (# s4, PackedWords (ByteArray elem_ba) #)
where
nbBytes, wordSize :: Int#
!(I# nbBytes) = bytesPerWord * wordLength @a
!(I# wordSize) = bytesPerWord
{-# INLINABLE basicUnsafeWrite #-}
basicUnsafeWrite (MV_PackWords (VU.MV_Word (VPM.MVector (I# off) _ (MutableByteArray full_mba)))) (I# i) (PackedWords (ByteArray val_ba)) =
primitive $ \ s1 -> case copyByteArray# val_ba 0# full_mba (wordSize *# off +# nbBytes *# i) nbBytes s1 of
s2 -> (# s2, () #)
where
nbBytes, wordSize :: Int#
!(I# nbBytes) = bytesPerWord * wordLength @a
!(I# wordSize) = bytesPerWord
newtype instance VU.Vector (PackWords a) = V_PackWords (VU.Vector Word)
instance (Finitary a, 1 <= Cardinality a) => VG.Vector VU.Vector (PackWords a) where
{-# INLINE basicLength #-}
basicLength = over V_PackWords ((`div` wordLength @a) . VG.basicLength)
{-# INLINE basicUnsafeFreeze #-}
basicUnsafeFreeze = fmap V_PackWords . VG.basicUnsafeFreeze . op MV_PackWords
{-# INLINE basicUnsafeThaw #-}
basicUnsafeThaw = fmap MV_PackWords . VG.basicUnsafeThaw . op V_PackWords
{-# INLINABLE basicUnsafeSlice #-}
basicUnsafeSlice i len = over V_PackWords (VG.basicUnsafeSlice (i * wordLength @a) (len * wordLength @a))
{-# INLINABLE basicUnsafeIndexM #-}
basicUnsafeIndexM (V_PackWords (VU.V_Word (VP.Vector (I# off) _ (ByteArray full_ba)))) (I# i) =
pure $ runRW# $ \ s1 ->
case newByteArray# nbBytes s1 of
(# s2, elem_mba #) -> case copyByteArray# full_ba (wordSize *# off +# nbBytes *# i) elem_mba 0# nbBytes s2 of
s3 -> case unsafeFreezeByteArray# elem_mba s3 of
(# _, elem_ba #) -> PackedWords (ByteArray elem_ba)
where
nbBytes, wordSize :: Int#
!(I# nbBytes) = bytesPerWord * wordLength @a
!(I# wordSize) = bytesPerWord
instance (Finitary a, 1 <= Cardinality a) => VU.Unbox (PackWords a)
type WordLength a = NatWords (Cardinality a)
type NatWords n = CLog (Cardinality Word) n
{-# INLINE bytesPerWord #-}
bytesPerWord :: forall (a :: Type) .
(Num a) =>
a
bytesPerWord = fromIntegral . sizeOf $ (undefined :: Word)
{-# INLINE wordLength #-}
wordLength :: forall (a :: Type) (b :: Type) .
(Finitary a, 1 <= Cardinality a, Num b) =>
b
wordLength = fromIntegral $ natVal' @(WordLength a) proxy#
{-# INLINE natWords #-}
natWords :: forall (n :: Nat) (b :: Type) .
(KnownNat n, 1 <= n, Num b) =>
b
natWords = fromIntegral $ natVal' @(NatWords n) proxy#
{-# INLINE packWords #-}
packWords :: forall (a :: Type) .
(Finitary a, 1 <= Cardinality a) =>
a -> PackWords a
packWords = fromFinite . toFinite
{-# INLINE unpackWords #-}
unpackWords :: forall (a :: Type) .
(Finitary a, 1 <= Cardinality a) =>
PackWords a -> a
unpackWords = fromFinite . toFinite
{-# INLINABLE intoWords #-}
intoWords :: forall (n :: Nat) .
(KnownNat n, 1 <= n) =>
Finite n -> ByteArray
intoWords f = runRW# $ \ s1 ->
case newByteArray# nbBytes s1 of
(# s2, mba #) ->
case (
case i of
NatS# word
| 0## <- word
-> (# s2, 0# #)
| otherwise
-> case writeWordArray# mba 0# word s2 of
s3 -> (# s3, wordSize #)
NatJ# (BN# bigNatArray) ->
let
nbBytesWritten :: Int#
#ifdef BIGNUM
nbBytesWritten = wordSize *# bigNatSize# bigNatArray
#else
nbBytesWritten = wordSize *# sizeofBigNat# (BN# bigNatArray)
#endif
in
case copyByteArray# bigNatArray 0# mba 0# nbBytesWritten s2 of
s3 -> (# s3, nbBytesWritten #)
) of
(# s3, bytesWritten #) ->
case setByteArray# mba bytesWritten (nbBytes -# bytesWritten) 0# s3 of
s4 -> case unsafeFreezeByteArray# mba s4 of
(# _, ba #) -> ByteArray ba
where
wordSize :: Int#
!(I# wordSize) = bytesPerWord
nbBytes :: Int#
!(I# nbBytes) = I# wordSize * natWords @n
i :: Natural
i =
#ifdef BIGNUM
integerToNaturalClamp ( getFinite f )
#else
fromIntegral ( getFinite f )
#endif
{-# INLINABLE outOfWords #-}
outOfWords :: forall (n :: Nat) .
(KnownNat n) =>
ByteArray -> Finite n
outOfWords (ByteArray ba) =
#ifdef BIGNUM
Finite $ integerFromBigNat# ba
#else
Finite $ bigNatToInteger (BN# ba)
#endif