-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.WeakestPreconditions.IntDiv
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- Proof of correctness of an imperative integer division algorithm, using
-- weakest preconditions. The algorithm simply keeps subtracting the divisor
-- until the desired quotient and the remainder is found.
-----------------------------------------------------------------------------

{-# LANGUAGE DeriveAnyClass        #-}
{-# LANGUAGE DeriveGeneric         #-}
{-# LANGUAGE DeriveTraversable     #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns        #-}

{-# OPTIONS_GHC -Wall -Werror #-}

module Documentation.SBV.Examples.WeakestPreconditions.IntDiv where

import Data.SBV
import Data.SBV.Control

import Data.SBV.Tools.WeakestPreconditions

import GHC.Generics (Generic)

-- * Program state

-- | The state for the division program, parameterized over a base type @a@.
data DivS a = DivS { forall a. DivS a -> a
x :: a   -- ^ The dividend
                   , forall a. DivS a -> a
y :: a   -- ^ The divisor
                   , forall a. DivS a -> a
q :: a   -- ^ The quotient
                   , forall a. DivS a -> a
r :: a   -- ^ The remainder
                   }
                   deriving (Int -> DivS a -> ShowS
forall a. Show a => Int -> DivS a -> ShowS
forall a. Show a => [DivS a] -> ShowS
forall a. Show a => DivS a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DivS a] -> ShowS
$cshowList :: forall a. Show a => [DivS a] -> ShowS
show :: DivS a -> String
$cshow :: forall a. Show a => DivS a -> String
showsPrec :: Int -> DivS a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> DivS a -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (DivS a) x -> DivS a
forall a x. DivS a -> Rep (DivS a) x
$cto :: forall a x. Rep (DivS a) x -> DivS a
$cfrom :: forall a x. DivS a -> Rep (DivS a) x
Generic, forall a.
Mergeable a =>
Bool -> SBool -> DivS a -> DivS a -> DivS a
forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[DivS a] -> DivS a -> SBV b -> DivS a
forall a.
(Bool -> SBool -> a -> a -> a)
-> (forall b. (Ord b, SymVal b, Num b) => [a] -> a -> SBV b -> a)
-> Mergeable a
select :: forall b.
(Ord b, SymVal b, Num b) =>
[DivS a] -> DivS a -> SBV b -> DivS a
$cselect :: forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[DivS a] -> DivS a -> SBV b -> DivS a
symbolicMerge :: Bool -> SBool -> DivS a -> DivS a -> DivS a
$csymbolicMerge :: forall a.
Mergeable a =>
Bool -> SBool -> DivS a -> DivS a -> DivS a
Mergeable, forall a b. a -> DivS b -> DivS a
forall a b. (a -> b) -> DivS a -> DivS b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> DivS b -> DivS a
$c<$ :: forall a b. a -> DivS b -> DivS a
fmap :: forall a b. (a -> b) -> DivS a -> DivS b
$cfmap :: forall a b. (a -> b) -> DivS a -> DivS b
Functor, forall a. Eq a => a -> DivS a -> Bool
forall a. Num a => DivS a -> a
forall a. Ord a => DivS a -> a
forall m. Monoid m => DivS m -> m
forall a. DivS a -> Bool
forall a. DivS a -> Int
forall a. DivS a -> [a]
forall a. (a -> a -> a) -> DivS a -> a
forall m a. Monoid m => (a -> m) -> DivS a -> m
forall b a. (b -> a -> b) -> b -> DivS a -> b
forall a b. (a -> b -> b) -> b -> DivS a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
product :: forall a. Num a => DivS a -> a
$cproduct :: forall a. Num a => DivS a -> a
sum :: forall a. Num a => DivS a -> a
$csum :: forall a. Num a => DivS a -> a
minimum :: forall a. Ord a => DivS a -> a
$cminimum :: forall a. Ord a => DivS a -> a
maximum :: forall a. Ord a => DivS a -> a
$cmaximum :: forall a. Ord a => DivS a -> a
elem :: forall a. Eq a => a -> DivS a -> Bool
$celem :: forall a. Eq a => a -> DivS a -> Bool
length :: forall a. DivS a -> Int
$clength :: forall a. DivS a -> Int
null :: forall a. DivS a -> Bool
$cnull :: forall a. DivS a -> Bool
toList :: forall a. DivS a -> [a]
$ctoList :: forall a. DivS a -> [a]
foldl1 :: forall a. (a -> a -> a) -> DivS a -> a
$cfoldl1 :: forall a. (a -> a -> a) -> DivS a -> a
foldr1 :: forall a. (a -> a -> a) -> DivS a -> a
$cfoldr1 :: forall a. (a -> a -> a) -> DivS a -> a
foldl' :: forall b a. (b -> a -> b) -> b -> DivS a -> b
$cfoldl' :: forall b a. (b -> a -> b) -> b -> DivS a -> b
foldl :: forall b a. (b -> a -> b) -> b -> DivS a -> b
$cfoldl :: forall b a. (b -> a -> b) -> b -> DivS a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> DivS a -> b
$cfoldr' :: forall a b. (a -> b -> b) -> b -> DivS a -> b
foldr :: forall a b. (a -> b -> b) -> b -> DivS a -> b
$cfoldr :: forall a b. (a -> b -> b) -> b -> DivS a -> b
foldMap' :: forall m a. Monoid m => (a -> m) -> DivS a -> m
$cfoldMap' :: forall m a. Monoid m => (a -> m) -> DivS a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> DivS a -> m
$cfoldMap :: forall m a. Monoid m => (a -> m) -> DivS a -> m
fold :: forall m. Monoid m => DivS m -> m
$cfold :: forall m. Monoid m => DivS m -> m
Foldable, Functor DivS
Foldable DivS
forall (t :: * -> *).
Functor t
-> Foldable t
-> (forall (f :: * -> *) a b.
    Applicative f =>
    (a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall (m :: * -> *) a. Monad m => DivS (m a) -> m (DivS a)
forall (f :: * -> *) a. Applicative f => DivS (f a) -> f (DivS a)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> DivS a -> m (DivS b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> DivS a -> f (DivS b)
sequence :: forall (m :: * -> *) a. Monad m => DivS (m a) -> m (DivS a)
$csequence :: forall (m :: * -> *) a. Monad m => DivS (m a) -> m (DivS a)
mapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> DivS a -> m (DivS b)
$cmapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> DivS a -> m (DivS b)
sequenceA :: forall (f :: * -> *) a. Applicative f => DivS (f a) -> f (DivS a)
$csequenceA :: forall (f :: * -> *) a. Applicative f => DivS (f a) -> f (DivS a)
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> DivS a -> f (DivS b)
$ctraverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> DivS a -> f (DivS b)
Traversable)

-- | Show instance for 'DivS'. The above deriving clause would work just as well,
-- but we want it to be a little prettier here, and hence the @OVERLAPS@ directive.
instance {-# OVERLAPS #-} (SymVal a, Show a) => Show (DivS (SBV a)) where
   show :: DivS (SBV a) -> String
show (DivS SBV a
x SBV a
y SBV a
q SBV a
r) = String
"{x = " forall a. [a] -> [a] -> [a]
++ forall {a}. (Show a, SymVal a) => SBV a -> String
sh SBV a
x forall a. [a] -> [a] -> [a]
++ String
", y = " forall a. [a] -> [a] -> [a]
++ forall {a}. (Show a, SymVal a) => SBV a -> String
sh SBV a
y forall a. [a] -> [a] -> [a]
++ String
", q = " forall a. [a] -> [a] -> [a]
++ forall {a}. (Show a, SymVal a) => SBV a -> String
sh SBV a
q forall a. [a] -> [a] -> [a]
++ String
", r = " forall a. [a] -> [a] -> [a]
++ forall {a}. (Show a, SymVal a) => SBV a -> String
sh SBV a
r forall a. [a] -> [a] -> [a]
++ String
"}"
     where sh :: SBV a -> String
sh SBV a
v = forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"<symbolic>" forall a. Show a => a -> String
show (forall a. SymVal a => SBV a -> Maybe a
unliteral SBV a
v)

-- | 'Fresh' instance for the program state
instance SymVal a => Fresh IO (DivS (SBV a)) where
  fresh :: QueryT IO (DivS (SBV a))
fresh = forall a. a -> a -> a -> a -> DivS a
DivS forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. SymVal a => Query (SBV a)
freshVar_  forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. SymVal a => Query (SBV a)
freshVar_ forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. SymVal a => Query (SBV a)
freshVar_ forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. SymVal a => Query (SBV a)
freshVar_

-- | Helper type synonym
type D = DivS SInteger

-- * The algorithm

-- | The imperative division algorithm, assuming non-negative @x@ and strictly positive @y@:
--
-- @
--    r = x                     -- set remainder to x
--    q = 0                     -- set quotient  to 0
--    while y <= r              -- while we can still subtract
--      r = r - y                    -- reduce the remainder
--      q = q + 1                    -- increase the quotient
-- @
--
-- Note that we need to explicitly annotate each loop with its invariant and the termination
-- measure. For convenience, we take those two as parameters for simplicity.
algorithm :: Invariant D -> Maybe (Measure D) -> Stmt D
algorithm :: Invariant D -> Maybe (Measure D) -> Stmt D
algorithm Invariant D
inv Maybe (Measure D)
msr = forall st. [Stmt st] -> Stmt st
Seq [ forall st. String -> (st -> SBool) -> Stmt st
assert String
"x, y >= 0" forall a b. (a -> b) -> a -> b
$ \DivS{SInteger
x :: SInteger
x :: forall a. DivS a -> a
x, SInteger
y :: SInteger
y :: forall a. DivS a -> a
y} -> SInteger
x forall a. OrdSymbolic a => a -> a -> SBool
.>= SInteger
0 SBool -> SBool -> SBool
.&& SInteger
y forall a. OrdSymbolic a => a -> a -> SBool
.>= SInteger
0
                        , forall st. (st -> st) -> Stmt st
Assign forall a b. (a -> b) -> a -> b
$ \st :: D
st@DivS{SInteger
x :: SInteger
x :: forall a. DivS a -> a
x} -> D
st{r :: SInteger
r = SInteger
x, q :: SInteger
q = SInteger
0}
                        , forall st.
String
-> Invariant st
-> Maybe (Measure st)
-> Invariant st
-> Stmt st
-> Stmt st
While String
"y <= r"
                                Invariant D
inv
                                Maybe (Measure D)
msr
                                (\DivS{SInteger
y :: SInteger
y :: forall a. DivS a -> a
y, SInteger
r :: SInteger
r :: forall a. DivS a -> a
r} -> SInteger
y forall a. OrdSymbolic a => a -> a -> SBool
.<= SInteger
r)
                                forall a b. (a -> b) -> a -> b
$ forall st. (st -> st) -> Stmt st
Assign forall a b. (a -> b) -> a -> b
$ \st :: D
st@DivS{SInteger
y :: SInteger
y :: forall a. DivS a -> a
y, SInteger
q :: SInteger
q :: forall a. DivS a -> a
q, SInteger
r :: SInteger
r :: forall a. DivS a -> a
r} -> D
st{r :: SInteger
r = SInteger
r forall a. Num a => a -> a -> a
- SInteger
y, q :: SInteger
q = SInteger
q forall a. Num a => a -> a -> a
+ SInteger
1}
                        ]

-- | Precondition for our program: @x@ must non-negative and @y@ must be strictly positive.
-- Note that there is an explicit call to 'Data.SBV.Tools.WeakestPreconditions.abort' in our program to protect against this case, so
-- if we do not have this precondition, all programs will fail.
pre :: D -> SBool
pre :: Invariant D
pre DivS{SInteger
x :: SInteger
x :: forall a. DivS a -> a
x, SInteger
y :: SInteger
y :: forall a. DivS a -> a
y} = SInteger
x forall a. OrdSymbolic a => a -> a -> SBool
.>= SInteger
0 SBool -> SBool -> SBool
.&& SInteger
y forall a. OrdSymbolic a => a -> a -> SBool
.> SInteger
0

-- | Postcondition for our program: Remainder must be non-negative and less than @y@,
-- and it must hold that @x = q*y + r@:
post :: D -> SBool
post :: Invariant D
post DivS{SInteger
x :: SInteger
x :: forall a. DivS a -> a
x, SInteger
y :: SInteger
y :: forall a. DivS a -> a
y, SInteger
q :: SInteger
q :: forall a. DivS a -> a
q, SInteger
r :: SInteger
r :: forall a. DivS a -> a
r} = SInteger
r forall a. OrdSymbolic a => a -> a -> SBool
.>= SInteger
0 SBool -> SBool -> SBool
.&& SInteger
r forall a. OrdSymbolic a => a -> a -> SBool
.< SInteger
y SBool -> SBool -> SBool
.&& SInteger
x forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
q forall a. Num a => a -> a -> a
* SInteger
y forall a. Num a => a -> a -> a
+ SInteger
r

-- | Stability: @x@ and @y@ must remain unchanged.
noChange :: Stable D
noChange :: Stable D
noChange = [forall a st.
EqSymbolic a =>
String -> (st -> a) -> st -> st -> (String, SBool)
stable String
"x" forall a. DivS a -> a
x, forall a st.
EqSymbolic a =>
String -> (st -> a) -> st -> st -> (String, SBool)
stable String
"y" forall a. DivS a -> a
y]

-- | A program is the algorithm, together with its pre- and post-conditions.
imperativeDiv :: Invariant D -> Maybe (Measure D) -> Program D
imperativeDiv :: Invariant D -> Maybe (Measure D) -> Program D
imperativeDiv Invariant D
inv Maybe (Measure D)
msr = Program { setup :: Symbolic ()
setup         = forall (m :: * -> *) a. Monad m => a -> m a
return ()
                                , precondition :: Invariant D
precondition  = Invariant D
pre
                                , program :: Stmt D
program       = Invariant D -> Maybe (Measure D) -> Stmt D
algorithm Invariant D
inv Maybe (Measure D)
msr
                                , postcondition :: Invariant D
postcondition = Invariant D
post
                                , stability :: Stable D
stability     = Stable D
noChange
                                }

-- * Correctness

-- | The invariant is simply that @x = q * y + r@ holds at all times and @r@ is strictly positive.
-- We need the @y > 0@ part of the invariant to establish the measure decreases, which is guaranteed
-- by our precondition.
invariant :: Invariant D
invariant :: Invariant D
invariant DivS{SInteger
x :: SInteger
x :: forall a. DivS a -> a
x, SInteger
y :: SInteger
y :: forall a. DivS a -> a
y, SInteger
q :: SInteger
q :: forall a. DivS a -> a
q, SInteger
r :: SInteger
r :: forall a. DivS a -> a
r} = SInteger
y forall a. OrdSymbolic a => a -> a -> SBool
.> SInteger
0 SBool -> SBool -> SBool
.&& SInteger
r forall a. OrdSymbolic a => a -> a -> SBool
.>= SInteger
0 SBool -> SBool -> SBool
.&& SInteger
x forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
q forall a. Num a => a -> a -> a
* SInteger
y forall a. Num a => a -> a -> a
+ SInteger
r

-- | The measure. In each iteration @r@ decreases, but always remains positive.
-- Since @y@ is strictly positive, @r@ can serve as a measure for the loop.
measure :: Measure D
measure :: Measure D
measure DivS{SInteger
r :: SInteger
r :: forall a. DivS a -> a
r} = [SInteger
r]

-- | Check that the program terminates and the post condition holds. We have:
--
-- >>> correctness
-- Total correctness is established.
-- Q.E.D.
correctness :: IO ()
correctness :: IO ()
correctness = forall a. Show a => a -> IO ()
print forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall st res.
(Show res, Mergeable st, Queriable IO st res) =>
WPConfig -> Program st -> IO (ProofResult res)
wpProveWith WPConfig
defaultWPCfg{wpVerbose :: Bool
wpVerbose=Bool
True} (Invariant D -> Maybe (Measure D) -> Program D
imperativeDiv Invariant D
invariant (forall a. a -> Maybe a
Just Measure D
measure))