{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE Trustworthy #-}
{-# OPTIONS_GHC -fplugin=GHC.TypeLits.Extra.Solver #-}
{-# OPTIONS_GHC -fplugin=GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin=GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_HADDOCK show-extensions #-}
#include "MachDeps.h"
module Clash.Class.BitPack
( BitPack (..)
, bitCoerce
, bitCoerceMap
, boolToBV
, boolToBit
, bitToBool
, packXWith
)
where
import Control.Exception (catch, evaluate)
import Data.Binary.IEEE754 (doubleToWord, floatToWord, wordToDouble,
wordToFloat)
#if MIN_VERSION_base(4,12,0)
import Data.Complex (Complex)
import Data.Ord (Down)
#endif
import Data.Int
import Data.Word
import Foreign.C.Types (CUShort)
import GHC.TypeLits (KnownNat, Nat, type (+), type (-))
import Numeric.Half (Half (..))
import GHC.Generics
import GHC.TypeLits.Extra (CLog, Max)
import Prelude hiding (map)
import System.IO.Unsafe (unsafeDupablePerformIO)
import Clash.Promoted.Nat (SNat(..), snatToNum)
import Clash.Class.BitPack.Internal (deriveBitPackTuples)
import Clash.Class.Resize (zeroExtend, resize)
import Clash.Sized.BitVector (Bit, BitVector, (++#))
import Clash.Sized.Internal.BitVector
(pack#, split#, checkUnpackUndef, undefined#, unpack#, unsafeToNatural)
import Clash.XException
class KnownNat (BitSize a) => BitPack a where
type BitSize a :: Nat
type BitSize a = (CLog 2 (GConstructorCount (Rep a))) + (GFieldSize (Rep a))
pack :: a -> BitVector (BitSize a)
default pack
:: ( Generic a
, GBitPack (Rep a)
, KnownNat (BitSize a)
, KnownNat constrSize
, KnownNat fieldSize
, constrSize ~ CLog 2 (GConstructorCount (Rep a))
, fieldSize ~ GFieldSize (Rep a)
, (constrSize + fieldSize) ~ BitSize a
)
=> a -> BitVector (BitSize a)
pack = packXWith go
where
go a = resize (pack sc) ++# packedFields
where
(sc, packedFields) = gPackFields 0 (from a)
unpack :: BitVector (BitSize a) -> a
default unpack
:: ( Generic a
, GBitPack (Rep a)
, KnownNat constrSize
, KnownNat fieldSize
, constrSize ~ CLog 2 (GConstructorCount (Rep a))
, fieldSize ~ GFieldSize (Rep a)
, (constrSize + fieldSize) ~ BitSize a
)
=> BitVector (BitSize a) -> a
unpack b =
to (gUnpack sc 0 bFields)
where
(checkUnpackUndef unpack . resize -> sc, bFields) = split# b
packXWith
:: KnownNat n
=> (a -> BitVector n)
-> a
-> BitVector n
packXWith f x =
unsafeDupablePerformIO (catch (f <$> evaluate x)
(\(XException _) -> return undefined#))
{-# NOINLINE packXWith #-}
{-# INLINE[1] bitCoerce #-}
bitCoerce
:: (BitPack a, BitPack b, BitSize a ~ BitSize b)
=> a
-> b
bitCoerce = unpack . pack
bitCoerceMap
:: forall a b . (BitPack a, BitPack b, BitSize a ~ BitSize b)
=> (a -> a)
-> b
-> b
bitCoerceMap f = bitCoerce . f . bitCoerce
instance BitPack Bool where
type BitSize Bool = 1
pack = let go b = if b then 1 else 0 in packXWith go
unpack = checkUnpackUndef $ \bv -> if bv == 1 then True else False
instance KnownNat n => BitPack (BitVector n) where
type BitSize (BitVector n) = n
pack = packXWith id
unpack v = v
instance BitPack Bit where
type BitSize Bit = 1
pack = packXWith pack#
unpack = unpack#
instance BitPack Int where
type BitSize Int = WORD_SIZE_IN_BITS
pack = packXWith fromIntegral
unpack = checkUnpackUndef fromIntegral
instance BitPack Int8 where
type BitSize Int8 = 8
pack = packXWith fromIntegral
unpack = checkUnpackUndef fromIntegral
instance BitPack Int16 where
type BitSize Int16 = 16
pack = packXWith fromIntegral
unpack = checkUnpackUndef fromIntegral
instance BitPack Int32 where
type BitSize Int32 = 32
pack = packXWith fromIntegral
unpack = checkUnpackUndef fromIntegral
instance BitPack Int64 where
type BitSize Int64 = 64
pack = packXWith fromIntegral
unpack = checkUnpackUndef fromIntegral
instance BitPack Word where
type BitSize Word = WORD_SIZE_IN_BITS
pack = packXWith fromIntegral
unpack = checkUnpackUndef fromIntegral
instance BitPack Word8 where
type BitSize Word8 = 8
pack = packXWith fromIntegral
unpack = checkUnpackUndef fromIntegral
instance BitPack Word16 where
type BitSize Word16 = 16
pack = packXWith fromIntegral
unpack = checkUnpackUndef fromIntegral
instance BitPack Word32 where
type BitSize Word32 = 32
pack = packXWith fromIntegral
unpack = checkUnpackUndef fromIntegral
instance BitPack Word64 where
type BitSize Word64 = 64
pack = packXWith fromIntegral
unpack = checkUnpackUndef fromIntegral
instance BitPack Float where
type BitSize Float = 32
pack = packXWith packFloat#
unpack = checkUnpackUndef unpackFloat#
packFloat# :: Float -> BitVector 32
packFloat# = fromIntegral . floatToWord
{-# NOINLINE packFloat# #-}
unpackFloat# :: BitVector 32 -> Float
unpackFloat# (unsafeToNatural -> w) = wordToFloat (fromIntegral w)
{-# NOINLINE unpackFloat# #-}
instance BitPack Double where
type BitSize Double = 64
pack = packXWith packDouble#
unpack = checkUnpackUndef unpackDouble#
packDouble# :: Double -> BitVector 64
packDouble# = fromIntegral . doubleToWord
{-# NOINLINE packDouble# #-}
unpackDouble# :: BitVector 64 -> Double
unpackDouble# (unsafeToNatural -> w) = wordToDouble (fromIntegral w)
{-# NOINLINE unpackDouble# #-}
instance BitPack CUShort where
type BitSize CUShort = 16
pack = packXWith fromIntegral
unpack = checkUnpackUndef fromIntegral
instance BitPack Half where
type BitSize Half = 16
pack (Half x) = pack x
unpack = checkUnpackUndef $ \x -> Half (unpack x)
instance BitPack () where
type BitSize () = 0
pack _ = minBound
unpack _ = ()
instance (BitPack a, BitPack b) =>
BitPack (a,b) where
type BitSize (a,b) = BitSize a + BitSize b
pack = let go (a,b) = pack a ++# pack b in packXWith go
unpack ab = let (a,b) = split# ab in (unpack a, unpack b)
class GBitPack f where
type GFieldSize f :: Nat
type GConstructorCount f :: Nat
gPackFields
:: Int
-> f a
-> (Int, BitVector (GFieldSize f))
gUnpack
:: Int
-> Int
-> BitVector (GFieldSize f)
-> f a
instance GBitPack a => GBitPack (M1 m d a) where
type GFieldSize (M1 m d a) = GFieldSize a
type GConstructorCount (M1 m d a) = GConstructorCount a
gPackFields cc (M1 m1) = gPackFields cc m1
gUnpack c cc b = M1 (gUnpack c cc b)
instance ( KnownNat (GFieldSize g)
, KnownNat (GFieldSize f)
, KnownNat (GConstructorCount f)
, GBitPack f
, GBitPack g
) => GBitPack (f :+: g) where
type GFieldSize (f :+: g) = Max (GFieldSize f) (GFieldSize g)
type GConstructorCount (f :+: g) = GConstructorCount f + GConstructorCount g
gPackFields cc (L1 l) =
let (sc, packed) = gPackFields cc l in
let padding = undefined# :: BitVector (Max (GFieldSize f) (GFieldSize g) - GFieldSize f) in
(sc, packed ++# padding)
gPackFields cc (R1 r) =
let cLeft = snatToNum (SNat @(GConstructorCount f)) in
let (sc, packed) = gPackFields (cc + cLeft) r in
let padding = undefined# :: BitVector (Max (GFieldSize f) (GFieldSize g) - GFieldSize g) in
(sc, packed ++# padding)
gUnpack c cc b =
let cLeft = snatToNum (SNat @(GConstructorCount f)) in
if c < cc + cLeft then
L1 (gUnpack c cc f)
else
R1 (gUnpack c (cc + cLeft) g)
where
(f, _ :: BitVector (Max (GFieldSize f) (GFieldSize g) - GFieldSize f)) = split# b
(g, _ :: BitVector (Max (GFieldSize f) (GFieldSize g) - GFieldSize g)) = split# b
instance (KnownNat (GFieldSize g), KnownNat (GFieldSize f), GBitPack f, GBitPack g) => GBitPack (f :*: g) where
type GFieldSize (f :*: g) = GFieldSize f + GFieldSize g
type GConstructorCount (f :*: g) = 1
gPackFields cc fg =
(cc, packXWith go fg)
where
go (l0 :*: r0) =
let (_, l1) = gPackFields cc l0 in
let (_, r1) = gPackFields cc r0 in
l1 ++# r1
gUnpack c cc b =
gUnpack c cc front :*: gUnpack c cc back
where
(front, back) = split# b
instance BitPack c => GBitPack (K1 i c) where
type GFieldSize (K1 i c) = BitSize c
type GConstructorCount (K1 i c) = 1
gPackFields cc (K1 i) = (cc, pack i)
gUnpack _c _cc b = K1 (unpack b)
instance GBitPack U1 where
type GFieldSize U1 = 0
type GConstructorCount U1 = 1
gPackFields cc U1 = (cc, 0)
gUnpack _c _cc _b = U1
instance ( BitPack a
, BitPack b
) => BitPack (Either a b)
instance BitPack a => BitPack (Maybe a)
#if MIN_VERSION_base(4,12,0)
instance BitPack a => BitPack (Complex a)
instance BitPack a => BitPack (Down a)
#endif
boolToBV :: KnownNat n => Bool -> BitVector (n + 1)
boolToBV = zeroExtend . pack
boolToBit :: Bool -> Bit
boolToBit = bitCoerce
bitToBool :: Bit -> Bool
bitToBool = bitCoerce
deriveBitPackTuples ''BitPack ''BitSize 'pack 'unpack