{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE UndecidableInstances #-}
module Flat.Class
(
Flat(..)
, getSize
, module GHC.Generics
)
where
import Data.Bits
import Flat.Decoder
import Flat.Encoder
import Data.Word
import GHC.Generics
import GHC.TypeLits
import Prelude hiding (mempty)
#define INL 2
#if INL == 1
import GHC.Exts (inline)
#endif
getSize :: Flat a => a -> NumBits
getSize a = size a 0
class Flat a where
encode :: a -> Encoding
default encode :: (Generic a, GEncode (Rep a)) => a -> Encoding
encode = gencode . from
decode :: Get a
default decode :: (Generic a, GDecode (Rep a)) => Get a
decode = to `fmap` gget
size :: a -> NumBits -> NumBits
default size :: (Generic a, GSize (Rep a)) => a -> NumBits -> NumBits
size !x !n = gsize n $ from x
#if INL>=2
{-# INLINE size #-}
{-# INLINE decode #-}
{-# INLINE encode #-}
#elif INL == 1
#elif INL == 0
{-# NOINLINE size #-}
{-# NOINLINE decode #-}
{-# NOINLINE encode #-}
#endif
class GEncode f where gencode :: f a -> Encoding
instance {-# OVERLAPPABLE #-} GEncode f => GEncode (M1 i c f) where
gencode = gencode . unM1
{-# INLINE gencode #-}
instance {-# OVERLAPPING #-} GEncode a => GEncode (D1 i (C1 c a)) where
gencode = gencode . unM1 . unM1
{-# INLINE gencode #-}
instance GEncode V1 where
gencode = unused
{-# INLINE gencode #-}
instance GEncode U1 where
gencode U1 = mempty
{-# INLINE gencode #-}
instance Flat a => GEncode (K1 i a) where
{-# INLINE gencode #-}
#if INL == 1
gencode x = inline encode (unK1 x)
#else
gencode = encode . unK1
#endif
instance (GEncode a, GEncode b) => GEncode (a :*: b) where
gencode (x :*: y) = gencode x <> gencode y
{-# INLINE gencode #-}
instance (NumConstructors (a :+: b) <= 512,GEncodeSum (a :+: b)) => GEncode (a :+: b) where
gencode = gencodeSum 0 0
{-# INLINE gencode #-}
class GEncodeSum f where
gencodeSum :: Word16 -> NumBits -> f a -> Encoding
instance (GEncodeSum a, GEncodeSum b) => GEncodeSum (a :+: b) where
gencodeSum !code !numBits s = case s of
L1 !x -> gencodeSum ((code `unsafeShiftL` 1)) (numBits+1) x
R1 !x -> gencodeSum ((code `unsafeShiftL` 1) .|. 1) (numBits+1) x
{-# INLINE gencodeSum #-}
instance GEncode a => GEncodeSum (C1 c a) where
gencodeSum !code !numBits x = eBits16 numBits code <> gencode x
{-# INLINE gencodeSum #-}
class GDecode f where
gget :: Get (f t)
instance GDecode a => GDecode (M1 i c a) where
gget = M1 <$> gget
{-# INLINE gget #-}
instance GDecode V1 where
gget = unused
{-# INLINE gget #-}
instance GDecode U1 where
gget = pure U1
{-# INLINE gget #-}
instance (GDecode a, GDecode b) => GDecode (a :*: b) where
gget = (:*:) <$> gget <*> gget
{-# INLINE gget #-}
instance Flat a => GDecode (K1 i a) where
#if INL == 1
gget = K1 <$> inline decode
#else
gget = K1 <$> decode
#endif
{-# INLINE gget #-}
#define DEC_CONS
#define DEC_CONS48
#define DEC_BOOLC
#define DEC_BOOL
#ifdef DEC_BOOLG
instance (GDecode a, GDecode b) => GDecode (a :+: b)
#endif
#ifdef DEC_BOOLC
instance {-# OVERLAPPING #-} (GDecode a,GDecode b) => GDecode (C1 m1 a :+: C1 m2 b)
#endif
#ifdef DEC_BOOL
where
gget = do
!tag <- dBool
!r <- if tag then R1 <$> gget else L1 <$> gget
return r
{-# INLINE gget #-}
#endif
#ifdef DEC_CONS
instance {-# OVERLAPPABLE #-} (NumConstructors (a :+: b) <= 512, GDecodeSum (a :+: b)) => GDecode (a :+: b) where
gget = do
cs <- consOpen
getSum cs
{-# INLINE gget #-}
class GDecodeSum f where
getSum :: ConsState -> Get (f a)
#ifdef DEC_CONS48
instance {-# OVERLAPPING #-} (GDecodeSum n1,GDecodeSum n2,GDecodeSum n3,GDecodeSum n4) => GDecodeSum ((n1 :+: n2) :+: (n3 :+: n4))
where
getSum cs = do
let (cs',tag) = consBits cs 2
case tag of
0 -> L1 . L1 <$> getSum cs'
1 -> L1 . R1 <$> getSum cs'
2 -> R1 . L1 <$> getSum cs'
_ -> R1 . R1 <$> getSum cs'
{-# INLINE getSum #-}
instance {-# OVERLAPPING #-} (GDecodeSum n1,GDecodeSum n2,GDecodeSum n3,GDecodeSum n4,GDecodeSum n5,GDecodeSum n6,GDecodeSum n7,GDecodeSum n8) => GDecodeSum (((n1 :+: n2) :+: (n3 :+: n4)) :+: ((n5 :+: n6) :+: (n7 :+: n8)))
where
getSum cs = do
let (cs',tag) = consBits cs 3
case tag of
0 -> L1 . L1 . L1 <$> getSum cs'
1 -> L1 . L1 . R1 <$> getSum cs'
2 -> L1 . R1 . L1 <$> getSum cs'
3 -> L1 . R1 . R1 <$> getSum cs'
4 -> R1 . L1 . L1 <$> getSum cs'
5 -> R1 . L1 . R1 <$> getSum cs'
6 -> R1 . R1 . L1 <$> getSum cs'
_ -> R1 . R1 . R1 <$> getSum cs'
{-# INLINE getSum #-}
instance {-# OVERLAPPABLE #-} (GDecodeSum a, GDecodeSum b) => GDecodeSum (a :+: b) where
#else
instance (GDecodeSum a, GDecodeSum b) => GDecodeSum (a :+: b) where
#endif
getSum cs = do
let (cs',tag) = consBool cs
if tag then R1 <$> getSum cs' else L1 <$> getSum cs'
{-# INLINE getSum #-}
instance GDecode a => GDecodeSum (C1 c a) where
getSum (ConsState _ usedBits) = consClose usedBits >> gget
{-# INLINE getSum #-}
#endif
#ifdef DEC_BOOL48
instance {-# OVERLAPPING #-} (GDecode n1,GDecode n2,GDecode n3,GDecode n4) => GDecode ((n1 :+: n2) :+: (n3 :+: n4))
where
gget = do
!tag <- dBEBits8 2
case tag of
0 -> L1 <$> L1 <$> gget
1 -> L1 <$> R1 <$> gget
2 -> R1 <$> L1 <$> gget
_ -> R1 <$> R1 <$> gget
{-# INLINE gget #-}
instance {-# OVERLAPPING #-} (GDecode n1,GDecode n2,GDecode n3,GDecode n4,GDecode n5,GDecode n6,GDecode n7,GDecode n8) => GDecode (((n1 :+: n2) :+: (n3 :+: n4)) :+: ((n5 :+: n6) :+: (n7 :+: n8)))
where
gget = do
!tag <- dBEBits8 3
case tag of
0 -> L1 <$> L1 <$> L1 <$> gget
1 -> L1 <$> L1 <$> R1 <$> gget
2 -> L1 <$> R1 <$> L1 <$> gget
3 -> L1 <$> R1 <$> R1 <$> gget
4 -> R1 <$> L1 <$> L1 <$> gget
5 -> R1 <$> L1 <$> R1 <$> gget
6 -> R1 <$> R1 <$> L1 <$> gget
_ -> R1 <$> R1 <$> R1 <$> gget
{-# INLINE gget #-}
#endif
class GSize f where gsize :: NumBits -> f a -> NumBits
instance GSize f => GSize (M1 i c f) where
gsize !n = gsize n . unM1
{-# INLINE gsize #-}
instance GSize V1 where
gsize !n _ = n
{-# INLINE gsize #-}
instance GSize U1 where
gsize !n _ = n
{-# INLINE gsize #-}
instance Flat a => GSize (K1 i a) where
#if INL == 1
gsize !n x = inline size (unK1 x) n
#else
gsize !n x = size (unK1 x) n
#endif
{-# INLINE gsize #-}
instance (GSize a, GSize b) => GSize (a :*: b) where
gsize !n (x :*: y) =
let !n' = gsize n x
in gsize n' y
{-# INLINE gsize #-}
#define SIZ_ADD
#ifdef SIZ_ADD
instance (GSizeSum (a :+: b)) => GSize (a :+: b) where
gsize !n = gsizeSum n
#endif
#ifdef SIZ_NUM
instance (GSizeSum (a :+: b)) => GSize (a :+: b) where
gsize !n x = n + gsizeSum 0 x
#endif
#ifdef SIZ_MAX
instance (GSizeNxt (a :+: b),GSizeMax (a:+:b)) => GSize (a :+: b) where
gsize !n x = gsizeNxt (gsizeMax x + n) x
{-# INLINE gsize #-}
#ifdef SIZ_MAX_VAL
class GSizeMax (f :: * -> *) where gsizeMax :: f a -> NumBits
instance (GSizeMax f, GSizeMax g) => GSizeMax (f :+: g) where
gsizeMax _ = 1 + max (gsizeMax (undefined::f a )) (gsizeMax (undefined::g a))
{-# INLINE gsizeMax #-}
instance (GSize a) => GSizeMax (C1 c a) where
{-# INLINE gsizeMax #-}
gsizeMax _ = 0
#endif
#ifdef SIZ_MAX_PROX
type family ConsSize (a :: * -> *) :: Nat where
ConsSize (C1 c a) = 0
ConsSize (x :+: y) = 1 + Max (ConsSize x) (ConsSize y)
type family Max (n :: Nat) (m :: Nat) :: Nat where
Max n m = If (n <=? m) m n
type family If c (t::Nat) (e::Nat) where
If 'True t e = t
If 'False t e = e
#endif
class GSizeNxt (f :: * -> *) where gsizeNxt :: NumBits -> f a -> NumBits
instance (GSizeNxt a, GSizeNxt b) => GSizeNxt (a :+: b) where
gsizeNxt n x = case x of
L1 !l-> gsizeNxt n l
R1 !r-> gsizeNxt n r
{-# INLINE gsizeNxt #-}
instance (GSize a) => GSizeNxt (C1 c a) where
{-# INLINE gsizeNxt #-}
gsizeNxt !n !x = gsize n x
#endif
class GSizeSum (f :: * -> *) where gsizeSum :: NumBits -> f a -> NumBits
instance (GSizeSum a, GSizeSum b)
=> GSizeSum (a :+: b) where
gsizeSum !n x = case x of
L1 !l-> gsizeSum (n+1) l
R1 !r-> gsizeSum (n+1) r
{-# INLINE gsizeSum #-}
instance (GSize a) => GSizeSum (C1 c a) where
{-# INLINE gsizeSum #-}
gsizeSum !n !x = gsize n x
type family NumConstructors (a :: * -> *) :: Nat where
NumConstructors (C1 c a) = 1
NumConstructors (x :+: y) = NumConstructors x + NumConstructors y
unused :: forall a . a
unused = error "Now, now, you could not possibly have meant this.."