module Data.Graph.Haggle.Internal.BitSet (
  BitSet,
  newBitSet,
  setBit,
  testBit
  ) where

import Control.Monad.ST
import qualified Data.Bits as B
import Data.Vector.Unboxed.Mutable ( STVector )
import qualified Data.Vector.Unboxed.Mutable as V
import Data.Word ( Word64 )

data BitSet s = BS (STVector s Word64) {-# UNPACK #-} !Int

bitsPerWord :: Int
bitsPerWord :: Int
bitsPerWord = Int
64

-- | Allocate a new 'BitSet' with @n@ bits.  Bits are all
-- initialized to zero.
--
-- > bs <- newBitSet n
newBitSet :: Int -> ST s (BitSet s)
newBitSet :: Int -> ST s (BitSet s)
newBitSet Int
n = do
  let nWords :: Int
nWords = (Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
bitsPerWord) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
  STVector s Word64
v <- Int -> Word64 -> ST s (MVector (PrimState (ST s)) Word64)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
V.replicate Int
nWords Word64
0
  BitSet s -> ST s (BitSet s)
forall (m :: * -> *) a. Monad m => a -> m a
return (BitSet s -> ST s (BitSet s)) -> BitSet s -> ST s (BitSet s)
forall a b. (a -> b) -> a -> b
$ STVector s Word64 -> Int -> BitSet s
forall s. STVector s Word64 -> Int -> BitSet s
BS STVector s Word64
v Int
n

-- | Set a bit in the bitset.  Out of range has no effect.
setBit :: BitSet s -> Int -> ST s ()
setBit :: BitSet s -> Int -> ST s ()
setBit (BS STVector s Word64
v Int
sz) Int
bitIx
  | Int
bitIx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
sz = () -> ST s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  | Bool
otherwise = do
    let wordIx :: Int
wordIx = Int
bitIx Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
bitsPerWord
        bitPos :: Int
bitPos = Int
bitIx Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
bitsPerWord
    Word64
oldWord <- MVector (PrimState (ST s)) Word64 -> Int -> ST s Word64
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
V.read STVector s Word64
MVector (PrimState (ST s)) Word64
v Int
wordIx
    let newWord :: Word64
newWord = Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
B.setBit Word64
oldWord Int
bitPos
    MVector (PrimState (ST s)) Word64 -> Int -> Word64 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
V.write STVector s Word64
MVector (PrimState (ST s)) Word64
v Int
wordIx Word64
newWord

-- | Return True if the bit is set.  Out of range will return False.
testBit :: BitSet s -> Int -> ST s Bool
testBit :: BitSet s -> Int -> ST s Bool
testBit (BS STVector s Word64
v Int
sz) Int
bitIx
  | Int
bitIx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
sz = Bool -> ST s Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
  | Bool
otherwise = do
    let wordIx :: Int
wordIx = Int
bitIx Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
bitsPerWord
        bitPos :: Int
bitPos = Int
bitIx Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
bitsPerWord
    Word64
w <- MVector (PrimState (ST s)) Word64 -> Int -> ST s Word64
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
V.read STVector s Word64
MVector (PrimState (ST s)) Word64
v Int
wordIx
    Bool -> ST s Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> ST s Bool) -> Bool -> ST s Bool
forall a b. (a -> b) -> a -> b
$ Word64 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
B.testBit Word64
w Int
bitPos