-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.BitPrecise.MergeSort
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- Symbolic implementation of merge-sort and its correctness.
-----------------------------------------------------------------------------

{-# OPTIONS_GHC -Wall -Werror #-}

module Documentation.SBV.Examples.BitPrecise.MergeSort where

import Data.SBV
import Data.SBV.Tools.CodeGen

-----------------------------------------------------------------------------
-- * Implementing Merge-Sort
-----------------------------------------------------------------------------
-- | Element type of lists we'd like to sort. For simplicity, we'll just
-- use 'SWord8' here, but we can pick any symbolic type.
type E = SWord8

-- | Merging two given sorted lists, preserving the order.
merge :: [E] -> [E] -> [E]
merge :: [E] -> [E] -> [E]
merge []     [E]
ys           = [E]
ys
merge [E]
xs     []           = [E]
xs
merge xs :: [E]
xs@(E
x:[E]
xr) ys :: [E]
ys@(E
y:[E]
yr) = forall a. Mergeable a => SBool -> a -> a -> a
ite (E
x forall a. OrdSymbolic a => a -> a -> SBool
.< E
y) (E
x forall a. a -> [a] -> [a]
: [E] -> [E] -> [E]
merge [E]
xr [E]
ys) (E
y forall a. a -> [a] -> [a]
: [E] -> [E] -> [E]
merge [E]
xs [E]
yr)

-- | Simple merge-sort implementation. We simply divide the input list
-- in two halves so long as it has at least two elements, sort
-- each half on its own, and then merge.
mergeSort :: [E] -> [E]
mergeSort :: [E] -> [E]
mergeSort []  = []
mergeSort [E
x] = [E
x]
mergeSort [E]
xs  = [E] -> [E] -> [E]
merge ([E] -> [E]
mergeSort [E]
th) ([E] -> [E]
mergeSort [E]
bh)
   where ([E]
th, [E]
bh) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [E]
xs forall a. Integral a => a -> a -> a
`div` Int
2) [E]
xs

-----------------------------------------------------------------------------
-- * Proving correctness
-- ${props}
-----------------------------------------------------------------------------
{- $props
There are two main parts to proving that a sorting algorithm is correct:

       * Prove that the output is non-decreasing
 
       * Prove that the output is a permutation of the input
-}

-- | Check whether a given sequence is non-decreasing.
nonDecreasing :: [E] -> SBool
nonDecreasing :: [E] -> SBool
nonDecreasing []       = SBool
sTrue
nonDecreasing [E
_]      = SBool
sTrue
nonDecreasing (E
a:E
b:[E]
xs) = E
a forall a. OrdSymbolic a => a -> a -> SBool
.<= E
b SBool -> SBool -> SBool
.&& [E] -> SBool
nonDecreasing (E
bforall a. a -> [a] -> [a]
:[E]
xs)

-- | Check whether two given sequences are permutations. We simply check that each sequence
-- is a subset of the other, when considered as a set. The check is slightly complicated
-- for the need to account for possibly duplicated elements.
isPermutationOf :: [E] -> [E] -> SBool
isPermutationOf :: [E] -> [E] -> SBool
isPermutationOf [E]
as [E]
bs = forall {a}.
(Mergeable a, EqSymbolic a) =>
[a] -> [(a, SBool)] -> SBool
go [E]
as (forall a b. [a] -> [b] -> [(a, b)]
zip [E]
bs (forall a. a -> [a]
repeat SBool
sTrue)) SBool -> SBool -> SBool
.&& forall {a}.
(Mergeable a, EqSymbolic a) =>
[a] -> [(a, SBool)] -> SBool
go [E]
bs (forall a b. [a] -> [b] -> [(a, b)]
zip [E]
as (forall a. a -> [a]
repeat SBool
sTrue))
  where go :: [a] -> [(a, SBool)] -> SBool
go []     [(a, SBool)]
_  = SBool
sTrue
        go (a
x:[a]
xs) [(a, SBool)]
ys = let (SBool
found, [(a, SBool)]
ys') = forall {a}.
(Mergeable a, EqSymbolic a) =>
a -> [(a, SBool)] -> (SBool, [(a, SBool)])
mark a
x [(a, SBool)]
ys in SBool
found SBool -> SBool -> SBool
.&& [a] -> [(a, SBool)] -> SBool
go [a]
xs [(a, SBool)]
ys'
        -- Go and mark off an instance of 'x' in the list, if possible. We keep track
        -- of unmarked elements by associating a boolean bit. Note that we have to
        -- keep the lists equal size for the recursive result to merge properly.
        mark :: a -> [(a, SBool)] -> (SBool, [(a, SBool)])
mark a
_ []         = (SBool
sFalse, [])
        mark a
x ((a
y,SBool
v):[(a, SBool)]
ys) = forall a. Mergeable a => SBool -> a -> a -> a
ite (SBool
v SBool -> SBool -> SBool
.&& a
x forall a. EqSymbolic a => a -> a -> SBool
.== a
y)
                                (SBool
sTrue, (a
y, SBool -> SBool
sNot SBool
v)forall a. a -> [a] -> [a]
:[(a, SBool)]
ys)
                                (let (SBool
r, [(a, SBool)]
ys') = a -> [(a, SBool)] -> (SBool, [(a, SBool)])
mark a
x [(a, SBool)]
ys in (SBool
r, (a
y,SBool
v)forall a. a -> [a] -> [a]
:[(a, SBool)]
ys'))

-- | Asserting correctness of merge-sort for a list of the given size. Note that we can
-- only check correctness for fixed-size lists. Also, the proof will get more and more
-- complicated for the backend SMT solver as the list size increases. A value around
-- 5 or 6 should be fairly easy to prove. For instance, we have:
--
-- >>> correctness 5
-- Q.E.D.
correctness :: Int -> IO ThmResult
correctness :: Int -> IO ThmResult
correctness Int
n = forall a. Provable a => a -> IO ThmResult
prove forall a b. (a -> b) -> a -> b
$ do [E]
xs <- forall a. SymVal a => Int -> Symbolic [SBV a]
mkFreeVars Int
n
                           let ys :: [E]
ys = [E] -> [E]
mergeSort [E]
xs
                           forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [E] -> SBool
nonDecreasing [E]
ys SBool -> SBool -> SBool
.&& [E] -> [E] -> SBool
isPermutationOf [E]
xs [E]
ys

-----------------------------------------------------------------------------
-- * Generating C code
-----------------------------------------------------------------------------

-- | Generate C code for merge-sorting an array of size @n@. Again, we're restricted
-- to fixed size inputs. While the output is not how one would code merge sort in C
-- by hand, it's a faithful rendering of all the operations merge-sort would do as
-- described by its Haskell counterpart.
codeGen :: Int -> IO ()
codeGen :: Int -> IO ()
codeGen Int
n = forall a. Maybe FilePath -> FilePath -> SBVCodeGen a -> IO a
compileToC (forall a. a -> Maybe a
Just (FilePath
"mergeSort" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> FilePath
show Int
n)) FilePath
"mergeSort" forall a b. (a -> b) -> a -> b
$ do
                [E]
xs <- forall a. SymVal a => Int -> FilePath -> SBVCodeGen [SBV a]
cgInputArr Int
n FilePath
"xs"
                forall a. SymVal a => FilePath -> [SBV a] -> SBVCodeGen ()
cgOutputArr FilePath
"ys" ([E] -> [E]
mergeSort [E]
xs)