{-# LANGUAGE RecordWildCards #-}

module Hans.IP4.RoutingTable (
    Route(..), RouteType(..),
    routeSource, routeNextHop,
    RoutingTable,
    empty,
    addRule,
    deleteRule,
    lookupRoute,
    isLocal,
    getRoutes,
    routesForDev,
  ) where

import Hans.Device.Types (Device)
import Hans.IP4.Packet

import Control.Monad (guard)
import Data.Bits ((.&.))
import Data.List (insertBy)
import Data.Maybe (mapMaybe)
import Data.Word (Word32)


data RouteType = Direct
               | Indirect !IP4
               | Loopback

data Route = Route { routeNetwork :: {-# UNPACK #-} !IP4Mask
                   , routeType    ::                !RouteType
                   , routeDevice  ::                !Device
                   }

routeSource :: Route -> IP4
routeSource Route { routeNetwork = IP4Mask addr _ } = addr

routeNextHop :: IP4 -> Route -> IP4
routeNextHop dest route =
  case routeType route of
    Direct           -> dest
    Indirect nextHop -> nextHop
    Loopback         -> routeSource route


data Rule = Rule { ruleMask   :: {-# UNPACK #-} !Word32
                 , rulePrefix :: {-# UNPACK #-} !Word32
                 , ruleRoute  ::                !Route
                 }

ruleMaskLen :: Rule -> Int
ruleMaskLen rule = maskBits (routeNetwork (ruleRoute rule))

ruleSource :: Rule -> IP4
ruleSource rule = maskAddr (routeNetwork (ruleRoute rule))

ruleDevice :: Rule -> Device
ruleDevice rule = routeDevice (ruleRoute rule)

mkRule :: Route -> Rule
mkRule ruleRoute = Rule { .. }
  where
  IP4Mask (IP4 w) bits = routeNetwork ruleRoute

  ruleMask             = netmask bits
  rulePrefix           = ruleMask .&. w

routesTo :: Rule -> IP4 -> Bool
routesTo Rule { .. } (IP4 addr) = addr .&. ruleMask == rulePrefix

-- | Simple routing.
data RoutingTable = RoutingTable { rtRules :: [Rule]
                                   -- ^ Insertions must keep this list ordered
                                   -- by the network prefix length, descending.

                                 , rtDefault :: !(Maybe Route)
                                   -- ^ Optional default route.
                                 }

empty :: RoutingTable
empty  = RoutingTable { rtRules = [], rtDefault = Nothing }

getRoutes :: RoutingTable -> [Route]
getRoutes RoutingTable { .. } = map ruleRoute rtRules

addRule :: Bool -> Route -> RoutingTable -> RoutingTable
addRule isDefault route RoutingTable { .. } =
  rule `seq`
    RoutingTable { rtRules   = insertBy maskLenDesc rule rtRules
                 , rtDefault = if isDefault
                                  then Just route
                                  else rtDefault
                 }

  where

  -- compare b to a, to get descending order
  maskLenDesc a b = compare (ruleMaskLen b) (ruleMaskLen a)

  rule = mkRule route

deleteRule :: IP4Mask -> RoutingTable -> RoutingTable
deleteRule mask RoutingTable { .. } =
  rules' `seq` def' `seq` RoutingTable { rtRules = rules', rtDefault = def' }

  where

  rules' =
    do rule <- rtRules
       guard (routeNetwork (ruleRoute rule) /= mask)
       return rule

  def' =
    case rtDefault of
      Just Route { .. } | routeNetwork == mask -> Nothing
      _                                        -> rtDefault

lookupRoute :: IP4 -> RoutingTable -> Maybe Route
lookupRoute dest RoutingTable { .. } = foldr findRoute rtDefault rtRules
  where
  findRoute rule continue
    | rule `routesTo` dest = Just (ruleRoute rule)
    | otherwise            = continue

-- | If the address given is the source address for a rule in the table, return
-- the associated 'Device'.
isLocal :: IP4 -> RoutingTable -> Maybe Route
isLocal addr RoutingTable { .. } = foldr hasSource Nothing rtRules
  where
  hasSource rule continue
    | ruleSource rule == addr = Just (ruleRoute rule)
    | otherwise               = continue

-- | Give back routes that involve this device.
routesForDev :: Device -> RoutingTable -> [Route]
routesForDev dev RoutingTable { .. } = mapMaybe usesDev rtRules
  where
  usesDev rule | ruleDevice rule == dev = Just (ruleRoute rule)
               | otherwise              = Nothing