{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module AI.Search.FiniteDomain.Int.Constraint
( (#=)
, (#/=)
, (#<)
, (#<=)
, (#>)
, (#>=)
, (/\)
, (\/)
, Constraint
, FD
, Labeling(..)
, allDifferent
, between
, initNewVar
, labeling
, newVar
, not'
, runFD
) where
import Control.Monad ( forM, forM_ )
import Control.Monad.ST ( ST, runST )
import Data.List ( find )
import qualified Numeric.Domain as D
import Data.Propagator.Cell as P ( Cell, cell, connect, label, propagateMany
, readCell, succeeded, sync, syncWith )
import Control.Monad.Trans.State ( State, evalState, execState, get, modify, put )
import AI.Search.FiniteDomain.Int.Cell ( domainJoin, eqJoin, mustHoldJoin )
import AI.Search.FiniteDomain.Int.Expression ( Expression, cellifyExpression, var )
newtype FD a = FD { unFD :: State ([IntConstraint], Int) a }
deriving (Applicative, Functor, Monad)
type Constraint = FD ()
addConstraint :: IntConstraint -> FD ()
addConstraint cons = FD $
modify (\(cs, ix) -> (cons : cs, ix))
runFD :: FD a -> a
runFD (FD state) = evalState state ([], 0)
data IntConstraint
= Equals Expression Expression
| NotEquals Expression Expression
| LessThan Expression Expression
| And IntConstraint IntConstraint
| Or IntConstraint IntConstraint
deriving (Eq, Ord, Show)
infix 4 #=
(#=) :: Expression -> Expression -> Constraint
(#=) left right = addConstraint $ left `Equals` right
infix 4 #/=
(#/=) :: Expression -> Expression -> Constraint
(#/=) left right = addConstraint $ left `NotEquals` right
infix 4 #<
(#<) :: Expression -> Expression -> Constraint
(#<) left right = addConstraint $ left `LessThan` right
infix 4 #<=
(#<=) :: Expression -> Expression -> Constraint
(#<=) left right =
addConstraint $ (left `LessThan` right) `Or` (left `Equals` right)
infix 4 #>
(#>) :: Expression -> Expression -> Constraint
(#>) = flip (#<)
infix 4 #>=
(#>=) :: Expression -> Expression -> Constraint
(#>=) = flip (#<=)
infixl 3 /\
(/\) :: Constraint -> Constraint -> Constraint
(/\) = (>>)
infixl 2 \/
(\/) :: Constraint -> Constraint -> Constraint
(\/) left right = FD $ do
(cs, ix) <- get
let (lcs, lx) = execState (unFD left) ([], ix)
(rcs, nx) = execState (unFD right) ([], lx)
case (lcs, rcs) of
( [], _) -> put (cs ++ rcs, nx)
( _, []) -> put (cs ++ lcs, nx)
(l:ls, r:rs) ->
let leftAnd = foldl And l ls
rightAnd = foldl And r rs
constraint = leftAnd `Or` rightAnd
in put (constraint : cs, nx)
not' :: Constraint -> Constraint
not' constraint = FD $ do
(cons, ix) <- get
let (ncs, nx) = execState (unFD constraint) ([], ix)
case ncs of
[] -> put (cons, nx)
c:cs -> put (recNot (foldl And c cs) : cons, nx)
where
recNot (Equals l r) = NotEquals l r
recNot (NotEquals l r) = Equals l r
recNot (LessThan l r) = (r `LessThan` l) `Or` (r `Equals` l)
recNot (And l r) = Or (recNot l) (recNot r)
recNot (Or l r) = And (recNot l) (recNot r)
allDifferent :: [Expression] -> Constraint
allDifferent [] = pure ()
allDifferent (c:cs) = do
sequence_ (fmap (c #/=) cs)
allDifferent cs
between :: Expression
-> Expression
-> Expression
-> Constraint
between low high target = do
target #>= low
target #<= high
newVar :: FD Expression
newVar = FD $ do
(cs, iD) <- get
put (cs, iD + 1)
pure (var iD)
initNewVar :: Expression -> FD Expression
initNewVar initExpr = do
v <- newVar
v #= initExpr
pure v
data Labeling a
= Unsolvable [D.Domain Int]
| Unbounded [D.Domain Int]
| Solutions [a]
deriving (Eq, Show)
instance Functor Labeling where
fmap _ (Unsolvable ds) = Unsolvable ds
fmap _ (Unbounded ds) = Unbounded ds
fmap f (Solutions xs) = Solutions (fmap f xs)
instance Applicative Labeling where
pure a = Solutions [a]
Unsolvable ds <*> _ = Unsolvable ds
Unbounded ds <*> _ = Unbounded ds
_ <*> Unsolvable ds = Unsolvable ds
_ <*> Unbounded ds = Unbounded ds
Solutions f <*> Solutions a = Solutions (f <*> a)
instance Monad Labeling where
Unsolvable ds >>= _ = Unsolvable ds
Unbounded ds >>= _ = Unbounded ds
Solutions xs >>= f = go [] xs
where
go acc [] = Solutions acc
go acc (y:ys) =
case f y of
Unsolvable ds -> Unsolvable ds
Unbounded ds -> Unbounded ds
Solutions bs -> go (acc ++ bs) ys
labeling :: [Expression] -> FD (Labeling [Int])
labeling vars = do
cons <- FD (fmap fst get)
pure $
runST $ do
(res, rvs, cells) <- cellifyConstraints cons []
trueCell <- cell True mustHoldJoin
allCell <- cell D.maxDomain domainJoin
forM_ res $ \c -> connect c trueCell pure
let userCells = fmap snd userView
userView =
flip fmap vars $ \v -> do
case find ((== v) . fst) rvs of
Just av -> av
Nothing -> (v, allCell)
propagation <- propagateMany cells
snapshot <- forM userCells P.readCell
if not (succeeded propagation) then
pure (Unsolvable snapshot)
else do
if any D.isInfinite snapshot then
pure (Unbounded snapshot)
else do
result <- label (concat . D.elems) D.singleton userCells
case result of
[] -> pure (Unsolvable snapshot)
xs -> pure (Solutions xs)
type DomainCell s = Cell s (D.Domain Int)
type LogicCell s = Cell s Bool
type VarCell s = (Expression, DomainCell s)
cellifyConstraints
:: [IntConstraint]
-> [VarCell s]
-> ST s ([LogicCell s], [VarCell s], [DomainCell s])
cellifyConstraints cons vars =
case cons of
[] -> pure ([], vars, [])
c:cs -> do
(ls, nvs, xs) <- cellifyConstraint c vars
(rs, rvs, ys) <- cellifyConstraints cs nvs
pure (ls : rs, rvs, xs ++ ys)
cellifyConstraint
:: IntConstraint
-> [VarCell s]
-> ST s (LogicCell s, [VarCell s], [DomainCell s])
cellifyConstraint constraint vars =
case constraint of
Equals left right ->
binary left right sync
NotEquals left right ->
binary left right (syncWith D.notEqual D.notEqual)
LessThan left right ->
binary left right (syncWith D.greaterThanDomain D.lessThanDomain)
And left right -> do
(ls, nvs, xs) <- cellifyConstraint left vars
(rs, rvs, ys) <- cellifyConstraint right nvs
newCell <- cell True eqJoin
connect ls newCell (\ld -> (ld &&) <$> P.readCell rs)
connect rs newCell (\rd -> (&& rd) <$> P.readCell ls)
pure (newCell, rvs, xs ++ ys)
Or left right -> do
(ls, lvs, xs) <- cellifyConstraint left []
(rs, rvs, ys) <- cellifyConstraint right []
newCell <- cell True eqJoin
connect ls newCell (\ld -> (ld ||) <$> P.readCell rs)
connect rs newCell (\rd -> (|| rd) <$> P.readCell ls)
let (pairs, solos) = split lvs rvs
pairNews <-
forM pairs $ \(v,lc,rc) -> do
(varCell, new) <-
case find ((v ==) . fst) vars of
Just av -> pure (snd av, [])
Nothing -> do
varCell <- cell D.maxDomain domainJoin
pure (varCell, [(v, varCell)])
connect varCell lc pure
connect varCell rc pure
connect lc varCell (\ld -> D.union ld <$> P.readCell rc)
connect rc varCell (\rd -> (`D.union` rd) <$> P.readCell lc)
pure new
soloNews <-
forM solos $ \(v,sc) -> do
(varCell, new) <-
case find ((v ==) . fst) vars of
Just av -> pure (snd av, [])
Nothing -> do
varCell <- cell D.maxDomain domainJoin
pure (varCell, [(v, varCell)])
connect varCell sc pure
pure new
pure (newCell, concat pairNews ++ concat soloNews ++ vars, xs ++ ys)
where
binary left right wire = do
(ls, lcs, xs) <- cellifyExpression left vars
(rs, rcs, ys) <- cellifyExpression right lcs
newCell <- cell True eqJoin
_ <- wire ls rs
connect ls newCell (pure . not . D.null)
connect rs newCell (pure . not . D.null)
pure (newCell, rcs, xs ++ ys)
extract :: Expression -> [VarCell s] -> Maybe (VarCell s, [VarCell s])
extract _ [] = Nothing
extract a (x:xs) | a == fst x = Just (x, xs)
| otherwise = do (b, rs) <- extract a xs
pure (b, x : rs)
split :: [VarCell s]
-> [VarCell s]
-> ([(Expression, DomainCell s, DomainCell s)], [VarCell s])
split [] right = ([], right)
split (x:xs) right =
case extract xVar right of
Just (a, rs) ->
let (ps, vs) = split xs rs
in ((xVar, snd x, snd a) : ps, vs)
Nothing ->
let (ps, vs) = split xs right
in (ps, x : vs)
where
xVar = fst x