{-# LANGUAGE LambdaCase #-}
module Algorithms.Geometry.WellSeparatedPairDecomposition.WSPD where
import Algorithms.Geometry.WellSeparatedPairDecomposition.Types
import Control.Lens hiding (Level, levels)
import Control.Monad.Reader
import Control.Monad.ST (ST,runST)
import Data.BinaryTree
import Data.Ext
import qualified Data.Foldable as F
import Data.Geometry.Box
import Data.Geometry.Transformation
import Data.Geometry.Properties
import Data.Geometry.Point
import Data.Geometry.Vector
import qualified Data.Geometry.Vector as GV
import qualified Data.List as L
import qualified Data.List.NonEmpty as NonEmpty
import Data.Maybe
import Data.Ord (comparing)
import Data.Range
import qualified Data.Range as Range
import Data.Semigroup
import qualified Data.Seq2 as S2
import qualified Data.Sequence as S
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
import GHC.TypeLits
import qualified Data.IntMap.Strict as IntMap
import Debug.Trace
fairSplitTree :: (Fractional r, Ord r, Arity d, 1 <= d
, Show r, Show p
)
=> NonEmpty.NonEmpty (Point d r :+ p) -> SplitTree d p r ()
fairSplitTree pts = foldUp node' Leaf $ fairSplitTree' n pts'
where
pts' = GV.imap sortOn . pure . g $ pts
n = length $ pts'^.GV.element (C :: C 0)
sortOn' i = NonEmpty.sortWith (^.core.unsafeCoord i)
sortOn i = S2.viewL1FromNonEmpty . sortOn' (i + 1)
g = NonEmpty.zipWith (\i (p :+ e) -> p :+ (i :+ e)) (NonEmpty.fromList [0..])
. sortOn' 1
node' l j r = Node l (NodeData j (bbOf l <> bbOf r) ()) r
wellSeparatedPairs :: (Floating r, Ord r, Arity d, Arity (d + 1))
=> r -> SplitTree d p r a -> [WSP d p r a]
wellSeparatedPairs s = f
where
f (Leaf _) = []
f (Node l _ r) = findPairs s l r ++ f l ++ f r
fairSplitTree' :: (Fractional r, Ord r, Arity d, 1 <= d
, Show r, Show p
)
=> Int -> GV.Vector d (PointSeq d (Idx :+ p) r)
-> BinLeafTree Int (Point d r :+ p)
fairSplitTree' n pts
| n <= 1 = let (p S2.:< _) = pts^.GV.element (C :: C 0) in Leaf (dropIdx p)
| otherwise = foldr node' (V.last path) $ V.zip nodeLevels (V.init path)
where
(levels, nodeLevels'@(maxLvl NonEmpty.:| _)) = runST $ do
lvls <- MV.replicate n Nothing
ls <- runReaderT (assignLevels (n `div` 2) 0 pts (Level 0 Nothing) []) lvls
lvls' <- V.unsafeFreeze lvls
pure (lvls',ls)
nodeLevels = V.fromList . L.reverse . NonEmpty.toList $ nodeLevels'
distrPts = distributePoints (1 + maxLvl^.unLevel) levels pts
path = recurse <$> distrPts
node' (lvl,lc) rc = case lvl^?widestDim._Just of
Nothing -> error "Unknown widest dimension"
Just j -> Node lc j rc
recurse pts' = fairSplitTree' (length $ pts'^.GV.element (C :: C 0))
(reIndexPoints pts')
distributePoints :: (Arity d , Show r, Show p)
=> Int -> V.Vector (Maybe Level)
-> GV.Vector d (PointSeq d (Idx :+ p) r)
-> V.Vector (GV.Vector d (PointSeq d (Idx :+ p) r))
distributePoints k levels = transpose . fmap (distributePoints' k levels)
transpose :: Arity d => GV.Vector d (V.Vector a) -> V.Vector (GV.Vector d a)
transpose = V.fromList . map GV.vectorFromListUnsafe . L.transpose
. map V.toList . F.toList
distributePoints' :: Int
-> V.Vector (Maybe Level)
-> PointSeq d (Idx :+ p) r
-> V.Vector (PointSeq d (Idx :+ p) r)
distributePoints' k levels pts
| otherwise
= fmap fromSeqUnsafe $ V.create $ do
v <- MV.replicate k mempty
forM_ pts $ \p ->
append v (level p) p
pure v
where
level p = maybe (k-1) _unLevel $ levels V.! (p^.extra.core)
append v i p = MV.read v i >>= MV.write v i . (S.|> p)
reIndexPoints :: (Arity d, 1 <= d)
=> GV.Vector d (PointSeq d (Idx :+ p) r)
-> GV.Vector d (PointSeq d (Idx :+ p) r)
reIndexPoints ptsV = fmap reIndex ptsV
where
pts = ptsV^.GV.element (C :: C 0)
reIndex = fmap (\p -> p&extra.core %~ fromJust . flip IntMap.lookup mapping')
mapping' = IntMap.fromAscList $ zip (map (^.extra.core) . F.toList $ pts) [0..]
type RST s = ReaderT (MV.MVector s (Maybe Level)) (ST s)
assignLevels :: (Fractional r, Ord r, Arity d, KnownNat d
, Show r, Show p
)
=> Int
-> Int
-> GV.Vector d (PointSeq d (Idx :+ p) r)
-> Level
-> [Level]
-> RST s (NonEmpty.NonEmpty Level)
assignLevels h m pts l prevLvls
| m >= h = pure (l NonEmpty.:| prevLvls)
| otherwise = do
pts' <- compactEnds pts
let j = widestDimension pts'
i = j - 1
extJ = (extends pts')^.ix' i
mid = midPoint extJ
(lvlJPts,deletePts) <- findAndCompact j (pts'^.ix' i) mid
let pts'' = pts'&ix' i .~ lvlJPts
l' = l&widestDim .~ Just j
forM_ deletePts $ \p ->
assignLevel p l'
assignLevels h (m + length deletePts) pts'' (nextLevel l) (l' : prevLvls)
compactEnds :: Arity d
=> GV.Vector d (PointSeq d (Idx :+ p) r)
-> RST s (GV.Vector d (PointSeq d (Idx :+ p) r))
compactEnds = traverse compactEnds'
assignLevel :: (c :+ (Idx :+ p)) -> Level -> RST s ()
assignLevel p l = ask >>= \levels -> lift $ MV.write levels (p^.extra.core) (Just l)
levelOf :: (c :+ (Idx :+ p)) -> RST s (Maybe Level)
levelOf p = ask >>= \levels -> lift $ MV.read levels (p^.extra.core)
hasLevel :: c :+ (Idx :+ p) -> RST s Bool
hasLevel = fmap isJust . levelOf
compactEnds' :: PointSeq d (Idx :+ p) r
-> RST s (PointSeq d (Idx :+ p) r)
compactEnds' (l0 S2.:< s0) = fmap fromSeqUnsafe . goL $ l0 S.<| s0
where
goL s@(S.viewl -> l S.:< s') = hasLevel l >>= \case
False -> goR s
True -> goL s'
goR s@(S.viewr -> s' S.:> r) = hasLevel r >>= \case
False -> pure s
True -> goR s'
findAndCompact :: (Ord r, Arity d
, Show r, Show p
)
=> Int
-> PointSeq d (Idx :+ p) r
-> r
-> RST s ( PointSeq d (Idx :+ p) r
, PointSeq d (Idx :+ p) r
)
findAndCompact j (l0 S2.:< s0) m = fmap select . stepL $ l0 S.<| s0
where
stepL s = case S.viewl s of
S.EmptyL -> pure $ FAC mempty mempty L
l S.:< s' -> hasLevel l >>= \case
False -> if l^.core.unsafeCoord j <= m
then addL l <$> stepR s'
else pure $ FAC mempty s L
True -> stepL s'
stepR s = case S.viewr s of
S.EmptyR -> pure $ FAC mempty mempty R
s' S.:> r -> hasLevel r >>= \case
False -> if r^.core.unsafeCoord j >= m
then addR r <$> stepL s'
else pure $ FAC s mempty R
True -> stepR s'
addL l x = x&leftPart %~ (l S.<|)
addR r x = x&rightPart %~ (S.|> r)
select = over both fromSeqUnsafe . select'
select' (FAC l r L) = (r, l)
select' (FAC l r R) = (l, r)
widestDimension :: (Num r, Ord r, Arity d) => GV.Vector d (PointSeq d p r) -> Int
widestDimension = fst . L.maximumBy (comparing snd) . zip [1..] . F.toList . widths
widths :: (Num r, Arity d) => GV.Vector d (PointSeq d p r) -> GV.Vector d r
widths = fmap Range.width . extends
extends :: Arity d => GV.Vector d (PointSeq d p r) -> GV.Vector d (Range r)
extends = GV.imap (\i pts@(l S2.:< _) ->
let (_ S2.:> r) = S2.viewL1toR1 pts
in ClosedRange (l^.core.unsafeCoord (i + 1))
(r^.core.unsafeCoord (i + 1)))
findPairs :: (Floating r, Ord r, Arity d, Arity (d + 1))
=> r -> SplitTree d p r a -> SplitTree d p r a
-> [WSP d p r a]
findPairs s l r
| areWellSeparated' s l r = [(l,r)]
| maxWidth l <= maxWidth r = concatMap (findPairs s l) $ children' r
| otherwise = concatMap (findPairs s r) $ children' l
areWellSeparated :: (Arity d, Arity (d + 1), Fractional r, Ord r)
=> r
-> SplitTree d p r a
-> SplitTree d p r a -> Bool
areWellSeparated _ (Leaf _) (Leaf _) = True
areWellSeparated s l r = boxBox s (bbOf l) (bbOf r)
boxBox :: (Fractional r, Ord r, Arity d, Arity (d + 1))
=> r -> Box d p r -> Box d p r -> Bool
boxBox s lb rb = boxBox' lb rb && boxBox' rb lb
where
boxBox' b' b = not $ b' `intersects` bOut
where
v = (centerPoint b)^.vector
bOut = translateBy v . scaleUniformlyBy s . translateBy ((-1) *^ v) $ b
areWellSeparated' :: (Floating r, Ord r, Arity d)
=> r
-> SplitTree d p r a
-> SplitTree d p r a
-> Bool
areWellSeparated' _ (Leaf _) (Leaf _) = True
areWellSeparated' s l r = boxBox1 s (bbOf l) (bbOf r)
boxBox1 :: (Floating r, Ord r, Arity d) => r -> Box d p r -> Box d p r -> Bool
boxBox1 s lb rb = euclideanDist (centerPoint lb) (centerPoint rb) >= (s+1)*d
where
diam b = euclideanDist (b^.minP.core.cwMin) (b^.maxP.core.cwMax)
d = max (diam lb) (diam rb)
maxWidth :: (Arity d, KnownNat d, Num r)
=> SplitTree d p r a -> r
maxWidth (Leaf _) = 0
maxWidth (Node _ (NodeData i b _) _) = fromJust $ widthIn' i b
bbOf :: Ord r => SplitTree d p r a -> Box d () r
bbOf (Leaf p) = boundingBox $ p^.core
bbOf (Node _ (NodeData _ b _) _) = b
children' :: BinLeafTree v a -> [BinLeafTree v a]
children' (Leaf _) = []
children' (Node l _ r) = [l,r]
fromSeqUnsafe :: S.Seq a -> S2.ViewL1 a
fromSeqUnsafe (S.viewl -> (l S.:< s)) = l S2.:< s
fromSeqUnsafe _ = error "fromSeqUnsafe: Empty seq"
ix' :: (Arity d, KnownNat d) => Int -> Lens' (GV.Vector d a) a
ix' i = singular (GV.element' i)
dropIdx :: core :+ (t :+ extra) -> core :+ extra
dropIdx (p :+ (_ :+ e)) = p :+ e