module Rubik.Distances where
import Control.Monad
import Control.Monad.ST
import Control.Monad.Primitive
import Control.Monad.Ref
import Data.Foldable
import Data.Function
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as MG
import qualified Data.Vector.Generic.Mutable.Loops as MG
import qualified Data.MBitVector as MBV
type Coord = Int
distances :: (Traversable t, Eq a, Integral a, Show a, G.Vector v a)
=> Int -> Coord -> (Coord -> t Coord) -> v a
distances n root neighbors = runST (distancesM MG.iForM_ n root neighbors)
distancesM :: forall a m t r v
. ( Traversable t, Eq a, Integral a, Show a
, G.Vector v a, PrimMonad m, MonadRef r m )
=> MG.ILoop m (G.Mutable v) a -> Int -> Coord -> (Coord -> t Coord) -> m (v a)
distancesM forV n root neighbors = do
mv <- MG.replicate n (1)
mb <- MBV.replicate n False
count <- newRef (0 :: Int)
fill forV n root neighbors mv mb count 0
G.unsafeFreeze mv
fill forV n root neighbors mv mb count = fix $ \go d -> do
c <- readRef count
fillFrom neighbors mv mb count d 0 root
c' <- readRef count
unless (c == c' || c' == n) $
if c' < n `div` 10
then go (d+1)
else fill' forV n neighbors mv count d
fillFrom neighbors mv mb count d = fix $ \go dx x -> do
dx' <- MG.read mv x
if dx' == 1
then do
modifyRef' count (+1)
MG.unsafeWrite mv x d
MBV.put mb x (fromIntegral $ d `mod` 2)
else do
test <- mb `MBV.test` x
when (dx == dx' && test == even d) $ do
mb `MBV.complement` x
for_ (neighbors x) (go (dx+1))
fill' forV n neighbors mv count = fix $ \go d -> do
c <- readRef count
forV mv $ \x d' ->
when (d' == d) $ do
ys <- (filterM (\t -> fmap (1 ==) (MG.read mv t)) . toList . neighbors) x
for_ ys $ \y -> modifyRef' count (+1) >> MG.unsafeWrite mv y (d+1)
c' <- readRef count
unless (c == c' || c' == n) $ go (d+1)