module Hans.Message.Tcp where
import Hans.Address.IP4 (IP4)
import Hans.Message.Ip4 (mkIP4PseudoHeader,IP4Protocol(..))
import Hans.Utils (chunk)
import Hans.Utils.Checksum
import Control.Monad (unless,ap,replicateM_,replicateM)
import Data.Bits ((.&.),setBit,testBit,shiftL,shiftR)
import Data.List (find)
import Data.Monoid (Monoid(..))
import Data.Serialize
(Get,Put,Putter,getWord16be,putWord16be,getWord32be,putWord32be,getWord8
,putWord8,putByteString,getBytes,remaining,label,isolate,skip,runGet,runPut
,putLazyByteString)
import Data.Word (Word8,Word16,Word32)
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString as S
import qualified Data.Foldable as F
tcpProtocol :: IP4Protocol
tcpProtocol = IP4Protocol 0x6
newtype TcpPort = TcpPort
{ getPort :: Word16
} deriving (Eq,Ord,Read,Show,Num,Enum,Bounded)
putTcpPort :: Putter TcpPort
putTcpPort (TcpPort w16) = putWord16be w16
getTcpPort :: Get TcpPort
getTcpPort = TcpPort `fmap` getWord16be
newtype TcpSeqNum = TcpSeqNum
{ getSeqNum :: Word32
} deriving (Eq,Ord,Show,Num,Bounded,Enum,Real,Integral)
instance Monoid TcpSeqNum where
mempty = 0
mappend = (+)
putTcpSeqNum :: Putter TcpSeqNum
putTcpSeqNum (TcpSeqNum w32) = putWord32be w32
getTcpSeqNum :: Get TcpSeqNum
getTcpSeqNum = TcpSeqNum `fmap` getWord32be
type TcpAckNum = TcpSeqNum
putTcpAckNum :: Putter TcpAckNum
putTcpAckNum = putTcpSeqNum
getTcpAckNum :: Get TcpAckNum
getTcpAckNum = getTcpSeqNum
data TcpHeader = TcpHeader
{ tcpSourcePort :: !TcpPort
, tcpDestPort :: !TcpPort
, tcpSeqNum :: !TcpSeqNum
, tcpAckNum :: !TcpAckNum
, tcpCwr :: !Bool
, tcpEce :: !Bool
, tcpUrg :: !Bool
, tcpAck :: !Bool
, tcpPsh :: !Bool
, tcpRst :: !Bool
, tcpSyn :: !Bool
, tcpFin :: !Bool
, tcpWindow :: !Word16
, tcpChecksum :: !Word16
, tcpUrgentPointer :: !Word16
, tcpOptions :: [TcpOption]
} deriving (Eq,Show)
instance HasTcpOptions TcpHeader where
findTcpOption tag hdr = findTcpOption tag (tcpOptions hdr)
setTcpOption opt hdr = hdr { tcpOptions = setTcpOption opt (tcpOptions hdr) }
emptyTcpHeader :: TcpHeader
emptyTcpHeader = TcpHeader
{ tcpSourcePort = TcpPort 0
, tcpDestPort = TcpPort 0
, tcpSeqNum = 0
, tcpAckNum = 0
, tcpCwr = False
, tcpEce = False
, tcpUrg = False
, tcpAck = False
, tcpPsh = False
, tcpRst = False
, tcpSyn = False
, tcpFin = False
, tcpWindow = 0
, tcpChecksum = 0
, tcpUrgentPointer = 0
, tcpOptions = []
}
tcpFixedHeaderLength :: Int
tcpFixedHeaderLength = 5
putTcpHeader :: Putter TcpHeader
putTcpHeader hdr = do
putTcpPort (tcpSourcePort hdr)
putTcpPort (tcpDestPort hdr)
putTcpSeqNum (tcpSeqNum hdr)
putTcpAckNum (tcpAckNum hdr)
let (optLen,padding) = tcpOptionsLength (tcpOptions hdr)
putWord8 (fromIntegral ((tcpFixedHeaderLength + optLen) `shiftL` 4))
putTcpControl hdr
putWord16be (tcpWindow hdr)
putWord16be 0
putWord16be (tcpUrgentPointer hdr)
mapM_ putTcpOption (tcpOptions hdr)
replicateM_ padding (putTcpOptionTag OptTagEndOfOptions)
getTcpHeader :: Get (TcpHeader,Int)
getTcpHeader = label "TcpHeader" $ do
src <- getTcpPort
dst <- getTcpPort
seqNum <- getTcpSeqNum
ackNum <- getTcpAckNum
b <- getWord8
let len = fromIntegral ((b `shiftR` 4) .&. 0xf)
cont <- getWord8
win <- getWord16be
cs <- getWord16be
urgent <- getWord16be
let optsLen = len tcpFixedHeaderLength
opts <- label "options" (isolate (optsLen `shiftL` 2) getTcpOptions)
let hdr = setTcpControl cont emptyTcpHeader
{ tcpSourcePort = src
, tcpDestPort = dst
, tcpSeqNum = seqNum
, tcpAckNum = ackNum
, tcpWindow = win
, tcpChecksum = cs
, tcpUrgentPointer = urgent
, tcpOptions = filter (/= OptEndOfOptions) opts
}
return (hdr,len * 4)
putTcpControl :: Putter TcpHeader
putTcpControl c =
putWord8 $ putBit 7 tcpCwr
$ putBit 6 tcpEce
$ putBit 5 tcpUrg
$ putBit 4 tcpAck
$ putBit 3 tcpPsh
$ putBit 2 tcpRst
$ putBit 1 tcpSyn
$ putBit 0 tcpFin
0
where
putBit n prj w | prj c = setBit w n
| otherwise = w
setTcpControl :: Word8 -> TcpHeader -> TcpHeader
setTcpControl w hdr = hdr
{ tcpCwr = testBit w 7
, tcpEce = testBit w 6
, tcpUrg = testBit w 5
, tcpAck = testBit w 4
, tcpPsh = testBit w 3
, tcpRst = testBit w 2
, tcpSyn = testBit w 1
, tcpFin = testBit w 0
}
class HasTcpOptions a where
findTcpOption :: TcpOptionTag -> a -> Maybe TcpOption
setTcpOption :: TcpOption -> a -> a
setTcpOptions :: HasTcpOptions a => [TcpOption] -> a -> a
setTcpOptions opts a = foldr setTcpOption a opts
data TcpOptionTag
= OptTagEndOfOptions
| OptTagNoOption
| OptTagMaxSegmentSize
| OptTagWindowScaling
| OptTagSackPermitted
| OptTagSack
| OptTagTimestamp
| OptTagUnknown !Word8
deriving (Eq,Show)
getTcpOptionTag :: Get TcpOptionTag
getTcpOptionTag = do
ty <- getWord8
return $! case ty of
0 -> OptTagEndOfOptions
1 -> OptTagNoOption
2 -> OptTagMaxSegmentSize
3 -> OptTagWindowScaling
4 -> OptTagSackPermitted
5 -> OptTagSack
8 -> OptTagTimestamp
_ -> OptTagUnknown ty
putTcpOptionTag :: Putter TcpOptionTag
putTcpOptionTag tag =
putWord8 $ case tag of
OptTagEndOfOptions -> 0
OptTagNoOption -> 1
OptTagMaxSegmentSize -> 2
OptTagWindowScaling -> 3
OptTagSackPermitted -> 4
OptTagSack -> 5
OptTagTimestamp -> 8
OptTagUnknown ty -> ty
instance HasTcpOptions [TcpOption] where
findTcpOption tag = find p
where
p opt = tag == tcpOptionTag opt
setTcpOption opt = loop
where
tag = tcpOptionTag opt
loop [] = [opt]
loop (o:opts)
| tcpOptionTag o == tag = opt : opts
| otherwise = o : loop opts
data TcpOption
= OptEndOfOptions
| OptNoOption
| OptMaxSegmentSize !Word16
| OptWindowScaling !Word8
| OptSackPermitted
| OptSack [SackBlock]
| OptTimestamp !Word32 !Word32
| OptUnknown !Word8 !Word8 !S.ByteString
deriving (Show,Eq)
data SackBlock = SackBlock
{ sbLeft :: !TcpSeqNum
, sbRight :: !TcpSeqNum
} deriving (Show,Eq)
tcpOptionTag :: TcpOption -> TcpOptionTag
tcpOptionTag opt = case opt of
OptEndOfOptions{} -> OptTagEndOfOptions
OptNoOption{} -> OptTagNoOption
OptMaxSegmentSize{} -> OptTagMaxSegmentSize
OptSackPermitted{} -> OptTagSackPermitted
OptSack{} -> OptTagSack
OptWindowScaling{} -> OptTagWindowScaling
OptTimestamp{} -> OptTagTimestamp
OptUnknown ty _ _ -> OptTagUnknown ty
tcpOptionsLength :: [TcpOption] -> (Int,Int)
tcpOptionsLength opts
| left == 0 = (len,0)
| otherwise = (len + 1,4 left)
where
(len,left) = F.sum (fmap tcpOptionLength opts) `quotRem` 4
tcpOptionLength :: TcpOption -> Int
tcpOptionLength opt = case opt of
OptEndOfOptions{} -> 1
OptNoOption{} -> 1
OptMaxSegmentSize{} -> 4
OptWindowScaling{} -> 3
OptSackPermitted{} -> 2
OptSack bs -> sackLength bs
OptTimestamp{} -> 10
OptUnknown _ len _ -> fromIntegral len
putTcpOption :: Putter TcpOption
putTcpOption opt = do
putTcpOptionTag (tcpOptionTag opt)
case opt of
OptEndOfOptions -> return ()
OptNoOption -> return ()
OptMaxSegmentSize mss -> putMaxSegmentSize mss
OptWindowScaling w -> putWindowScaling w
OptSackPermitted -> putSackPermitted
OptSack bs -> putSack bs
OptTimestamp v r -> putTimestamp v r
OptUnknown _ len bs -> putUnknown len bs
getTcpOptions :: Get [TcpOption]
getTcpOptions = label "Tcp Options" loop
where
loop = do
left <- remaining
if left > 0 then body else return []
body = do
opt <- getTcpOption
case opt of
OptEndOfOptions -> do
skip =<< remaining
return []
_ -> do
rest <- loop
return (opt:rest)
getTcpOption :: Get TcpOption
getTcpOption = do
tag <- getTcpOptionTag
case tag of
OptTagEndOfOptions -> return OptEndOfOptions
OptTagNoOption -> return OptNoOption
OptTagMaxSegmentSize -> getMaxSegmentSize
OptTagWindowScaling -> getWindowScaling
OptTagSackPermitted -> getSackPermitted
OptTagSack -> getSack
OptTagTimestamp -> getTimestamp
OptTagUnknown ty -> getUnknown ty
getMaxSegmentSize :: Get TcpOption
getMaxSegmentSize = label "Max Segment Size" $ isolate 3 $ do
len <- getWord8
unless (len == 4) (fail ("Unexpected length: " ++ show len))
OptMaxSegmentSize `fmap` getWord16be
putMaxSegmentSize :: Putter Word16
putMaxSegmentSize w16 = do
putWord8 4
putWord16be w16
getSackPermitted :: Get TcpOption
getSackPermitted = label "Sack Permitted" $ isolate 1 $ do
len <- getWord8
unless (len == 2) (fail ("Unexpected length: " ++ show len))
return OptSackPermitted
putSackPermitted :: Put
putSackPermitted = do
putWord8 2
getSack :: Get TcpOption
getSack = label "Sack" $ do
len <- getWord8
let edgeLen = fromIntegral len 2
OptSack `fmap` isolate edgeLen (replicateM (edgeLen `shiftR` 3) getSackBlock)
putSack :: Putter [SackBlock]
putSack bs = do
putWord8 (fromIntegral (sackLength bs))
mapM_ putSackBlock bs
getSackBlock :: Get SackBlock
getSackBlock = do
l <- getTcpSeqNum
r <- getTcpSeqNum
return $! SackBlock
{ sbLeft = l
, sbRight = r
}
putSackBlock :: Putter SackBlock
putSackBlock sb = do
putTcpSeqNum (sbLeft sb)
putTcpSeqNum (sbRight sb)
sackLength :: [SackBlock] -> Int
sackLength bs = length bs * 8 + 2
getWindowScaling :: Get TcpOption
getWindowScaling = label "Window Scaling" $ isolate 2 $ do
len <- getWord8
unless (len == 3) (fail ("Unexpected length: " ++ show len))
OptWindowScaling `fmap` getWord8
putWindowScaling :: Putter Word8
putWindowScaling w = do
putWord8 3
putWord8 w
getTimestamp :: Get TcpOption
getTimestamp = label "Timestamp" $ isolate 9 $ do
len <- getWord8
unless (len == 10) (fail ("Unexpected length: " ++ show len))
OptTimestamp `fmap` getWord32be `ap` getWord32be
putTimestamp :: Word32 -> Word32 -> Put
putTimestamp v r = do
putWord8 10
putWord32be v
putWord32be r
getUnknown :: Word8 -> Get TcpOption
getUnknown ty = do
len <- getWord8
body <- isolate (fromIntegral len 2) (getBytes =<< remaining)
return (OptUnknown ty len body)
putUnknown :: Word8 -> S.ByteString -> Put
putUnknown len body = do
putWord8 len
putByteString body
parseTcpPacket :: S.ByteString -> Either String (TcpHeader,S.ByteString)
parseTcpPacket bytes = runGet getTcpPacket bytes
getTcpPacket :: Get (TcpHeader,S.ByteString)
getTcpPacket = do
pktLen <- remaining
(hdr,hdrLen) <- getTcpHeader
body <- getBytes (pktLen hdrLen)
return (hdr,body)
putTcpPacket :: TcpHeader -> L.ByteString -> Put
putTcpPacket hdr body = do
putTcpHeader hdr
putLazyByteString body
renderWithTcpChecksumIP4 :: IP4 -> IP4 -> TcpHeader -> L.ByteString
-> L.ByteString
renderWithTcpChecksumIP4 src dst hdr body = chunk hdrbs `L.append` body
where
(hdrbs,_) = computeTcpChecksumIP4 src dst hdr body
computeTcpChecksumIP4 :: IP4 -> IP4 -> TcpHeader -> L.ByteString
-> (S.ByteString,Word16)
computeTcpChecksumIP4 src dst hdr body =
(cs `seq` unsafePerformIO (pokeChecksum cs hdrbs 16), cs)
where
phcs = computePartialChecksum emptyPartialChecksum
$ mkIP4PseudoHeader src dst tcpProtocol
$ S.length hdrbs + fromIntegral (L.length body)
hdrbs = runPut (putTcpHeader hdr { tcpChecksum = 0 })
hdrcs = computePartialChecksum phcs hdrbs
cs = finalizeChecksum (computePartialChecksumLazy hdrcs body)
validateTcpChecksumIP4 :: IP4 -> IP4 -> S.ByteString -> Bool
validateTcpChecksumIP4 src dst bytes =
finalizeChecksum (computePartialChecksum phcs bytes) == 0
where
phcs = computePartialChecksum emptyPartialChecksum
$ mkIP4PseudoHeader src dst tcpProtocol
$ S.length bytes