-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.Puzzles.MagicSquare
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- Solves the magic-square puzzle. An NxN magic square is one where all entries
-- are filled with numbers from 1 to NxN such that sums of all rows, columns
-- and diagonals is the same.
-----------------------------------------------------------------------------

{-# OPTIONS_GHC -Wall -Werror #-}

module Documentation.SBV.Examples.Puzzles.MagicSquare where

import Data.List (genericLength, transpose)

import Data.SBV

-- | Use 32-bit words for elements.
type Elem  = SWord32

-- | A row is a list of elements
type Row   = [Elem]

-- | The puzzle board is a list of rows
type Board = [Row]

-- | Checks that all elements in a list are within bounds
check :: Elem -> Elem -> [Elem] -> SBool
check :: Elem -> Elem -> [Elem] -> SBool
check Elem
low Elem
high = (Elem -> SBool) -> [Elem] -> SBool
forall a. (a -> SBool) -> [a] -> SBool
sAll ((Elem -> SBool) -> [Elem] -> SBool)
-> (Elem -> SBool) -> [Elem] -> SBool
forall a b. (a -> b) -> a -> b
$ \Elem
x -> Elem
x Elem -> Elem -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.>= Elem
low SBool -> SBool -> SBool
.&& Elem
x Elem -> Elem -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.<= Elem
high

-- | Get the diagonal of a square matrix
diag :: [[a]] -> [a]
diag :: [[a]] -> [a]
diag ((a
a:[a]
_):[[a]]
rs) = a
a a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [[a]] -> [a]
forall a. [[a]] -> [a]
diag (([a] -> [a]) -> [[a]] -> [[a]]
forall a b. (a -> b) -> [a] -> [b]
map [a] -> [a]
forall a. [a] -> [a]
tail [[a]]
rs)
diag [[a]]
_          = []

-- | Test if a given board is a magic square
isMagic :: Board -> SBool
isMagic :: Board -> SBool
isMagic Board
rows = [SBool] -> SBool
sAnd ([SBool] -> SBool) -> [SBool] -> SBool
forall a b. (a -> b) -> a -> b
$ Bool -> SBool
fromBool Bool
isSquare SBool -> [SBool] -> [SBool]
forall a. a -> [a] -> [a]
: [Elem] -> SBool
forall a. EqSymbolic a => [a] -> SBool
allEqual (([Elem] -> Elem) -> Board -> [Elem]
forall a b. (a -> b) -> [a] -> [b]
map [Elem] -> Elem
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum Board
items) SBool -> [SBool] -> [SBool]
forall a. a -> [a] -> [a]
: [Elem] -> SBool
forall a. EqSymbolic a => [a] -> SBool
distinct (Board -> [Elem]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat Board
rows) SBool -> [SBool] -> [SBool]
forall a. a -> [a] -> [a]
: ([Elem] -> SBool) -> Board -> [SBool]
forall a b. (a -> b) -> [a] -> [b]
map [Elem] -> SBool
chk Board
items
  where items :: Board
items = [Elem]
d1 [Elem] -> Board -> Board
forall a. a -> [a] -> [a]
: [Elem]
d2 [Elem] -> Board -> Board
forall a. a -> [a] -> [a]
: Board
rows Board -> Board -> Board
forall a. [a] -> [a] -> [a]
++ Board
columns
        n :: Word32
n = Board -> Word32
forall i a. Num i => [a] -> i
genericLength Board
rows
        isSquare :: Bool
isSquare = ([Elem] -> Bool) -> Board -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\[Elem]
r -> [Elem] -> Word32
forall i a. Num i => [a] -> i
genericLength [Elem]
r Word32 -> Word32 -> Bool
forall a. Eq a => a -> a -> Bool
== Word32
n) Board
rows
        columns :: Board
columns = Board -> Board
forall a. [[a]] -> [[a]]
transpose Board
rows
        d1 :: [Elem]
d1 = Board -> [Elem]
forall a. [[a]] -> [a]
diag Board
rows
        d2 :: [Elem]
d2 = Board -> [Elem]
forall a. [[a]] -> [a]
diag (([Elem] -> [Elem]) -> Board -> Board
forall a b. (a -> b) -> [a] -> [b]
map [Elem] -> [Elem]
forall a. [a] -> [a]
reverse Board
rows)
        chk :: [Elem] -> SBool
