{-|
Copyright   : (C) 2021-2022, QBayLogic B.V.
License     : BSD2 (see the file LICENSE)
Maintainer  : QBayLogic B.V. <devops@qbaylogic.com>

Random generation of BitVector.
-}

{-# OPTIONS_GHC -fplugin=GHC.TypeLits.KnownNat.Solver #-}

{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}

module Clash.Hedgehog.Sized.BitVector
  ( genDefinedBit
  , genBit
  , genDefinedBitVector
  , genBitVector
  , SomeBitVector(..)
  , genSomeBitVector
  ) where

import GHC.Natural (Natural)
import GHC.TypeNats
import Hedgehog (MonadGen, Range)
import qualified Hedgehog.Gen as Gen

import Clash.Promoted.Nat
import Clash.Sized.Internal.BitVector
import Clash.Sized.Vector (v2bv)
import Clash.XException (errorX)

import Clash.Hedgehog.Sized.Vector (genVec)

-- | Generate a bit which is guaranteed to be defined.
-- This will either have the value 'low' or 'high'.
--
genDefinedBit :: (MonadGen m) => m Bit
genDefinedBit :: m Bit
genDefinedBit = [Bit] -> m Bit
forall (m :: Type -> Type) a. MonadGen m => [a] -> m a
Gen.element [Bit
low, Bit
high]

-- | Generate a bit which is not guaranteed to be defined.
-- This will either have the value 'low' or 'high', or throw an 'XException'.
--
genBit :: (MonadGen m) => m Bit
genBit :: m Bit
genBit = [Bit] -> m Bit
forall (m :: Type -> Type) a. MonadGen m => [a] -> m a
Gen.element [Bit
low, Bit
high, String -> Bit
forall a. HasCallStack => String -> a
errorX String
"X"]

-- | Generate a bit vector where all bits are defined.
--
genDefinedBitVector :: (MonadGen m, KnownNat n) => m (BitVector n)
genDefinedBitVector :: m (BitVector n)
genDefinedBitVector =
  [(Int, m (BitVector n))] -> m (BitVector n)
forall (m :: Type -> Type) a. MonadGen m => [(Int, m a)] -> m a
Gen.frequency
    [ (Int
60, (Vec n Bit -> BitVector n) -> m (Vec n Bit) -> m (BitVector n)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Vec n Bit -> BitVector n
forall (n :: Nat). KnownNat n => Vec n Bit -> BitVector n
v2bv (m Bit -> m (Vec n Bit)
forall (m :: Type -> Type) (n :: Nat) a.
(MonadGen m, KnownNat n) =>
m a -> m (Vec n a)
genVec m Bit
forall (m :: Type -> Type). MonadGen m => m Bit
genDefinedBit))
    , (Int
20, BitVector n -> m (BitVector n)
forall (m :: Type -> Type) a. MonadGen m => a -> m a
Gen.constant BitVector n
forall a. Bounded a => a
minBound)
    , (Int
20, BitVector n -> m (BitVector n)
forall (m :: Type -> Type) a. MonadGen m => a -> m a
Gen.constant BitVector n
forall a. Bounded a => a
maxBound)
    ]

-- | Generate a bit vector where some bits may be undefined.
--
genBitVector :: (MonadGen m, KnownNat n) => m (BitVector n)
genBitVector :: m (BitVector n)
genBitVector =
  [(Int, m (BitVector n))] -> m (BitVector n)
forall (m :: Type -> Type) a. MonadGen m => [(Int, m a)] -> m a
Gen.frequency
    [ (Int
55, (Vec n Bit -> BitVector n) -> m (Vec n Bit) -> m (BitVector n)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Vec n Bit -> BitVector n
forall (n :: Nat). KnownNat n => Vec n Bit -> BitVector n
v2bv (m Bit -> m (Vec n Bit)
forall (m :: Type -> Type) (n :: Nat) a.
(MonadGen m, KnownNat n) =>
m a -> m (Vec n a)
genVec m Bit
forall (m :: Type -> Type). MonadGen m => m Bit
genBit))
    , (Int
15, BitVector n -> m (BitVector n)
forall (m :: Type -> Type) a. MonadGen m => a -> m a
Gen.constant BitVector n
forall a. Bounded a => a
minBound)
    , (Int
15, BitVector n -> m (BitVector n)
forall (m :: Type -> Type) a. MonadGen m => a -> m a
Gen.constant BitVector n
forall a. Bounded a => a
maxBound)
    , (Int
15, BitVector n -> m (BitVector n)
forall (m :: Type -> Type) a. MonadGen m => a -> m a
Gen.constant BitVector n
forall (n :: Nat). KnownNat n => BitVector n
undefined#)
    ]

data SomeBitVector atLeast where
  SomeBitVector :: SNat n -> BitVector (atLeast + n) -> SomeBitVector atLeast

instance KnownNat atLeast => Show (SomeBitVector atLeast) where
  show :: SomeBitVector atLeast -> String
show (SomeBitVector SNat n
SNat BitVector (atLeast + n)
bv) = BitVector (atLeast + n) -> String
forall a. Show a => a -> String
show BitVector (atLeast + n)
bv

genSomeBitVector
  :: forall m atLeast
   . (MonadGen m, KnownNat atLeast)
  => Range Natural
  -> (forall n. KnownNat n => m (BitVector n))
  -> m (SomeBitVector atLeast)
genSomeBitVector :: Range Natural
-> (forall (n :: Nat). KnownNat n => m (BitVector n))
-> m (SomeBitVector atLeast)
genSomeBitVector Range Natural
rangeBv forall (n :: Nat). KnownNat n => m (BitVector n)
genBv = do
  Natural
numExtra <- Range Natural -> m Natural
forall (m :: Type -> Type) a.
(MonadGen m, Integral a) =>
Range a -> m a
Gen.integral Range Natural
rangeBv

  case Natural -> SomeNat
someNatVal Natural
numExtra of
    SomeNat Proxy n
proxy -> SNat n -> BitVector (atLeast + n) -> SomeBitVector atLeast
forall (n :: Nat) (atLeast :: Nat).
SNat n -> BitVector (atLeast + n) -> SomeBitVector atLeast
SomeBitVector (Proxy n -> SNat n
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> SNat n
snatProxy Proxy n
proxy) (BitVector (atLeast + n) -> SomeBitVector atLeast)
-> m (BitVector (atLeast + n)) -> m (SomeBitVector atLeast)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> m (BitVector (atLeast + n))
forall (n :: Nat). KnownNat n => m (BitVector n)
genBv