{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
module What4.Utils.BVDomain.Bitwise
( Domain(..)
, bitle
, proper
, bvdMask
, member
, pmember
, size
, asSingleton
, nonempty
, eq
, domainsOverlap
, bitbounds
, any
, singleton
, range
, interval
, union
, intersection
, concat
, select
, zext
, sext
, testBit
, shl
, lshr
, ashr
, rol
, ror
, and
, or
, xor
, not
, genDomain
, genElement
, genPair
, correct_any
, correct_singleton
, correct_overlap
, correct_union
, correct_intersection
, correct_zero_ext
, correct_sign_ext
, correct_concat
, correct_shrink
, correct_trunc
, correct_select
, correct_shl
, correct_lshr
, correct_ashr
, correct_rol
, correct_ror
, correct_eq
, correct_and
, correct_or
, correct_not
, correct_xor
, correct_testBit
) where
import Data.Bits hiding (testBit, xor)
import qualified Data.Bits as Bits
import Data.Parameterized.NatRepr
import Numeric.Natural
import GHC.TypeNats
import Test.Verification (Property, property, (==>), Gen, chooseInteger)
import qualified Prelude
import Prelude hiding (any, concat, negate, and, or, not)
import qualified What4.Utils.Arithmetic as Arith
data Domain (w :: Nat) =
BVBitInterval !Integer !Integer !Integer
deriving (Show)
proper :: NatRepr w -> Domain w -> Bool
proper w (BVBitInterval mask lo hi) =
mask == maxUnsigned w &&
bitle lo mask &&
bitle hi mask &&
bitle lo hi
member :: Domain w -> Integer -> Bool
member (BVBitInterval mask lo hi) x = bitle lo x' && bitle x' hi
where x' = x .&. mask
size :: Domain w -> Integer
size (BVBitInterval _ lo hi)
| bitle lo hi = Bits.bit p
| otherwise = 0
where
u = Bits.xor lo hi
p = Bits.popCount u
bitle :: Integer -> Integer -> Bool
bitle x y = (x .|. y) == y
bvdMask :: Domain w -> Integer
bvdMask (BVBitInterval mask _ _) = mask
genDomain :: NatRepr w -> Gen (Domain w)
genDomain w =
do let mask = maxUnsigned w
lo <- chooseInteger (0, mask)
hi <- chooseInteger (0, mask)
pure $! interval mask lo (lo .|. hi)
genElement :: Domain w -> Gen Integer
genElement (BVBitInterval _mask lo hi) =
do x <- chooseInteger (0, bit bs - 1)
pure $ stripe lo x 0
where
u = Bits.xor lo hi
bs = Bits.popCount u
stripe val x i
| x == 0 = val
| Bits.testBit u i =
let val' = if Bits.testBit x 0 then setBit val i else val in
stripe val' (x `shiftR` 1) (i+1)
| otherwise = stripe val x (i+1)
genPair :: NatRepr w -> Gen (Domain w, Integer)
genPair w =
do a <- genDomain w
x <- genElement a
return (a,x)
interval :: Integer -> Integer -> Integer -> Domain w
interval mask lo hi = BVBitInterval mask lo hi
range :: NatRepr w -> Integer -> Integer -> Domain w
range w lo hi = BVBitInterval (maxUnsigned w) lo' hi'
where
lo' = lo .&. mask
hi' = hi .&. mask
mask = maxUnsigned w
bitbounds :: Domain w -> (Integer, Integer)
bitbounds (BVBitInterval _ lo hi) = (lo, hi)
asSingleton :: Domain w -> Maybe Integer
asSingleton (BVBitInterval _ lo hi) = if lo == hi then Just lo else Nothing
nonempty :: Domain w -> Bool
nonempty (BVBitInterval _mask lo hi) = bitle lo hi
singleton :: NatRepr w -> Integer -> Domain w
singleton w x = BVBitInterval mask x' x'
where
x' = x .&. mask
mask = maxUnsigned w
any :: NatRepr w -> Domain w
any w = BVBitInterval mask 0 mask
where
mask = maxUnsigned w
domainsOverlap :: Domain w -> Domain w -> Bool
domainsOverlap a b = nonempty (intersection a b)
eq :: Domain w -> Domain w -> Maybe Bool
eq a b
| Just x <- asSingleton a
, Just y <- asSingleton b
= Just (x == y)
| Prelude.not (domainsOverlap a b) = Just False
| otherwise = Nothing
intersection :: Domain w -> Domain w -> Domain w
intersection (BVBitInterval mask alo ahi) (BVBitInterval _ blo bhi) =
BVBitInterval mask (alo .|. blo) (ahi .&. bhi)
union :: Domain w -> Domain w -> Domain w
union (BVBitInterval mask alo ahi) (BVBitInterval _ blo bhi) =
BVBitInterval mask (alo .&. blo) (ahi .|. bhi)
concat :: NatRepr u -> Domain u -> NatRepr v -> Domain v -> Domain (u + v)
concat u (BVBitInterval _ alo ahi) v (BVBitInterval _ blo bhi) =
BVBitInterval mask (cat alo blo) (cat ahi bhi)
where
cat i j = (i `shiftL` widthVal v) + j
mask = maxUnsigned (addNat u v)
shrink ::
NatRepr i ->
Domain (i + n) -> Domain n
shrink i (BVBitInterval mask lo hi) = BVBitInterval (shr mask) (shr lo) (shr hi)
where
shr x = x `shiftR` widthVal i
trunc ::
(n <= w) =>
NatRepr n ->
Domain w ->
Domain n
trunc n (BVBitInterval _ lo hi) = range n lo hi
select ::
(1 <= n, i + n <= w) =>
NatRepr i ->
NatRepr n ->
Domain w -> Domain n
select i n a = shrink i (trunc (addNat i n) a)
zext :: (1 <= w, w+1 <= u) => Domain w -> NatRepr u -> Domain u
zext (BVBitInterval _ lo hi) u = range u lo hi
sext :: (1 <= w, w+1 <= u) => NatRepr w -> Domain w -> NatRepr u -> Domain u
sext w (BVBitInterval _ lo hi) u = range u lo' hi'
where
lo' = toSigned w lo
hi' = toSigned w hi
testBit :: Domain w -> Natural -> Maybe Bool
testBit (BVBitInterval _mask lo hi) i = if lob == hib then Just lob else Nothing
where
lob = Bits.testBit lo j
hib = Bits.testBit hi j
j = fromIntegral i
shl :: NatRepr w -> Domain w -> Integer -> Domain w
shl w (BVBitInterval mask lo hi) y = BVBitInterval mask (shleft lo) (shleft hi)
where
y' = fromInteger (min y (intValue w))
shleft x = (x `shiftL` y') .&. mask
rol :: NatRepr w -> Domain w -> Integer -> Domain w
rol w (BVBitInterval mask lo hi) y =
BVBitInterval mask (Arith.rotateLeft w lo y) (Arith.rotateLeft w hi y)
ror :: NatRepr w -> Domain w -> Integer -> Domain w
ror w (BVBitInterval mask lo hi) y =
BVBitInterval mask (Arith.rotateRight w lo y) (Arith.rotateRight w hi y)
lshr :: NatRepr w -> Domain w -> Integer -> Domain w
lshr w (BVBitInterval mask lo hi) y = BVBitInterval mask (shr lo) (shr hi)
where
y' = fromInteger (min y (intValue w))
shr x = x `shiftR` y'
ashr :: (1 <= w) => NatRepr w -> Domain w -> Integer -> Domain w
ashr w (BVBitInterval mask lo hi) y = BVBitInterval mask (shr lo) (shr hi)
where
y' = fromInteger (min y (intValue w))
shr x = ((toSigned w x) `shiftR` y') .&. mask
not :: Domain w -> Domain w
not (BVBitInterval mask alo ahi) =
BVBitInterval mask (ahi `Bits.xor` mask) (alo `Bits.xor` mask)
and :: Domain w -> Domain w -> Domain w
and (BVBitInterval mask alo ahi) (BVBitInterval _ blo bhi) =
BVBitInterval mask (alo .&. blo) (ahi .&. bhi)
or :: Domain w -> Domain w -> Domain w
or (BVBitInterval mask alo ahi) (BVBitInterval _ blo bhi) =
BVBitInterval mask (alo .|. blo) (ahi .|. bhi)
xor :: Domain w -> Domain w -> Domain w
xor (BVBitInterval mask alo ahi) (BVBitInterval _ blo bhi) = BVBitInterval mask clo chi
where
au = alo `Bits.xor` ahi
bu = blo `Bits.xor` bhi
c = alo `Bits.xor` blo
cu = au .|. bu
chi = c .|. cu
clo = chi `Bits.xor` cu
pmember :: NatRepr n -> Domain n -> Integer -> Bool
pmember n a x = proper n a && member a x
correct_any :: (1 <= n) => NatRepr n -> Integer -> Property
correct_any n x = property (pmember n (any n) x)
correct_singleton :: (1 <= n) => NatRepr n -> Integer -> Integer -> Property
correct_singleton n x y = property (pmember n (singleton n x') y' == (x' == y'))
where
x' = toUnsigned n x
y' = toUnsigned n y
correct_overlap :: Domain n -> Domain n -> Integer -> Property
correct_overlap a b x =
member a x && member b x ==> domainsOverlap a b
correct_union :: (1 <= n) => NatRepr n -> Domain n -> Domain n -> Integer -> Property
correct_union n a b x =
member a x || member b x ==> pmember n (union a b) x
correct_intersection :: (1 <= n) => Domain n -> Domain n -> Integer -> Property
correct_intersection a b x =
member a x && member b x ==> member (intersection a b) x
correct_zero_ext :: (1 <= w, w+1 <= u) => NatRepr w -> Domain w -> NatRepr u -> Integer -> Property
correct_zero_ext w a u x = member a x' ==> pmember u (zext a u) x'
where
x' = toUnsigned w x
correct_sign_ext :: (1 <= w, w+1 <= u) => NatRepr w -> Domain w -> NatRepr u -> Integer -> Property
correct_sign_ext w a u x = member a x' ==> pmember u (sext w a u) x'
where
x' = toSigned w x
correct_concat :: NatRepr m -> (Domain m,Integer) -> NatRepr n -> (Domain n,Integer) -> Property
correct_concat m (a,x) n (b,y) = member a x' ==> member b y' ==> pmember (addNat m n) (concat m a n b) z
where
x' = toUnsigned m x
y' = toUnsigned n y
z = x' `shiftL` (widthVal n) .|. y'
correct_shrink :: NatRepr i -> NatRepr n -> (Domain (i + n), Integer) -> Property
correct_shrink i n (a,x) = member a x' ==> pmember n (shrink i a) (x' `shiftR` widthVal i)
where
x' = x .&. bvdMask a
correct_trunc :: (n <= w) => NatRepr n -> (Domain w, Integer) -> Property
correct_trunc n (a,x) = member a x' ==> pmember n (trunc n a) (toUnsigned n x')
where
x' = x .&. bvdMask a
correct_select :: (1 <= n, i + n <= w) =>
NatRepr i -> NatRepr n -> (Domain w, Integer) -> Property
correct_select i n (a, x) = member a x ==> pmember n (select i n a) y
where
y = toUnsigned n ((x .&. bvdMask a) `shiftR` (widthVal i))
correct_eq :: (1 <= n) => NatRepr n -> (Domain n, Integer) -> (Domain n, Integer) -> Property
correct_eq n (a,x) (b,y) =
member a x ==> member b y ==>
case eq a b of
Just True -> toUnsigned n x == toUnsigned n y
Just False -> toUnsigned n x /= toUnsigned n y
Nothing -> True
correct_shl :: (1 <= n) => NatRepr n -> (Domain n,Integer) -> Integer -> Property
correct_shl n (a,x) y = member a x ==> pmember n (shl n a y) z
where
z = (toUnsigned n x) `shiftL` fromInteger (min (intValue n) y)
correct_lshr :: (1 <= n) => NatRepr n -> (Domain n, Integer) -> Integer -> Property
correct_lshr n (a,x) y = member a x ==> pmember n (lshr n a y) z
where
z = (toUnsigned n x) `shiftR` fromInteger (min (intValue n) y)
correct_ashr :: (1 <= n) => NatRepr n -> (Domain n, Integer) -> Integer -> Property
correct_ashr n (a,x) y = member a x ==> pmember n (ashr n a y) z
where
z = (toSigned n x) `shiftR` fromInteger (min (intValue n) y)
correct_rol :: (1 <= n) => NatRepr n -> (Domain n,Integer) -> Integer -> Property
correct_rol n (a,x) y = member a x ==> pmember n (rol n a y) (Arith.rotateLeft n x y)
correct_ror :: (1 <= n) => NatRepr n -> (Domain n,Integer) -> Integer -> Property
correct_ror n (a,x) y = member a x ==> pmember n (ror n a y) (Arith.rotateRight n x y)
correct_not :: (1 <= n) => NatRepr n -> (Domain n, Integer) -> Property
correct_not n (a,x) = member a x ==> pmember n (not a) (complement x)
correct_and :: (1 <= n) => NatRepr n -> (Domain n, Integer) -> (Domain n, Integer) -> Property
correct_and n (a,x) (b,y) = member a x ==> member b y ==> pmember n (and a b) (x .&. y)
correct_or :: (1 <= n) => NatRepr n -> (Domain n, Integer) -> (Domain n, Integer) -> Property
correct_or n (a,x) (b,y) = member a x ==> member b y ==> pmember n (or a b) (x .|. y)
correct_xor :: (1 <= n) => NatRepr n -> (Domain n, Integer) -> (Domain n, Integer) -> Property
correct_xor n (a,x) (b,y) = member a x ==> member b y ==> pmember n (xor a b) (x `Bits.xor` y)
correct_testBit :: (1 <= n) => NatRepr n -> (Domain n, Integer) -> Natural -> Property
correct_testBit n (a,x) i =
i < natValue n ==>
case testBit a i of
Just True -> Bits.testBit x (fromIntegral i)
Just False -> Prelude.not (Bits.testBit x (fromIntegral i))
Nothing -> True