-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.Existentials.Diophantine
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- Finding minimal natural number solutions to linear Diophantine equations,
-- using explicit quantification.
-----------------------------------------------------------------------------

{-# OPTIONS_GHC -Wall -Werror #-}

module Documentation.SBV.Examples.Existentials.Diophantine where

import Data.SBV

--------------------------------------------------------------------------------------------------
-- * Representing solutions
--------------------------------------------------------------------------------------------------
-- | For a homogeneous problem, the solution is any linear combination of the resulting vectors.
-- For a non-homogeneous problem, the solution is any linear combination of the vectors in the
-- second component plus one of the vectors in the first component.
data Solution = Homogeneous    [[Integer]]
              | NonHomogeneous [[Integer]] [[Integer]]
              deriving Int -> Solution -> ShowS
[Solution] -> ShowS
Solution -> String
(Int -> Solution -> ShowS)
-> (Solution -> String) -> ([Solution] -> ShowS) -> Show Solution
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Solution] -> ShowS
$cshowList :: [Solution] -> ShowS
show :: Solution -> String
$cshow :: Solution -> String
showsPrec :: Int -> Solution -> ShowS
$cshowsPrec :: Int -> Solution -> ShowS
Show

--------------------------------------------------------------------------------------------------
-- * Solving diophantine equations
--------------------------------------------------------------------------------------------------
-- | ldn: Solve a (L)inear (D)iophantine equation, returning minimal solutions over (N)aturals.
-- The input is given as a rows of equations, with rhs values separated into a tuple. The first
-- parameter limits the search to bound: In case there are too many solutions, you might want
-- to limit your search space.
ldn :: Maybe Int -> [([Integer], Integer)] -> IO Solution
ldn :: Maybe Int -> [([Integer], Integer)] -> IO Solution
ldn Maybe Int
mbLim [([Integer], Integer)]
problem = do [[Integer]]
solution <- Maybe Int -> [[SInteger]] -> IO [[Integer]]
basis Maybe Int
mbLim (([Integer] -> [SInteger]) -> [[Integer]] -> [[SInteger]]
forall a b. (a -> b) -> [a] -> [b]
map ((Integer -> SInteger) -> [Integer] -> [SInteger]
forall a b. (a -> b) -> [a] -> [b]
map Integer -> SInteger
forall a. SymVal a => a -> SBV a
literal) [[Integer]]
m)
                       if Bool
homogeneous
                           then Solution -> IO Solution
forall (m :: * -> *) a. Monad m => a -> m a
return (Solution -> IO Solution) -> Solution -> IO Solution
forall a b. (a -> b) -> a -> b
$ [[Integer]] -> Solution
Homogeneous [[Integer]]
solution
                           else do let ones :: [[Integer]]
ones  = [[Integer]
xs | (Integer
1:[Integer]
xs) <- [[Integer]]
solution]
                                       zeros :: [[Integer]]
zeros = [[Integer]
xs | (Integer
0:[Integer]
xs) <- [[Integer]]
solution]
                                   Solution -> IO Solution
forall (m :: * -> *) a. Monad m => a -> m a
return (Solution -> IO Solution) -> Solution -> IO Solution
forall a b. (a -> b) -> a -> b
$ [[Integer]] -> [[Integer]] -> Solution
NonHomogeneous [[Integer]]
ones [[Integer]]
zeros
  where rhs :: [Integer]
rhs = (([Integer], Integer) -> Integer)
-> [([Integer], Integer)] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map ([Integer], Integer) -> Integer
forall a b. (a, b) -> b
snd [([Integer], Integer)]
problem
        lhs :: [[Integer]]
lhs = (([Integer], Integer) -> [Integer])
-> [([Integer], Integer)] -> [[Integer]]
forall a b. (a -> b) -> [a] -> [b]
map ([Integer], Integer) -> [Integer]
forall a b. (a, b) -> a
fst [([Integer], Integer)]
problem
        homogeneous :: Bool
homogeneous = (Integer -> Bool) -> [Integer] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0) [Integer]
rhs
        m :: [[Integer]]
m | Bool
homogeneous = [[Integer]]
lhs
          | Bool
True        = (Integer -> [Integer] -> [Integer])
-> [Integer] -> [[Integer]] -> [[Integer]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Integer
x [Integer]
y -> -Integer
x Integer -> [Integer] -> [Integer]
forall a. a -> [a] -> [a]
: [Integer]
y) [Integer]
rhs [[Integer]]
lhs

