{-|
Module      : What4.Utils.BVDomain.Bitwise
Copyright   : (c) Galois Inc, 2020
License     : BSD3
Maintainer  : huffman@galois.com

Provides a bitwise implementation of bitvector abstract domains.
-}

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}

module What4.Utils.BVDomain.Bitwise
  ( Domain(..)
  , bitle
  , proper
  , bvdMask
  , member
  , pmember
  , size
  , asSingleton
  , nonempty
  , eq
  , domainsOverlap
  , bitbounds
  -- * Operations
  , any
  , singleton
  , range
  , interval
  , union
  , intersection
  , concat
  , select
  , zext
  , sext
  , testBit
  -- ** shifts and rotates
  , shl
  , lshr
  , ashr
  , rol
  , ror
  -- ** bitwise logical
  , and
  , or
  , xor
  , not

  -- * Correctness properties
  , genDomain
  , genElement
  , genPair
  , correct_any
  , correct_singleton
  , correct_overlap
  , correct_union
  , correct_intersection
  , correct_zero_ext
  , correct_sign_ext
  , correct_concat
  , correct_shrink
  , correct_trunc
  , correct_select
  , correct_shl
  , correct_lshr
  , correct_ashr
  , correct_rol
  , correct_ror
  , correct_eq
  , correct_and
  , correct_or
  , correct_not
  , correct_xor
  , correct_testBit
  ) where

import           Data.Bits hiding (testBit, xor)
import qualified Data.Bits as Bits
import           Data.Parameterized.NatRepr
import           Numeric.Natural
import           GHC.TypeNats
import           Test.Verification (Property, property, (==>), Gen, chooseInteger)

import qualified Prelude
import           Prelude hiding (any, concat, negate, and, or, not)

import qualified What4.Utils.Arithmetic as Arith

