-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.Puzzles.MagicSquare
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- Solves the magic-square puzzle. An NxN magic square is one where all entries
-- are filled with numbers from 1 to NxN such that sums of all rows, columns
-- and diagonals is the same.
-----------------------------------------------------------------------------

{-# OPTIONS_GHC -Wall -Werror #-}

module Documentation.SBV.Examples.Puzzles.MagicSquare where

import Data.List (genericLength, transpose)

import Data.SBV

-- | Use 32-bit words for elements.
type Elem  = SWord32

-- | A row is a list of elements
type Row   = [Elem]

-- | The puzzle board is a list of rows
type Board = [Row]

-- | Checks that all elements in a list are within bounds
check :: Elem -> Elem -> [Elem] -> SBool
check :: Elem -> Elem -> [Elem] -> SBool
check Elem
low Elem
high = forall a. (a -> SBool) -> [a] -> SBool
sAll forall a b. (a -> b) -> a -> b
$ \Elem
x -> Elem
x forall a. OrdSymbolic a => a -> a -> SBool
.>= Elem
low SBool -> SBool -> SBool
.&& Elem
x forall a. OrdSymbolic a => a -> a -> SBool
.<= Elem
high

-- | Get the diagonal of a square matrix
diag :: [[a]] -> [a]
diag :: forall a. [[a]] -> [a]
diag ((a
a:[a]
_):[[a]]
rs) = a
a forall a. a -> [a] -> [a]
: forall a. [[a]] -> [a]
diag (forall a b. (a -> b) -> [a] -> [b]
map forall a. [a] -> [a]
tail [[a]]
rs)
diag [[a]]
_          = []

-- | Test if a given board is a magic square
isMagic :: Board -> SBool
isMagic :: Board -> SBool
isMagic Board
rows = [SBool] -> SBool
sAnd forall a b. (a -> b) -> a -> b
$ Bool -> SBool
fromBool Bool
isSquare forall a. a -> [a] -> [a]
: forall a. EqSymbolic a => [a] -> SBool
allEqual (forall a b. (a -> b) -> [a] -> [b]
map forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum Board
items) forall a. a -> [a] -> [a]
: forall a. EqSymbolic a => [a] -> SBool
distinct (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat Board
rows) forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map [Elem] -> SBool
chk Board
items
  where items :: Board
items = [Elem]
d1 forall a. a -> [a] -> [a]
: [Elem]
d2 forall a. a -> [a] -> [a]
: Board
rows forall a. [a] -> [a] -> [a]
++ Board
columns
        n :: Word32
n = forall i a. Num i => [a] -> i
genericLength Board
rows
        isSquare :: Bool
isSquare = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\[Elem]
r -> forall i a. Num i => [a] -> i
genericLength [Elem]
r forall a. Eq a => a -> a -> Bool
== Word32
n) Board
rows
        columns :: Board
columns = forall a. [[a]] -> [[a]]
transpose Board
rows
        d1 :: [Elem]
d1 = forall a. [[a]] -> [a]
diag Board
rows
        d2 :: [Elem]
d2 = forall a. [[a]] -> [a]
diag (forall a b. (a -> b) -> [a] -> [b]
map forall a. [a] -> [a]
reverse Board
rows)
        chk :: [Elem] -> SBool
chk = Elem -> Elem -> [Elem] -> SBool
check (forall a. SymVal a => a -> SBV a
literal Word32
1) (forall a. SymVal a => a -> SBV a
literal (Word32
nforall a. Num a => a -> a -> a
*Word32
n))

-- | Group a list of elements in the sublists of length @i@
chunk :: Int -> [a] -> [[a]]
chunk :: forall a. Int -> [a] -> [[a]]
chunk Int
_ [] = []
chunk Int
i [a]
xs = let ([a]
f, [a]
r) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
i [a]
xs in [a]
f forall a. a -> [a] -> [a]
: forall a. Int -> [a] -> [[a]]
chunk Int
i [a]
r

-- | Given @n@, magic @n@ prints all solutions to the @nxn@ magic square problem
magic :: Int -> IO ()
magic :: Int -> IO ()
magic Int
n
 | Int
n forall a. Ord a => a -> a -> Bool
< Int
0 = String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ String
"n must be non-negative, received: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
n
 | Bool
True  = do String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ String
"Finding all " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
n forall a. [a] -> [a] -> [a]
++ String
"-magic squares.."
              AllSatResult
res <- forall a. Provable a => a -> IO AllSatResult
allSat forall a b. (a -> b) -> a -> b
$ (Board -> SBool
isMagic forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> [a] -> [[a]]
chunk Int
n) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` forall a. SymVal a => Int -> Symbolic [SBV a]
mkExistVars Int
n2
              Int
cnt <- forall a.
SatModel a =>
([(Bool, a)] -> [(Bool, a)])
-> (Int -> (Bool, a) -> IO ()) -> AllSatResult -> IO Int
displayModels forall a. a -> a
id forall {a} {a}. Show a => a -> (a, [Word32]) -> IO ()
disp AllSatResult
res
              String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ String
"Found: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
cnt forall a. [a] -> [a] -> [a]
++ String
" solution(s)."
   where n2 :: Int
n2 = Int
n forall a. Num a => a -> a -> a
* Int
n
         disp :: a -> (a, [Word32]) -> IO ()
disp a
i (a
_, [Word32]
model)
          | Int
lmod forall a. Eq a => a -> a -> Bool
/= Int
n2
          = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Impossible! Backend solver returned " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
n forall a. [a] -> [a] -> [a]
++ String
" values, was expecting: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
lmod
          | Bool
True
          = do String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ String
"Solution #" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show a
i
               forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {t :: * -> *} {p}. (Foldable t, Show p) => t p -> IO ()
printRow [[Word32]]
board
               String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ String
"Valid Check: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (Board -> SBool
isMagic Board
sboard)
               String -> IO ()
putStrLn String
"Done."
          where lmod :: Int
lmod  = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Word32]
model
                board :: [[Word32]]
board = forall a. Int -> [a] -> [[a]]
chunk Int
n [Word32]
model
                sboard :: Board
sboard = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. (a -> b) -> [a] -> [b]
map forall a. SymVal a => a -> SBV a
literal) [[Word32]]
board
                sh2 :: p -> String
sh2 p
z = let s :: String
s = forall a. Show a => a -> String
show p
z in if forall (t :: * -> *) a. Foldable t => t a -> Int
length String
s forall a. Ord a => a -> a -> Bool
< Int
2 then Char
' 'forall a. a -> [a] -> [a]
:String
s else String
s
                printRow :: t p -> IO ()
printRow t p
r = String -> IO ()
putStr String
"   " forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\p
x -> String -> IO ()
putStr (forall a. Show a => a -> String
sh2 p
x forall a. [a] -> [a] -> [a]
++ String
" ")) t p
r forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> String -> IO ()
putStrLn String
""