-- | Find the basis solution. By definition, the basis has all non-trivial (i.e., non-0) solutions
-- that cannot be written as the sum of two other solutions. We use the mathematically equivalent
-- statement that a solution is in the basis if it's least according to the natural partial
-- order using the ordinary less-than relation.
basis :: Maybe Int -> [[SInteger]] -> IO [[Integer]]
basis :: Maybe Int -> [[SInteger]] -> IO [[Integer]]
basis Maybe Int
mbLim [[SInteger]]
m = AllSatResult -> [[Integer]]
forall a. SatModel a => AllSatResult -> [a]
extractModels (AllSatResult -> [[Integer]]) -> IO AllSatResult -> IO [[Integer]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` SMTConfig -> SymbolicT IO SBool -> IO AllSatResult
forall a. Provable a => SMTConfig -> a -> IO AllSatResult
allSatWith SMTConfig
z3{allSatMaxModelCount :: Maybe Int
allSatMaxModelCount = Maybe Int
mbLim} SymbolicT IO SBool
cond
 where cond :: SymbolicT IO SBool
cond = do [SInteger]
as <- Int -> Symbolic [SInteger]
forall a. SymVal a => Int -> Symbolic [SBV a]
mkExistVars  Int
n
                 [SInteger]
bs <- Int -> Symbolic [SInteger]
forall a. SymVal a => Int -> Symbolic [SBV a]
mkForallVars Int
n

                 SBool -> SymbolicT IO SBool
forall (m :: * -> *) a. Monad m => a -> m a
return (SBool -> SymbolicT IO SBool) -> SBool -> SymbolicT IO SBool
forall a b. (a -> b) -> a -> b
$ [SInteger] -> SBool
ok [SInteger]
as SBool -> SBool -> SBool
.&& ([SInteger] -> SBool
ok [SInteger]
bs SBool -> SBool -> SBool
.=> [SInteger]
as [SInteger] -> [SInteger] -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== [SInteger]
bs SBool -> SBool -> SBool
.|| SBool -> SBool
sNot ([SInteger]
bs [SInteger] -> [SInteger] -> SBool
forall a. OrdSymbolic a => [a] -> [a] -> SBool
`less` [SInteger]
as))

       n :: Int
n = if [[SInteger]] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [[SInteger]]
m then Int
0 else [SInteger] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([[SInteger]] -> [SInteger]
forall a. [a] -> a
head [[SInteger]]
m)

       ok :: [SInteger] -> SBool
ok [SInteger]
xs = (SInteger -> SBool) -> [SInteger] -> SBool
forall a. (a -> SBool) -> [a] -> SBool
sAny (SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.> SInteger
0) [SInteger]
xs SBool -> SBool -> SBool
.&& (SInteger -> SBool) -> [SInteger] -> SBool
forall a. (a -> SBool) -> [a] -> SBool
sAll (SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.>= SInteger
0) [SInteger]
xs SBool -> SBool -> SBool
.&& [SBool] -> SBool
sAnd [[SInteger] -> SInteger
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((SInteger -> SInteger -> SInteger)
-> [SInteger] -> [SInteger] -> [SInteger]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
(*) [SInteger]
r [SInteger]
xs) SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
0 | [SInteger]
r <- [[SInteger]]
m]

       [a]
as less :: [a] -> [a] -> SBool
`less` [a]
bs = [SBool] -> SBool
sAnd ((a -> a -> SBool) -> [a] -> [a] -> [SBool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith a -> a -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
(.<=) [a]
as [a]
bs) SBool -> SBool -> SBool
.&& [SBool] -> SBool
sOr ((a -> a -> SBool) -> [a] -> [a] -> [SBool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith a -> a -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
(.<) [a]
as [a]
bs)

--------------------------------------------------------------------------------------------------
-- * Examples
--------------------------------------------------------------------------------------------------

-- | Solve the equation:
--
--    @2x + y - z = 2@
--
-- We have:
--
-- >>> test
-- NonHomogeneous [[1,0,0],[0,2,0]] [[0,1,1],[1,0,2]]
--
-- which means that the solutions are of the form:
--
--    @(1, 0, 0) + k (0, 1, 1) + k' (1, 0, 2) = (1+k', k, k+2k')@
--
-- OR
--
--    @(0, 2, 0) + k (0, 1, 1) + k' (1, 0, 2) = (k', 2+k, k+2k')@
--
-- for arbitrary @k@, @k'@. It's easy to see that these are really solutions
-- to the equation given. It's harder to see that they cover all possibilities,
-- but a moments thought reveals that is indeed the case.
test :: IO Solution
test :: IO Solution
test = Maybe Int -> [([Integer], Integer)] -> IO Solution
ldn Maybe Int
forall a. Maybe a
Nothing [([Integer
2,Integer
1,-Integer
1], Integer
2)]

-- | A puzzle: Five sailors and a monkey escape from a naufrage and reach an island with
-- coconuts. Before dawn, they gather a few of them and decide to sleep first and share
-- the next day. At night, however, one of them awakes, counts the nuts, makes five parts,
-- gives the remaining nut to the monkey, saves his share away, and sleeps. All other
-- sailors do the same, one by one. When they all wake up in the morning, they again make 5 shares,
-- and give the last remaining nut to the monkey. How many nuts were there at the beginning?
--
-- We can model this as a series of diophantine equations:
--
-- @
--       x_0 = 5 x_1 + 1
--     4 x_1 = 5 x_2 + 1
--     4 x_2 = 5 x_3 + 1
--     4 x_3 = 5 x_4 + 1
--     4 x_4 = 5 x_5 + 1
--     4 x_5 = 5 x_6 + 1
-- @
--
-- We need to solve for x_0, over the naturals. If you run this program, z3 takes its time (quite long!)
-- but, it eventually computes: [15621,3124,2499,1999,1599,1279,1023] as the answer.
--
-- That is:
--
-- @
--   * There was a total of 15621 coconuts
--   * 1st sailor: 15621 = 3124*5+1, leaving 15621-3124-1 = 12496
--   * 2nd sailor: 12496 = 2499*5+1, leaving 12496-2499-1 =  9996
--   * 3rd sailor:  9996 = 1999*5+1, leaving  9996-1999-1 =  7996
--   * 4th sailor:  7996 = 1599*5+1, leaving  7996-1599-1 =  6396
--   * 5th sailor:  6396 = 1279*5+1, leaving  6396-1279-1 =  5116
--   * In the morning, they had: 5116 = 1023*5+1.
-- @
--
-- Note that this is the minimum solution, that is, we are guaranteed that there's
-- no solution with less number of coconuts. In fact, any member of @[15625*k-4 | k <- [1..]]@
-- is a solution, i.e., so are @31246@, @46871@, @62496@, @78121@, etc.
--
-- Note that we iteratively deepen our search by requesting increasing number of
-- solutions to avoid the all-sat pitfall.
sailors :: IO [Integer]
sailors :: IO [Integer]
sailors = Int -> IO [Integer]
search Int
1
  where search :: Int -> IO [Integer]
search Int
i = do Solution
soln <- Maybe Int -> [([Integer], Integer)] -> IO Solution
ldn (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
i) [ ([Integer
1, -Integer
5,  Integer
0,  Integer
0,  Integer
0,  Integer
0,  Integer
0], Integer
1)
                                           , ([Integer
0,  Integer
4, -Integer
5 , Integer
0,  Integer
0,  Integer
0,  Integer
0], Integer
1)
                                           , ([Integer
0,  Integer
0,  Integer
4, -Integer
5 , Integer
0,  Integer
0,  Integer
0], Integer
1)
                                           , ([Integer
0,  Integer
0,  Integer
0,  Integer
4, -Integer
5,  Integer
0,  Integer
0], Integer
1)
                                           , ([Integer
0,  Integer
0,  Integer
0,  Integer
0,  Integer
4, -Integer
5,  Integer
0], Integer
1)
                                           , ([Integer
0,  Integer
0,  Integer
0,  Integer
0,  Integer
0,  Integer
4, -Integer
5], Integer
1)
                                           ]
                      case Solution
soln of
                        NonHomogeneous ([Integer]
xs:[[Integer]]
_) [[Integer]]
_ -> [Integer] -> IO [Integer]
forall (m :: * -> *) a. Monad m => a -> m a
return [Integer]
xs
                        Solution
_                       -> Int -> IO [Integer]
search (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)