chk = Elem -> Elem -> [Elem] -> SBool
check (Word32 -> Elem
forall a. SymVal a => a -> SBV a
literal Word32
1) (Word32 -> Elem
forall a. SymVal a => a -> SBV a
literal (Word32
nWord32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
*Word32
n))

-- | Group a list of elements in the sublists of length @i@
chunk :: Int -> [a] -> [[a]]
chunk :: Int -> [a] -> [[a]]
chunk Int
_ [] = []
chunk Int
i [a]
xs = let ([a]
f, [a]
r) = Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
i [a]
xs in [a]
f [a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
: Int -> [a] -> [[a]]
forall a. Int -> [a] -> [[a]]
chunk Int
i [a]
r

-- | Given @n@, magic @n@ prints all solutions to the @nxn@ magic square problem
magic :: Int -> IO ()
magic :: Int -> IO ()
magic Int
n
 | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"n must be non-negative, received: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n
 | Bool
True  = do String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Finding all " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"-magic squares.."
              AllSatResult
res <- SymbolicT IO SBool -> IO AllSatResult
forall a. Provable a => a -> IO AllSatResult
allSat (SymbolicT IO SBool -> IO AllSatResult)
-> SymbolicT IO SBool -> IO AllSatResult
forall a b. (a -> b) -> a -> b
$ (Board -> SBool
isMagic (Board -> SBool) -> ([Elem] -> Board) -> [Elem] -> SBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Elem] -> Board
forall a. Int -> [a] -> [[a]]
chunk Int
n) ([Elem] -> SBool) -> SymbolicT IO [Elem] -> SymbolicT IO SBool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Int -> SymbolicT IO [Elem]
forall a. SymVal a => Int -> Symbolic [SBV a]
mkExistVars Int
n2
              Int
cnt <- ([(Bool, [Word32])] -> [(Bool, [Word32])])
-> (Int -> (Bool, [Word32]) -> IO ()) -> AllSatResult -> IO Int
forall a.
SatModel a =>
([(Bool, a)] -> [(Bool, a)])
-> (Int -> (Bool, a) -> IO ()) -> AllSatResult -> IO Int
displayModels [(Bool, [Word32])] -> [(Bool, [Word32])]
forall a. a -> a
id Int -> (Bool, [Word32]) -> IO ()
forall a a. Show a => a -> (a, [Word32]) -> IO ()
disp AllSatResult
res
              String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Found: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
cnt String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" solution(s)."
   where n2 :: Int
n2 = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n
         disp :: a -> (a, [Word32]) -> IO ()
disp a
i (a
_, [Word32]
model)
          | Int
lmod Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n2
          = String -> IO ()
forall a. HasCallStack => String -> a
error (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Impossible! Backend solver returned " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" values, was expecting: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
lmod
          | Bool
True
          = do String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Solution #" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
i
               ([Word32] -> IO ()) -> [[Word32]] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ [Word32] -> IO ()
forall (t :: * -> *) a. (Foldable t, Show a) => t a -> IO ()
printRow [[Word32]]
board
               String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Valid Check: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ SBool -> String
forall a. Show a => a -> String
show (Board -> SBool
isMagic Board
sboard)
               String -> IO ()
putStrLn String
"Done."
          where lmod :: Int
lmod  = [Word32] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Word32]
model
                board :: [[Word32]]
board = Int -> [Word32] -> [[Word32]]
forall a. Int -> [a] -> [[a]]
chunk Int
n [Word32]
model
                sboard :: Board
sboard = ([Word32] -> [Elem]) -> [[Word32]] -> Board
forall a b. (a -> b) -> [a] -> [b]
map ((Word32 -> Elem) -> [Word32] -> [Elem]
forall a b. (a -> b) -> [a] -> [b]
map Word32 -> Elem
forall a. SymVal a => a -> SBV a
literal) [[Word32]]
board
                sh2 :: a -> String
sh2 a
z = let s :: String
s = a -> String
forall a. Show a => a -> String
show a
z in if String -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2 then Char
' 'Char -> String -> String
forall a. a -> [a] -> [a]
:String
s else String
s
                printRow :: t a -> IO ()
printRow t a
r = String -> IO ()
putStr String
"   " IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (a -> IO ()) -> t a -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\a
x -> String -> IO ()
putStr (a -> String
forall a. Show a => a -> String
sh2 a
x String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" ")) t a
r IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> String -> IO ()
putStrLn String
""