{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveGeneric #-}
module Data.IP.RouteTable.Internal where
import Control.Applicative hiding (empty)
import qualified Control.Applicative as A (empty)
import Control.Monad
import Data.Bits
import Data.Foldable (Foldable(..))
import Data.IP.Addr
import Data.IP.Op
import Data.IP.Range
import Data.IntMap (IntMap, (!))
import qualified Data.IntMap as IM (fromList)
import Data.Monoid hiding ((<>))
import Data.Semigroup
import Data.Traversable
import Data.Word
import GHC.Generics (Generic, Generic1)
import Prelude hiding (lookup)
class Addr a => Routable a where
intToTBit :: Int -> a
isZero :: a -> a -> Bool
instance Routable IPv4 where
intToTBit = intToTBitIPv4
isZero a b = a `masked` b == IP4 0
instance Routable IPv6 where
intToTBit = intToTBitIPv6
isZero a b = a `masked` b == IP6 (0,0,0,0)
intToTBitIPv4 :: Int -> IPv4
intToTBitIPv4 len = IP4 (intToTBitsIPv4 ! len)
intToTBitIPv6 :: Int -> IPv6
intToTBitIPv6 len = IP6 (intToTBitsIPv6 ! len)
intToTBitsWord32 :: [Word32]
intToTBitsWord32 = iterate (`shift` (-1)) 0x80000000
intToTBitsIPv4 :: IntMap IPv4Addr
intToTBitsIPv4 = IM.fromList $ zip [0..32] intToTBitsWord32
intToTBitsIPv6 :: IntMap IPv6Addr
intToTBitsIPv6 = IM.fromList $ zip [0..128] bs
where
bs = b1 ++ b2 ++ b3 ++ b4 ++ b5
b1 = map (\vbit -> (vbit,all0,all0,all0)) intToTBits
b2 = map (\vbit -> (all0,vbit,all0,all0)) intToTBits
b3 = map (\vbit -> (all0,all0,vbit,all0)) intToTBits
b4 = map (\vbit -> (all0,all0,all0,vbit)) intToTBits
b5 = [(all0,all0,all0,all0)]
intToTBits = take 32 intToTBitsWord32
all0 = 0x00000000
data IPRTable k a =
Nil
| Node !(AddrRange k) !k !(Maybe a) !(IPRTable k a) !(IPRTable k a)
deriving (Eq, Generic, Generic1, Show)
empty :: Routable k => IPRTable k a
empty = Nil
instance Functor (IPRTable k) where
fmap _ Nil = Nil
fmap f (Node r a mv b1 b2) = Node r a (f <$> mv) (fmap f b1) (fmap f b2)
instance Foldable (IPRTable k) where
foldMap _ Nil = mempty
foldMap f (Node _ _ mv b1 b2) = foldMap f mv `mappend` foldMap f b1 `mappend` foldMap f b2
instance Traversable (IPRTable k) where
traverse _ Nil = pure Nil
traverse f (Node r a mv b1 b2) = Node r a <$> traverse f mv <*> traverse f b1 <*> traverse f b2
instance Routable k => Semigroup (IPRTable k a) where
a <> b = foldlWithKey (\rt k v -> insert k v rt) a b
stimes = stimesIdempotent
instance Routable k => Monoid (IPRTable k a) where
mempty = empty
mappend = (<>)
insert :: (Routable k) => AddrRange k -> a -> IPRTable k a -> IPRTable k a
insert k1 v1 Nil = Node k1 tb1 (Just v1) Nil Nil
where
tb1 = keyToTestBit k1
insert k1 v1 s@(Node k2 tb2 v2 l r)
| k1 == k2 = Node k1 tb1 (Just v1) l r
| k2 >:> k1 = if isLeft k1 tb2 then
Node k2 tb2 v2 (insert k1 v1 l) r
else
Node k2 tb2 v2 l (insert k1 v1 r)
| k1 >:> k2 = if isLeft k2 tb1 then
Node k1 tb1 (Just v1) s Nil
else
Node k1 tb1 (Just v1) Nil s
| otherwise = let n = Node k1 tb1 (Just v1) Nil Nil
in link n s
where
tb1 = keyToTestBit k1
link :: Routable k => IPRTable k a -> IPRTable k a -> IPRTable k a
link s1@(Node k1 _ _ _ _) s2@(Node k2 _ _ _ _)
| isLeft k1 tbg = Node kg tbg Nothing s1 s2
| otherwise = Node kg tbg Nothing s2 s1
where
kg = glue 0 k1 k2
tbg = keyToTestBit kg
link _ _ = error "link"
glue :: (Routable k) => Int -> AddrRange k -> AddrRange k -> AddrRange k
glue n k1 k2
| addr k1 `masked` mk == addr k2 `masked` mk = glue (n + 1) k1 k2
| otherwise = makeAddrRange (addr k1) (n - 1)
where
mk = intToMask n
keyToTestBit :: Routable k => AddrRange k -> k
keyToTestBit = intToTBit . mlen
isLeft :: Routable k => AddrRange k -> k -> Bool
isLeft adr = isZero (addr adr)
delete :: (Routable k) => AddrRange k -> IPRTable k a -> IPRTable k a
delete _ Nil = Nil
delete k1 s@(Node k2 tb2 v2 l r)
| k1 == k2 = node k2 tb2 Nothing l r
| k2 >:> k1 = if isLeft k1 tb2 then
node k2 tb2 v2 (delete k1 l) r
else
node k2 tb2 v2 l (delete k1 r)
| otherwise = s
node :: (Routable k) => AddrRange k -> k -> Maybe a -> IPRTable k a -> IPRTable k a -> IPRTable k a
node _ _ Nothing Nil r = r
node _ _ Nothing l Nil = l
node k tb v l r = Node k tb v l r
lookup :: Routable k => AddrRange k -> IPRTable k a -> Maybe a
lookup k s = fmap snd (search k s Nothing)
lookupKeyValue :: Routable k => AddrRange k -> IPRTable k a -> Maybe (AddrRange k, a)
lookupKeyValue k s = search k s Nothing
search :: Routable k => AddrRange k
-> IPRTable k a
-> Maybe (AddrRange k, a)
-> Maybe (AddrRange k, a)
search _ Nil res = res
search k1 (Node k2 tb2 Nothing l r) res
| k1 == k2 = res
| k2 >:> k1 = if isLeft k1 tb2 then
search k1 l res
else
search k1 r res
| otherwise = res
search k1 (Node k2 tb2 (Just vl) l r) res
| k1 == k2 = Just (k1, vl)
| k2 >:> k1 = if isLeft k1 tb2 then
search k1 l $ Just (k2, vl)
else
search k1 r $ Just (k2, vl)
| otherwise = res
lookupAll :: Routable k => AddrRange k -> IPRTable k a -> [(AddrRange k, a)]
lookupAll range = go []
where
go acc Nil = acc
go acc (Node k tb Nothing l r)
| k == range = acc
| k >:> range = go acc $ if isLeft range tb then l else r
| otherwise = acc
go acc (Node k tb (Just v) l r)
| k == range = (k,v):acc
| k >:> range = go ((k,v):acc) $ if isLeft range tb then l else r
| otherwise = acc
findMatch :: Alternative m => Routable k => AddrRange k -> IPRTable k a -> m (AddrRange k, a)
findMatch _ Nil = A.empty
findMatch k1 (Node k2 _ Nothing l r)
| k1 >:> k2 = findMatch k1 l <|> findMatch k1 r
| k2 >:> k1 = findMatch k1 l <|> findMatch k1 r
| otherwise = A.empty
findMatch k1 (Node k2 _ (Just vl) l r)
| k1 >:> k2 = pure (k2, vl) <|> findMatch k1 l <|> findMatch k1 r
| k2 >:> k1 = findMatch k1 l <|> findMatch k1 r
| otherwise = A.empty
fromList :: Routable k => [(AddrRange k, a)] -> IPRTable k a
fromList = foldl' (\s (k,v) -> insert k v s) empty
toList :: Routable k => IPRTable k a -> [(AddrRange k, a)]
toList = foldt toL []
where
toL Nil xs = xs
toL (Node _ _ Nothing _ _) xs = xs
toL (Node k _ (Just a) _ _) xs = (k,a) : xs
foldt :: (IPRTable k a -> b -> b) -> b -> IPRTable k a -> b
foldt _ v Nil = v
foldt func v rt@(Node _ _ _ l r) = foldt func (foldt func (func rt v) l) r
foldlWithKey :: (b -> AddrRange k -> a -> b) -> b -> IPRTable k a -> b
foldlWithKey f zr = go zr
where
go z Nil = z
go z (Node _ _ Nothing l r) = go (go z l) r
go z (Node n _ (Just v) l r) = go (f (go z l) n v) r
{-# INLINE foldlWithKey #-}
foldrWithKey :: (AddrRange k -> a -> b -> b) -> b -> IPRTable k a -> b
foldrWithKey f zr = go zr
where
go z Nil = z
go z (Node _ _ Nothing l r) = go (go z r) l
go z (Node n _ (Just v) l r) = go (f n v (go z r)) l
{-# INLINE foldrWithKey #-}