module ToySolver.Arith.MIPSolver2
(
Solver
, newSolver
, optimize
, getBestSolution
, getBestValue
, getBestModel
, setNThread
, setLogger
, setOnUpdateBestSolution
, setShowRational
) where
import Prelude hiding (log)
import Control.Monad
import Control.Exception
import Control.Concurrent
import Control.Concurrent.STM
import Data.Default.Class
import Data.List
import Data.OptDir
import Data.Ord
import Data.IORef
import Data.Maybe
import qualified Data.IntSet as IS
import qualified Data.IntMap as IM
import qualified Data.Map as Map
import qualified Data.Sequence as Seq
import qualified Data.Foldable as F
import Data.VectorSpace
import Data.Time
import System.CPUTime
import System.Timeout
import Text.Printf
import qualified ToySolver.Data.LA as LA
import ToySolver.Data.ArithRel ((.<=.), (.>=.))
import qualified ToySolver.Arith.Simplex2 as Simplex2
import ToySolver.Arith.Simplex2 (OptResult (..), Var, Model)
import ToySolver.Internal.Util (isInteger, fracPart)
data Solver
= MIP
{ mipRootLP :: Simplex2.Solver
, mipIVs :: IS.IntSet
, mipBest :: TVar (Maybe Node)
, mipNThread :: IORef Int
, mipLogger :: IORef (Maybe (String -> IO ()))
, mipOnUpdateBestSolution :: IORef (Model -> Rational -> IO ())
, mipShowRational :: IORef Bool
}
data Node =
Node
{ ndLP :: Simplex2.Solver
, ndDepth :: !Int
, ndValue :: Rational
}
newSolver :: Simplex2.Solver -> IS.IntSet -> IO Solver
newSolver lp ivs = do
lp2 <- Simplex2.cloneSolver lp
forM_ (IS.toList ivs) $ \v -> do
lb <- Simplex2.getLB lp2 v
case lb of
Just l | not (isInteger l) ->
Simplex2.assertLower lp2 v (fromInteger (ceiling l))
_ -> return ()
ub <- Simplex2.getUB lp2 v
case ub of
Just u | not (isInteger u) ->
Simplex2.assertLower lp2 v (fromInteger (floor u))
_ -> return ()
bestRef <- newTVarIO Nothing
nthreadRef <- newIORef 0
logRef <- newIORef Nothing
showRef <- newIORef False
updateRef <- newIORef (\_ _ -> return ())
return $
MIP
{ mipRootLP = lp2
, mipIVs = ivs
, mipBest = bestRef
, mipNThread = nthreadRef
, mipLogger = logRef
, mipOnUpdateBestSolution = updateRef
, mipShowRational = showRef
}
optimize :: Solver -> IO OptResult
optimize solver = do
let lp = mipRootLP solver
update <- readIORef (mipOnUpdateBestSolution solver)
log solver "MIP: Solving LP relaxation..."
ret <- Simplex2.check lp
if not ret
then return Unsat
else do
s0 <- showValue solver =<< Simplex2.getObjValue lp
log solver (printf "MIP: LP relaxation is satisfiable with obj = %s" s0)
log solver "MIP: Optimizing LP relaxation"
ret2 <- Simplex2.optimize lp def
case ret2 of
Unsat -> error "should not happen"
ObjLimit -> error "should not happen"
Unbounded -> do
log solver "MIP: LP relaxation is unbounded"
let ivs = mipIVs solver
if IS.null ivs
then return Unbounded
else do
origObj <- Simplex2.getObj lp
lp2 <- Simplex2.cloneSolver lp
Simplex2.clearLogger lp2
Simplex2.setObj lp2 (LA.constant 0)
branchAndBound solver lp2 $ \m _ -> do
update m (LA.evalExpr m origObj)
best <- readTVarIO (mipBest solver)
case best of
Just nd -> do
m <- Simplex2.getModel (ndLP nd)
atomically $ writeTVar (mipBest solver) $ Just nd{ ndValue = LA.evalExpr m origObj }
return Unbounded
Nothing -> return Unsat
Optimum -> do
s1 <- showValue solver =<< Simplex2.getObjValue lp
log solver $ "MIP: LP relaxation optimum is " ++ s1
log solver "MIP: Integer optimization begins..."
Simplex2.clearLogger lp
branchAndBound solver lp update
m <- readTVarIO (mipBest solver)
case m of
Nothing -> return Unsat
Just _ -> return Optimum
branchAndBound :: Solver -> Simplex2.Solver -> (Model -> Rational -> IO ()) -> IO ()
branchAndBound solver rootLP update = do
dir <- Simplex2.getOptDir rootLP
rootVal <- Simplex2.getObjValue rootLP
let root = Node{ ndLP = rootLP, ndDepth = 0, ndValue = rootVal }
pool <- newTVarIO (Seq.singleton root)
activeThreads <- newTVarIO (Map.empty)
visitedNodes <- newTVarIO 0
solchan <- newTChanIO
let addNode :: Node -> STM ()
addNode nd = do
modifyTVar pool (Seq.|> nd)
pickNode :: IO (Maybe Node)
pickNode = do
self <- myThreadId
atomically $ modifyTVar activeThreads (Map.delete self)
atomically $ do
s <- readTVar pool
case Seq.viewl s of
nd Seq.:< s2 -> do
writeTVar pool s2
modifyTVar activeThreads (Map.insert self nd)
return (Just nd)
Seq.EmptyL -> do
ths <- readTVar activeThreads
if Map.null ths
then return Nothing
else retry
processNode :: Node -> IO ()
processNode node = do
let lp = ndLP node
lim <- liftM (fmap ndValue) $ readTVarIO (mipBest solver)
ret <- Simplex2.dualSimplex lp def{ Simplex2.objLimit = lim }
case ret of
Unbounded -> error "should not happen"
Unsat -> return ()
ObjLimit -> return ()
Optimum -> do
val <- Simplex2.getObjValue lp
p <- prune solver val
unless p $ do
xs <- violated node (mipIVs solver)
case xs of
[] -> atomically $ writeTChan solchan $ node { ndValue = val }
_ -> do
r <- if ndDepth node `mod` 100 /= 0
then return Nothing
else liftM listToMaybe $ filterM (canDeriveGomoryCut lp) $ map fst xs
case r of
Nothing -> do
let (v0,val0) = fst $ maximumBy (comparing snd)
[((v,vval), abs (fromInteger (round vval) vval)) | (v,vval) <- xs]
let lp1 = lp
lp2 <- Simplex2.cloneSolver lp
Simplex2.assertAtom lp1 (LA.var v0 .<=. LA.constant (fromInteger (floor val0)))
Simplex2.assertAtom lp2 (LA.var v0 .>=. LA.constant (fromInteger (ceiling val0)))
atomically $ do
addNode $ Node lp1 (ndDepth node + 1) val
addNode $ Node lp2 (ndDepth node + 1) val
modifyTVar visitedNodes (+1)
Just v -> do
atom <- deriveGomoryCut lp (mipIVs solver) v
Simplex2.assertAtom lp atom
atomically $ do
addNode $ Node lp (ndDepth node + 1) val
let isCompleted = do
nodes <- readTVar pool
threads <- readTVar activeThreads
return $ Seq.null nodes && Map.null threads
nthreads <- liftM (max 1) $ readIORef (mipNThread solver)
log solver $ printf "MIP: forking %d worker threads..." nthreads
startCPU <- getCPUTime
startWC <- getCurrentTime
ex <- newEmptyTMVarIO
let printStatus :: Seq.Seq Node -> Int -> IO ()
printStatus nodes visited
| Seq.null nodes = return ()
| otherwise = do
nowCPU <- getCPUTime
nowWC <- getCurrentTime
let spentCPU = (nowCPU startCPU) `div` 10^(12::Int)
let spentWC = round (nowWC `diffUTCTime` startWC) :: Int
let vs = map ndValue (F.toList nodes)
dualBound =
case dir of
OptMin -> minimum vs
OptMax -> maximum vs
primalBound <- do
x <- readTVarIO (mipBest solver)
return $ case x of
Nothing -> Nothing
Just node -> Just (ndValue node)
(p,g) <- case primalBound of
Nothing -> return ("not yet found", "--")
Just val -> do
p <- showValue solver val
let g = if val == 0
then "inf"
else printf "%.2f%%" (fromRational (abs (dualBound val) * 100 / abs val) :: Double)
return (p, g)
d <- showValue solver dualBound
let range =
case dir of
OptMin -> p ++ " >= " ++ d
OptMax -> p ++ " <= " ++ d
log solver $ printf "cpu time = %d sec; wc time = %d sec; active nodes = %d; visited nodes = %d; %s; gap = %s"
spentCPU spentWC (Seq.length nodes) visited range g
mask $ \(restore :: forall a. IO a -> IO a) -> do
threads <- replicateM nthreads $ do
forkIO $ do
let loop = do
m <- pickNode
case m of
Nothing -> return ()
Just node -> processNode node >> loop
ret <- try $ restore loop
case ret of
Left e -> atomically (putTMVar ex e)
Right _ -> return ()
let propagateException :: SomeException -> IO ()
propagateException e = do
mapM_ (\t -> throwTo t e) threads
throwIO e
let loop = do
ret <- try $ timeout (2*1000*1000) $ restore $ atomically $ msum
[ do node <- readTChan solchan
ret <- do
old <- readTVar (mipBest solver)
case old of
Nothing -> do
writeTVar (mipBest solver) (Just node)
return True
Just best -> do
let isBetter = if dir==OptMin then ndValue node < ndValue best else ndValue node > ndValue best
when isBetter $ writeTVar (mipBest solver) (Just node)
return isBetter
return $ do
when ret $ do
let lp = ndLP node
m <- Simplex2.getModel lp
update m (ndValue node)
loop
, do b <- isCompleted
guard b
return $ return ()
, do e <- readTMVar ex
return $ propagateException e
]
case ret of
Left (e::SomeException) -> propagateException e
Right (Just m) -> m
Right Nothing -> do
(nodes, visited) <- atomically $ do
nodes <- readTVar pool
athreads <- readTVar activeThreads
visited <- readTVar visitedNodes
return (Seq.fromList (Map.elems athreads) Seq.>< nodes, visited)
printStatus nodes visited
loop
loop
getBestSolution :: Solver -> IO (Maybe (Model, Rational))
getBestSolution solver = do
ret <- readTVarIO (mipBest solver)
case ret of
Nothing -> return Nothing
Just node -> do
m <- Simplex2.getModel (ndLP node)
return $ Just (m, ndValue node)
getBestModel :: Solver -> IO (Maybe Model)
getBestModel solver = liftM (fmap fst) $ getBestSolution solver
getBestValue :: Solver -> IO (Maybe Rational)
getBestValue solver = liftM (fmap snd) $ getBestSolution solver
violated :: Node -> IS.IntSet -> IO [(Var, Rational)]
violated node ivs = do
m <- Simplex2.getModel (ndLP node)
let p (v,val) = v `IS.member` ivs && not (isInteger val)
return $ filter p (IM.toList m)
prune :: Solver -> Rational -> IO Bool
prune solver lb = do
b <- readTVarIO (mipBest solver)
case b of
Nothing -> return False
Just node -> do
dir <- Simplex2.getOptDir (mipRootLP solver)
return $ if dir==OptMin then ndValue node <= lb else ndValue node >= lb
showValue :: Solver -> Rational -> IO String
showValue solver v = do
printRat <- readIORef (mipShowRational solver)
return $ Simplex2.showValue printRat v
setShowRational :: Solver -> Bool -> IO ()
setShowRational solver = writeIORef (mipShowRational solver)
setNThread :: Solver -> Int -> IO ()
setNThread solver = writeIORef (mipNThread solver)
setLogger :: Solver -> (String -> IO ()) -> IO ()
setLogger solver logger = do
writeIORef (mipLogger solver) (Just logger)
setOnUpdateBestSolution :: Solver -> (Model -> Rational -> IO ()) -> IO ()
setOnUpdateBestSolution solver cb = do
writeIORef (mipOnUpdateBestSolution solver) cb
log :: Solver -> String -> IO ()
log solver msg = logIO solver (return msg)
logIO :: Solver -> IO String -> IO ()
logIO solver action = do
m <- readIORef (mipLogger solver)
case m of
Nothing -> return ()
Just logger -> action >>= logger
deriveGomoryCut :: Simplex2.Solver -> IS.IntSet -> Var -> IO (LA.Atom Rational)
deriveGomoryCut lp ivs xi = do
v0 <- Simplex2.getValue lp xi
let f0 = fracPart v0
assert (0 < f0 && f0 < 1) $ return ()
row <- Simplex2.getRow lp xi
let p (_,xj) = do
lb <- Simplex2.getLB lp xj
ub <- Simplex2.getUB lp xj
case (lb,ub) of
(Just l, Just u) -> return (l < u)
_ -> return True
ns <- filterM p $ LA.terms row
js <- flip filterM ns $ \(_, xj) -> do
vj <- Simplex2.getValue lp xj
lb <- Simplex2.getLB lp xj
return $ Just vj == lb
ks <- flip filterM ns $ \(_, xj) -> do
vj <- Simplex2.getValue lp xj
ub <- Simplex2.getUB lp xj
return $ Just vj == ub
xs1 <- forM js $ \(aij, xj) -> do
let fj = fracPart aij
Just lj <- Simplex2.getLB lp xj
let c = if xj `IS.member` ivs
then (if fj <= 1 f0 then fj / (1 f0) else ((1 fj) / f0))
else (if aij > 0 then aij / (1 f0) else (aij / f0))
return $ c *^ (LA.var xj ^-^ LA.constant lj)
xs2 <- forM ks $ \(aij, xj) -> do
let fj = fracPart aij
Just uj <- Simplex2.getUB lp xj
let c = if xj `IS.member` ivs
then (if fj <= f0 then fj / f0 else ((1 fj) / (1 f0)))
else (if aij > 0 then aij / f0 else (aij / (1 f0)))
return $ c *^ (LA.constant uj ^-^ LA.var xj)
return $ sumV xs1 ^+^ sumV xs2 .>=. LA.constant 1
canDeriveGomoryCut :: Simplex2.Solver -> Var -> IO Bool
canDeriveGomoryCut lp xi = do
b <- Simplex2.isBasicVariable lp xi
if not b
then return False
else do
val <- Simplex2.getValue lp xi
if isInteger val
then return False
else do
row <- Simplex2.getRow lp xi
ys <- forM (LA.terms row) $ \(_,xj) -> do
vj <- Simplex2.getValue lp xj
lb <- Simplex2.getLB lp xj
ub <- Simplex2.getUB lp xj
return $ Just vj == lb || Just vj == ub
return (and ys)