-- | A bitwise interval domain, defined via a
--   bitwise upper and lower bound.  The ordering
--   used here to construct the interval is the pointwise
--   ordering on bits.  In particular @x [= y iff x .|. y == y@,
--   and a value @x@ is in the set defined by the pair @(lo,hi)@
--   just when @lo [= x && x [= hi@.
data Domain (w :: Nat) =
  BVBitInterval !Integer !Integer !Integer
  -- ^ @BVDBitInterval mask lo hi@.
  --  @mask@ caches the value of @2^w - 1@
 deriving (Show)

-- | Test if the domain satisfies its invariants
proper :: NatRepr w -> Domain w -> Bool
proper w (BVBitInterval mask lo hi) =
  mask == maxUnsigned w &&
  bitle lo mask &&
  bitle hi mask &&
  bitle lo hi

-- | Test if the given integer value is a member of the abstract domain
member :: Domain w -> Integer -> Bool
member (BVBitInterval mask lo hi) x = bitle lo x' && bitle x' hi
  where x' = x .&. mask

-- | Compute how many concrete elements are in the abstract domain
size :: Domain w -> Integer
size (BVBitInterval _ lo hi)
  | bitle lo hi = Bits.bit p
  | otherwise   = 0
 where
 u = Bits.xor lo hi
 p = Bits.popCount u

bitle :: Integer -> Integer -> Bool
bitle x y = (x .|. y) == y

-- | Return the bitvector mask value from this domain
bvdMask :: Domain w -> Integer
bvdMask (BVBitInterval mask _ _) = mask

-- | Random generator for domain values.  We always generate
--   nonempty domain values.
genDomain :: NatRepr w -> Gen (Domain w)
genDomain w =
  do let mask = maxUnsigned w
     lo <- chooseInteger (0, mask)
     hi <- chooseInteger (0, mask)
     pure $! interval mask lo (lo .|. hi)

-- This generator goes to some pains to try
-- to generate a good statistical distribution
-- of the values in the domain.  It only choses
-- random bits for the "unknown" values of
-- the domain, then stripes them out among
-- the unknown bit positions.
genElement :: Domain w -> Gen Integer
genElement (BVBitInterval _mask lo hi) =
  do x <- chooseInteger (0, bit bs - 1)
     pure $ stripe lo x 0

 where
 u = Bits.xor lo hi
 bs = Bits.popCount u
 stripe val x i
   | x == 0 = val
   | Bits.testBit u i =
       let val' = if Bits.testBit x 0 then setBit val i else val in
       stripe val' (x `shiftR` 1) (i+1)
   | otherwise = stripe val x (i+1)

{- A faster generator, but I worry that it
   doesn't have very good statistical properties...

genElement :: Domain w -> Gen Integer
genElement (BVBitInterval mask lo hi) =
  do let u = Bits.xor lo hi
     x <- chooseInteger (0, mask)
     pure ((x .&. u) .|. lo)
-}

-- | Generate a random nonempty domain and an element
--   contained in that domain.
genPair :: NatRepr w -> Gen (Domain w, Integer)
genPair w =
  do a <- genDomain w
     x <- genElement a
     return (a,x)

-- | Unsafe constructor for internal use.
interval :: Integer -> Integer -> Integer -> Domain w
interval mask lo hi = BVBitInterval mask lo hi

-- | Construct a domain from bitwise lower and upper bounds
range :: NatRepr w -> Integer -> Integer -> Domain w
range w lo hi = BVBitInterval (maxUnsigned w) lo' hi'
  where
  lo'  = lo .&. mask
  hi'  = hi .&. mask
  mask = maxUnsigned w

-- | Bitwise lower and upper bounds
bitbounds :: Domain w -> (Integer, Integer)
bitbounds (BVBitInterval _ lo hi) = (lo, hi)

-- | Test if this domain contains a single value, and return it if so
asSingleton :: Domain w -> Maybe Integer
asSingleton (BVBitInterval _ lo hi) = if lo == hi then Just lo else Nothing

-- | Returns true iff there is at least on element
--   in this bitwise domain.
nonempty :: Domain w -> Bool
nonempty (BVBitInterval _mask lo hi) = bitle lo hi

-- | Return a domain containing just the given value
singleton :: NatRepr w -> Integer -> Domain w
singleton w x = BVBitInterval mask x' x'
  where
  x' = x .&. mask
  mask = maxUnsigned w

-- | Bitwise domain containing every bitvector value
any :: NatRepr w -> Domain w
any w = BVBitInterval mask 0 mask
  where
  mask = maxUnsigned w

-- | Returns true iff the domains have some value in common
domainsOverlap :: Domain w -> Domain w -> Bool
domainsOverlap a b = nonempty (intersection a b)

eq :: Domain w -> Domain w -> Maybe Bool
eq a b
  | Just x <- asSingleton a
  , Just y <- asSingleton b
  = Just (x == y)

  | Prelude.not (domainsOverlap a b) = Just False
  | otherwise = Nothing

intersection :: Domain w -> Domain w -> Domain w
intersection (BVBitInterval mask alo ahi) (BVBitInterval _ blo bhi) =
  BVBitInterval mask (alo .|. blo) (ahi .&. bhi)

union :: Domain w -> Domain w -> Domain w
union (BVBitInterval mask alo ahi) (BVBitInterval _ blo bhi) =
  BVBitInterval mask (alo .&. blo) (ahi .|. bhi)

-- | @concat a y@ returns domain where each element in @a@ has been
-- concatenated with an element in @y@.  The most-significant bits
-- are @a@, and the least significant bits are @y@.
concat :: NatRepr u -> Domain u -> NatRepr v -> Domain v -> Domain (u + v)
concat u (BVBitInterval _ alo ahi) v (BVBitInterval _ blo bhi) =
    BVBitInterval mask (cat alo blo) (cat ahi bhi)
  where
    cat i j = (i `shiftL` widthVal v) + j
    mask = maxUnsigned (addNat u v)

-- | @shrink i a@ drops the @i@ least significant bits from @a@.
shrink ::
  NatRepr i ->
  Domain (i + n) -> Domain n
shrink i (BVBitInterval mask lo hi) = BVBitInterval (shr mask) (shr lo) (shr hi)
  where
  shr x = x `shiftR` widthVal i

-- | @trunc n d@ selects the @n@ least significant bits from @d@.
trunc ::
  (n <= w) =>
  NatRepr n ->
  Domain w ->
  Domain n
trunc n (BVBitInterval _ lo hi) = range n lo hi

-- | @select i n a@ selects @n@ bits starting from index @i@ from @a@.
select ::
  (1 <= n, i + n <= w) =>
  NatRepr i ->
  NatRepr n ->
  Domain w -> Domain n
select i n a = shrink i (trunc (addNat i n) a)

zext :: (1 <= w, w+1 <= u) => Domain w -> NatRepr u -> Domain u
zext (BVBitInterval _ lo hi) u = range u lo hi

sext :: (1 <= w, w+1 <= u) => NatRepr w -> Domain w -> NatRepr u -> Domain u
sext w (BVBitInterval _ lo hi) u = range u lo' hi'
  where
  lo' = toSigned w lo
  hi' = toSigned w hi

testBit :: Domain w -> Natural -> Maybe Bool
testBit (BVBitInterval _mask lo hi) i = if lob == hib then Just lob else Nothing
  where
  lob = Bits.testBit lo j
  hib = Bits.testBit hi j
  j = fromIntegral i

shl :: NatRepr w -> Domain w -> Integer -> Domain w
shl w (BVBitInterval mask lo hi) y = BVBitInterval mask (shleft lo) (shleft hi)
  where
  y' = fromInteger (min y (intValue w))
  shleft x = (x `shiftL` y') .&. mask

rol :: NatRepr w -> Domain w -> Integer -> Domain w
rol w (BVBitInterval mask lo hi) y =
  BVBitInterval mask (Arith.rotateLeft w lo y) (Arith.rotateLeft w hi y)

ror :: NatRepr w -> Domain w -> Integer -> Domain w
ror w (BVBitInterval mask lo hi) y =
  BVBitInterval mask (Arith.rotateRight w lo y) (Arith.rotateRight w hi y)

lshr :: NatRepr w -> Domain w -> Integer -> Domain w
lshr w (BVBitInterval mask lo hi) y = BVBitInterval mask (shr lo) (shr hi)
  where
  y' = fromInteger (min y (intValue w))
  shr x = x `shiftR` y'

ashr :: (1 <= w) => NatRepr w -> Domain w -> Integer -> Domain w
ashr w (BVBitInterval mask lo hi) y = BVBitInterval mask (shr lo) (shr hi)
  where
  y' = fromInteger (min y (intValue w))
  shr x = ((toSigned w x) `shiftR` y') .&. mask

not :: Domain w -> Domain w
not (BVBitInterval mask alo ahi) =
  BVBitInterval mask (ahi `Bits.xor` mask) (alo `Bits.xor` mask)

and :: Domain w -> Domain w -> Domain w
and (BVBitInterval mask alo ahi) (BVBitInterval _ blo bhi) =
  BVBitInterval mask (alo .&. blo) (ahi .&. bhi)

or :: Domain w -> Domain w -> Domain w
or (BVBitInterval mask alo ahi) (BVBitInterval _ blo bhi) =
  BVBitInterval mask (alo .|. blo) (ahi .|. bhi)

xor :: Domain w -> Domain w -> Domain w
xor (BVBitInterval mask alo ahi) (BVBitInterval _ blo bhi) = BVBitInterval mask clo chi
  where
  au  = alo `Bits.xor` ahi
  bu  = blo `Bits.xor` bhi
  c   = alo `Bits.xor` blo
  cu  = au .|. bu
  chi = c  .|. cu
  clo = chi `Bits.xor` cu


---------------------------------------------------------------------------------------
-- Correctness properties

-- | Check that a domain is proper, and that
--   the given value is a member
pmember :: NatRepr n -> Domain n -> Integer -> Bool
pmember n a x = proper n a && member a x

correct_any :: (1 <= n) => NatRepr n -> Integer -> Property
correct_any n x = property (pmember n (any n) x)

correct_singleton :: (1 <= n) => NatRepr n -> Integer -> Integer -> Property
correct_singleton n x y = property (pmember n (singleton n x') y' == (x' == y'))
  where
  x' = toUnsigned n x
  y' = toUnsigned n y

correct_overlap :: Domain n -> Domain n -> Integer -> Property
correct_overlap a b x =
  member a x && member b x ==> domainsOverlap a b

correct_union :: (1 <= n) => NatRepr n -> Domain n -> Domain n -> Integer -> Property
correct_union n a b x =
  member a x || member b x ==> pmember n (union a b) x

correct_intersection :: (1 <= n) => Domain n -> Domain n -> Integer -> Property
correct_intersection a b x = -- NB, intersection might not be proper
  member a x && member b x ==> member (intersection a b) x

correct_zero_ext :: (1 <= w, w+1 <= u) => NatRepr w -> Domain w -> NatRepr u -> Integer -> Property
correct_zero_ext w a u x = member a x' ==> pmember u (zext a u) x'
  where
  x' = toUnsigned w x

correct_sign_ext :: (1 <= w, w+1 <= u) => NatRepr w -> Domain w -> NatRepr u -> Integer -> Property
correct_sign_ext w a u x = member a x' ==> pmember u (sext w a u) x'
  where
  x' = toSigned w x

correct_concat :: NatRepr m -> (Domain m,Integer) -> NatRepr n -> (Domain n,Integer) -> Property
correct_concat m (a,x) n (b,y) = member a x' ==> member b y' ==> pmember (addNat m n) (concat m a n b) z
  where
  x' = toUnsigned m x
  y' = toUnsigned n y
  z  = x' `shiftL` (widthVal n) .|. y'

correct_shrink :: NatRepr i -> NatRepr n -> (Domain (i + n), Integer) -> Property
correct_shrink i n (a,x) = member a x' ==> pmember n (shrink i a) (x' `shiftR` widthVal i)
  where
  x' = x .&. bvdMask a

correct_trunc :: (n <= w) => NatRepr n -> (Domain w, Integer) -> Property
correct_trunc n (a,x) = member a x' ==> pmember n (trunc n a) (toUnsigned n x')
  where
  x' = x .&. bvdMask a

correct_select :: (1 <= n, i + n <= w) =>
  NatRepr i -> NatRepr n -> (Domain w, Integer) -> Property
correct_select i n (a, x) = member a x ==> pmember n (select i n a) y
  where
  y = toUnsigned n ((x .&. bvdMask a) `shiftR` (widthVal i))

correct_eq :: (1 <= n) => NatRepr n -> (Domain n, Integer) -> (Domain n, Integer) -> Property
correct_eq n (a,x) (b,y) =
  member a x ==> member b y ==>
    case eq a b of
      Just True  -> toUnsigned n x == toUnsigned n y
      Just False -> toUnsigned n x /= toUnsigned n y
      Nothing    -> True

correct_shl :: (1 <= n) => NatRepr n -> (Domain n,Integer) -> Integer -> Property
correct_shl n (a,x) y = member a x ==> pmember n (shl n a y) z
  where
  z = (toUnsigned n x) `shiftL` fromInteger (min (intValue n) y)

correct_lshr :: (1 <= n) => NatRepr n -> (Domain n, Integer) -> Integer -> Property
correct_lshr n (a,x) y = member a x ==> pmember n (lshr n a y) z
  where
  z = (toUnsigned n x) `shiftR` fromInteger (min (intValue n) y)

correct_ashr :: (1 <= n) => NatRepr n -> (Domain n, Integer) -> Integer -> Property
correct_ashr n (a,x) y = member a x ==> pmember n (ashr n a y) z
  where
  z = (toSigned n x) `shiftR` fromInteger (min (intValue n) y)

correct_rol :: (1 <= n) => NatRepr n -> (Domain n,Integer) -> Integer -> Property
correct_rol n (a,x) y = member a x ==> pmember n (rol n a y) (Arith.rotateLeft n x y)

correct_ror :: (1 <= n) => NatRepr n -> (Domain n,Integer) -> Integer -> Property
correct_ror n (a,x) y = member a x ==> pmember n (ror n a y) (Arith.rotateRight n x y)

correct_not :: (1 <= n) => NatRepr n -> (Domain n, Integer) -> Property
correct_not n (a,x) = member a x ==> pmember n (not a) (complement x)

correct_and :: (1 <= n) => NatRepr n -> (Domain n, Integer) -> (Domain n, Integer) -> Property
correct_and n (a,x) (b,y) = member a x ==> member b y ==> pmember n (and a b) (x .&. y)

correct_or :: (1 <= n) => NatRepr n -> (Domain n, Integer) -> (Domain n, Integer) -> Property
correct_or n (a,x) (b,y) = member a x ==> member b y ==> pmember n (or a b) (x .|. y)

correct_xor :: (1 <= n) => NatRepr n -> (Domain n, Integer) -> (Domain n, Integer) -> Property
correct_xor n (a,x) (b,y) = member a x ==> member b y ==> pmember n (xor a b) (x `Bits.xor` y)

correct_testBit :: (1 <= n) => NatRepr n -> (Domain n, Integer) -> Natural -> Property
correct_testBit n (a,x) i =
  i < natValue n ==>
    case testBit a i of
      Just True  -> Bits.testBit x (fromIntegral i)
      Just False -> Prelude.not (Bits.testBit x (fromIntegral i))
      Nothing    -> True