{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
module Data.Bit.F2Poly
module Data.Bit.F2PolyTS
( F2Poly
, unF2Poly
, toF2Poly
) where
import Control.DeepSeq
import Control.Exception
import Control.Monad
import Control.Monad.ST
import Data.Bit.Immutable
import Data.Bit.Internal
import Data.Bit.Mutable
import Data.Bit.ImmutableTS
import Data.Bit.InternalTS
import Data.Bit.MutableTS
import Data.Bit.Utils
import Data.Bits
import Data.Coerce
import Data.Primitive.ByteArray
import Data.Typeable
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as MU
import GHC.Generics
#if UseIntegerGmp
import qualified Data.Vector.Primitive as P
import GHC.Exts
import GHC.Integer.GMP.Internals
import GHC.Integer.Logarithms
import Unsafe.Coerce
newtype F2Poly = F2Poly {
unF2Poly :: U.Vector Bit
deriving (Eq, Ord, Show, Typeable, Generic, NFData)
toF2Poly :: U.Vector Bit -> F2Poly
toF2Poly xs = F2Poly $ dropWhileEnd $ castFromWords $ cloneToWords xs
_isValid :: F2Poly -> Bool
_isValid (F2Poly (BitVec o l arr)) = o == 0 && l == l'
l' = U.length $ dropWhileEnd $ BitVec 0 (sizeofByteArray arr `shiftL` 3) arr
instance Num F2Poly where
(+) = coerce xorBits
(-) = coerce xorBits
negate = id
abs = id
signum = const (F2Poly (U.singleton (Bit True)))
(*) = coerce ((dropWhileEnd .) . karatsuba)
#if UseIntegerGmp
fromInteger !n = case n of
S# i# -> F2Poly $ BitVec 0 (wordSize - I# (word2Int# (clz# (int2Word# i#))))
$ fromBigNat $ wordToBigNat (int2Word# i#)
Jp# bn# -> F2Poly $ BitVec 0 (I# (integerLog2# n) + 1) $ fromBigNat bn#
Jn#{} -> error "F2Poly.fromInteger: argument must be non-negative"
fromInteger = F2Poly . dropWhileEnd . integerToBits
{-# INLINE (+) #-}
{-# INLINE (-) #-}
{-# INLINE negate #-}
{-# INLINE abs #-}
{-# INLINE signum #-}
{-# INLINE (*) #-}
{-# INLINE fromInteger #-}
instance Enum F2Poly where
fromEnum = fromIntegral
#if UseIntegerGmp
toEnum !(I# i#) = F2Poly $ BitVec 0 (wordSize - I# (word2Int# (clz# (int2Word# i#))))
$ fromBigNat $ wordToBigNat (int2Word# i#)
toEnum = fromIntegral
instance Real F2Poly where
toRational = fromIntegral
instance Integral F2Poly where
toInteger = bitsToInteger . unF2Poly
quotRem (F2Poly xs) (F2Poly ys) = (F2Poly (dropWhileEnd qs), F2Poly (dropWhileEnd rs))
(qs, rs) = quotRemBits xs ys
divMod = quotRem
mod = rem
:: U.Vector Bit
-> U.Vector Bit
-> U.Vector Bit
xorBits (BitVec _ 0 _) ys = ys
xorBits xs (BitVec _ 0 _) = xs
#if UseIntegerGmp
xorBits (BitVec 0 lx xarr) (BitVec 0 ly yarr) = case lx `compare` ly of
LT -> BitVec 0 ly zs
EQ -> dropWhileEnd $ BitVec 0 (lx `min` (sizeofByteArray zs `shiftL` 3)) zs
GT -> BitVec 0 lx zs
zs = fromBigNat (toBigNat xarr `xorBigNat` toBigNat yarr)
xorBits xs ys = dropWhileEnd $ runST $ do
let lx = U.length xs
ly = U.length ys
(shorterLen, longerLen, longer) = if lx >= ly then (ly, lx, xs) else (lx, ly, ys)
zs <- MU.replicate longerLen (Bit False)
forM_ [0, wordSize .. shorterLen - 1] $ \i ->
writeWord zs i (indexWord xs i `xor` indexWord ys i)
U.unsafeCopy (MU.drop shorterLen zs) (U.drop shorterLen longer)
U.unsafeFreeze zs
karatsubaThreshold :: Int
karatsubaThreshold = 2048
karatsuba :: U.Vector Bit -> U.Vector Bit -> U.Vector Bit
karatsuba xs ys
| karatsubaThreshold < 2 * wordSize
= error $ "karatsubaThreshold must be >= " ++ show (2 * wordSize)
| xs == ys = sqrBits xs
| lenXs <= karatsubaThreshold || lenYs <= karatsubaThreshold
= mulBits xs ys
| otherwise = runST $ do
zs <- MU.unsafeNew lenZs
forM_ [0, wordSize .. lenZs - 1] $ \k -> do
let z0 = indexWord0 zs0 k
z11 = indexWord0 zs11 (k - m)
z10 = indexWord0 zs0 (k - m)
z12 = indexWord0 zs2 (k - m)
z2 = indexWord0 zs2 (k - 2 * m)
writeWord zs k (z0 `xor` z11 `xor` z10 `xor` z12 `xor` z2)
U.unsafeFreeze zs
lenXs = U.length xs
lenYs = U.length ys
lenZs = lenXs + lenYs - 1
m' = ((lenXs `min` lenYs) + 1) `quot` 2
m = m' - modWordSize m'
xs0 = U.unsafeSlice 0 m xs
xs1 = U.unsafeSlice m (lenXs - m) xs
ys0 = U.unsafeSlice 0 m ys
ys1 = U.unsafeSlice m (lenYs - m) ys
xs01 = xorBits xs0 xs1
ys01 = xorBits ys0 ys1
zs0 = karatsuba xs0 ys0
zs2 = karatsuba xs1 ys1
zs11 = karatsuba xs01 ys01
indexWord0 :: U.Vector Bit -> Int -> Word
indexWord0 bv i
| i <= - wordSize = 0
| lenI <= 0 = 0
| i < 0, lenI >= wordSize = word0
| i < 0 = word0 .&. loMask lenI
| lenI >= wordSize = word
| otherwise = word .&. loMask lenI
lenI = U.length bv - i
word = indexWord bv i
word0 = indexWord bv 0 `unsafeShiftL` (- i)
mulBits :: U.Vector Bit -> U.Vector Bit -> U.Vector Bit
mulBits xs ys
| lenXs == 0 || lenYs == 0 = U.empty
| lenXs >= lenYs = mulBits' xs ys
| otherwise = mulBits' ys xs
lenXs = U.length xs
lenYs = U.length ys
mulBits' :: U.Vector Bit -> U.Vector Bit -> U.Vector Bit
mulBits' xs ys = runST $ do
zs <- MU.replicate lenZs (Bit False)
forM_ [0 .. lenYs - 1] $ \k ->
when (unBit (U.unsafeIndex ys k)) $
zipInPlace xor xs (MU.unsafeSlice k (lenZs - k) zs)
U.unsafeFreeze zs
lenXs = U.length xs
lenYs = U.length ys
lenZs = lenXs + lenYs - 1
sqrBits :: U.Vector Bit -> U.Vector Bit
sqrBits xs = runST $ do
let lenXs = U.length xs
zs <- MU.replicate (mulWordSize (nWords lenXs `shiftL` 1)) (Bit False)
forM_ [0, wordSize .. lenXs - 1] $ \i -> do
let (z0, z1) = sparseBits (indexWord xs i)
writeWord zs (i `shiftL` 1) z0
writeWord zs ((i `shiftL` 1) + wordSize) z1
U.unsafeFreeze zs
quotRemBits :: U.Vector Bit -> U.Vector Bit -> (U.Vector Bit, U.Vector Bit)
quotRemBits xs ys
| U.null ys = throw DivideByZero
| U.length xs < U.length ys = (U.empty, xs)
| otherwise = runST $ do
let lenXs = U.length xs
lenYs = U.length ys
lenQs = lenXs - lenYs + 1
qs <- MU.replicate lenQs (Bit False)
rs <- MU.replicate lenXs (Bit False)
U.unsafeCopy rs xs
forM_ [lenQs - 1, lenQs - 2 .. 0] $ \i -> do
Bit r <- MU.unsafeRead rs (lenYs - 1 + i)
when r $ do
MU.unsafeWrite qs i (Bit True)
zipInPlace xor ys (MU.drop i rs)
let rs' = MU.unsafeSlice 0 lenYs rs
(,) <$> U.unsafeFreeze qs <*> U.unsafeFreeze rs'
:: U.Vector Bit
-> U.Vector Bit
dropWhileEnd xs = U.unsafeSlice 0 (go (U.length xs)) xs
go n
| n < wordSize = wordSize - countLeadingZeros (indexWord xs 0 .&. loMask n)
| otherwise = case indexWord xs (n - wordSize) of
0 -> go (n - wordSize)
w -> n - countLeadingZeros w
#if UseIntegerGmp
bitsToByteArray :: U.Vector Bit -> ByteArray#
bitsToByteArray xs = arr
ys = if U.null xs then U.singleton 0 else cloneToWords xs
!(P.Vector _ _ (ByteArray arr)) = toPrimVector ys
fromBigNat :: BigNat -> ByteArray
fromBigNat = unsafeCoerce
toBigNat :: ByteArray -> BigNat
toBigNat = unsafeCoerce
bitsToInteger :: U.Vector Bit -> Integer
bitsToInteger xs = bigNatToInteger (BN# (bitsToByteArray xs))
integerToBits :: Integer -> U.Vector Bit
integerToBits x = U.generate (bitLen x) (Bit . testBit x)
bitLen :: Integer -> Int
bitLen x
= fst
$ head
$ dropWhile (\(_, b) -> x >= b)
$ map (\a -> (a, 1 `shiftL` a))
$ map (1 `shiftL`)
$ [lgWordSize..]
bitsToInteger :: U.Vector Bit -> Integer
bitsToInteger = U.ifoldl' (\acc i (Bit b) -> if b then acc `setBit` i else acc) 0