module Algorithms.FloydWarshall
( mkIndex
, mkGraph
, floydWarshall
) where
import Control.Monad (forM_, when)
import Control.Monad.ST (ST)
import Data.Vector.Unboxed.Mutable as V (MVector, length, replicate, unsafeRead,
unsafeWrite, Unbox)
floydWarshall :: (Unbox a, Fractional a, Ord a) => Int -> MVector s (a, Int) -> ST s ()
floydWarshall :: Int -> MVector s (a, Int) -> ST s ()
floydWarshall Int
n MVector s (a, Int)
graph = do
let nSq :: Int
nSq = MVector s (a, Int) -> Int
forall a s. Unbox a => MVector s a -> Int
V.length MVector s (a, Int)
graph
Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
nSq) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ [Char] -> ST s ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Bad bounds"
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
k ->
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i ->
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> do
(a
distIJ, Int
_) <- (Int, Int) -> ST s (a, Int)
access (Int
i,Int
j)
(a
distIK, Int
pathIK) <- (Int, Int) -> ST s (a, Int)
access (Int
i,Int
k)
(a
distKJ, Int
_) <- (Int, Int) -> ST s (a, Int)
access (Int
k,Int
j)
let indirectDist :: a
indirectDist = a
distIK a -> a -> a
forall a. Num a => a -> a -> a
+ a
distKJ
Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (a
distIJ a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
indirectDista -> a -> a
forall a. Num a => a -> a -> a
+a
indirectDista -> a -> a
forall a. Num a => a -> a -> a
*a
eps Bool -> Bool -> Bool
&& a
distIJ a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
distIK Bool -> Bool -> Bool
&& a
distIJ a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
distKJ) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$
(Int, Int) -> (a, Int) -> ST s ()
put (Int
i,Int
j) (a
indirectDist, Int
pathIK)
where
access :: (Int, Int) -> ST s (a, Int)
access (Int, Int)
idx = MVector (PrimState (ST s)) (a, Int) -> Int -> ST s (a, Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
V.unsafeRead MVector s (a, Int)
MVector (PrimState (ST s)) (a, Int)
graph (Int -> (Int, Int) -> Int
forall a. Num a => a -> (a, a) -> a
mkIndex Int
n (Int, Int)
idx)
put :: (Int, Int) -> (a, Int) -> ST s ()
put (Int, Int)
idx (a, Int)
e = MVector (PrimState (ST s)) (a, Int) -> Int -> (a, Int) -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
V.unsafeWrite MVector s (a, Int)
MVector (PrimState (ST s)) (a, Int)
graph (Int -> (Int, Int) -> Int
forall a. Num a => a -> (a, a) -> a
mkIndex Int
n (Int, Int)
idx) (a, Int)
e
eps :: a
eps = a
1e-10
mkIndex :: Num a => a -> (a, a) -> a
mkIndex :: a -> (a, a) -> a
mkIndex a
n (a
i,a
j) = a
ia -> a -> a
forall a. Num a => a -> a -> a
*a
na -> a -> a
forall a. Num a => a -> a -> a
+a
j
mkGraph :: (Unbox a, Num a) => Int -> a -> [(Int,Int,a)] -> ST s (MVector s (a, Int))
mkGraph :: Int -> a -> [(Int, Int, a)] -> ST s (MVector s (a, Int))
mkGraph Int
n a
maxValue [(Int, Int, a)]
edges = do
MVector s (a, Int)
graph <- Int -> (a, Int) -> ST s (MVector (PrimState (ST s)) (a, Int))
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
V.replicate (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n) (a
maxValue, Int
forall a. Bounded a => a
maxBound)
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
v -> do
MVector (PrimState (ST s)) (a, Int) -> Int -> (a, Int) -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
unsafeWrite MVector s (a, Int)
MVector (PrimState (ST s)) (a, Int)
graph (Int -> (Int, Int) -> Int
forall a. Num a => a -> (a, a) -> a
mkIndex Int
n (Int
v,Int
v)) (a
0, Int
v)
[(Int, Int, a)] -> ((Int, Int, a) -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Int, Int, a)]
edges (((Int, Int, a) -> ST s ()) -> ST s ())
-> ((Int, Int, a) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j,a
cost) -> do
MVector (PrimState (ST s)) (a, Int) -> Int -> (a, Int) -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
unsafeWrite MVector s (a, Int)
MVector (PrimState (ST s)) (a, Int)
graph (Int -> (Int, Int) -> Int
forall a. Num a => a -> (a, a) -> a
mkIndex Int
n (Int
i,Int
j)) (a
cost, Int
j)
MVector (PrimState (ST s)) (a, Int) -> Int -> (a, Int) -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
unsafeWrite MVector s (a, Int)
MVector (PrimState (ST s)) (a, Int)
graph (Int -> (Int, Int) -> Int
forall a. Num a => a -> (a, a) -> a
mkIndex Int
n (Int
j,Int
i)) (a
cost, Int
i)
MVector s (a, Int) -> ST s (MVector s (a, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s (a, Int)
graph