module Hans.Message.Dns (
DNSPacket(..)
, DNSHeader(..)
, OpCode(..)
, RespCode(..)
, Query(..)
, QClass(..)
, QType(..)
, RR(..)
, Type(..)
, Class(..)
, RData(..)
, Name
, parseDNSPacket, getDNSPacket
, renderDNSPacket, putDNSPacket
) where
import Hans.Address.IP4
import Hans.Utils (chunk)
import Control.Monad
import Data.Bits
import Data.Foldable ( traverse_, foldMap )
import Data.Int
import Data.Serialize ( Putter, runPut, putWord8, putWord16be, putWord32be
, putByteString )
import Data.Word
import MonadLib ( lift, StateT, runStateT, get, set )
import Numeric ( showHex )
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import qualified Data.Map.Strict as Map
import qualified Data.Serialize.Get as C
data DNSPacket = DNSPacket { dnsHeader :: DNSHeader
, dnsQuestions :: [Query]
, dnsAnswers :: [RR]
, dnsAuthorityRecords :: [RR]
, dnsAdditionalRecords :: [RR]
} deriving (Show)
data DNSHeader = DNSHeader { dnsId :: !Word16
, dnsQuery :: Bool
, dnsOpCode :: OpCode
, dnsAA :: Bool
, dnsTC :: Bool
, dnsRD :: Bool
, dnsRA :: Bool
, dnsRC :: RespCode
} deriving (Show)
data OpCode = OpQuery
| OpIQuery
| OpStatus
| OpReserved !Word16
deriving (Show)
data RespCode = RespNoError
| RespFormatError
| RespServerFailure
| RespNameError
| RespNotImplemented
| RespRefused
| RespReserved !Word16
deriving (Eq,Show)
type Name = [S.ByteString]
data Query = Query { qName :: Name
, qType :: QType
, qClass :: QClass
} deriving (Show)
data RR = RR { rrName :: Name
, rrClass :: Class
, rrTTL :: !Int32
, rrRData :: RData
} deriving (Show)
data QType = QType Type
| AFXR
| MAILB
| MAILA
| QTAny
deriving (Show)
data Type = A
| NS
| MD
| MF
| CNAME
| SOA
| MB
| MG
| MR
| NULL
| PTR
| HINFO
| MINFO
| MX
| AAAA
deriving (Show)
data QClass = QClass Class
| QAnyClass
deriving (Show)
data Class = IN | CS | CH | HS
deriving (Show,Eq)
data RData = RDA IP4
| RDNS Name
| RDMD Name
| RDMF Name
| RDCNAME Name
| RDSOA Name Name !Word32 !Int32 !Int32 !Int32 !Word32
| RDMB Name
| RDMG Name
| RDMR Name
| RDPTR Name
| RDHINFO S.ByteString S.ByteString
| RDMINFO Name Name
| RDMX !Word16 Name
| RDNULL S.ByteString
| RDUnknown Type S.ByteString
deriving (Show)
data RW = RW { rwOffset :: !Int
, rwLabels :: Map.Map Int Name
} deriving (Show)
type Get = StateT RW C.Get
unGet :: Get a -> C.Get a
unGet m =
do (a,_) <- runStateT RW { rwOffset = 0, rwLabels = Map.empty } m
return a
getOffset :: Get Int
getOffset = rwOffset `fmap` get
addOffset :: Int -> Get ()
addOffset off =
do rw <- get
set $! rw { rwOffset = rwOffset rw + off }
lookupPtr :: Int -> Get Name
lookupPtr off =
do rw <- get
when (off >= rwOffset rw) (fail "Invalid offset in pointer")
case Map.lookup off (rwLabels rw) of
Just ls -> return ls
Nothing -> fail $ "Unknown label for offset: " ++ showHex off "\n"
++ show (rwLabels rw)
data Label = Label Int S.ByteString
| Ptr Int Name
deriving (Show)
labelsToName :: [Label] -> Name
labelsToName = foldMap toName
where
toName (Label _ l) = [l]
toName (Ptr _ n) = n
addLabels :: [Label] -> Get ()
addLabels labels =
do rw <- get
set $! rw { rwLabels = Map.fromList newLabels `Map.union` rwLabels rw }
where
newLabels = go labels (labelsToName labels)
go (Label off _ : rest) name@(_ : ns) = (off,name) : go rest ns
go (Ptr off _ : _) name = [(off,name)]
go _ _ = []
liftGet :: Int -> C.Get a -> Get a
liftGet n m = do addOffset n
lift m
getWord8 :: Get Word8
getWord8 = liftGet 1 C.getWord8
getWord16be :: Get Word16
getWord16be = liftGet 2 C.getWord16be
getWord32be :: Get Word32
getWord32be = liftGet 4 C.getWord32be
getInt32be :: Get Int32
getInt32be = fromIntegral `fmap` liftGet 4 C.getWord32be
getBytes :: Int -> Get S.ByteString
getBytes n = liftGet n (C.getBytes n)
isolate :: Int -> Get a -> Get a
isolate n body =
do off <- get
(a,off') <- lift (C.isolate n (runStateT off body))
set off'
return a
label :: String -> Get a -> Get a
label str m =
do off <- get
(a,off') <- lift (C.label str (runStateT off m))
set off'
return a
putInt32be :: Putter Int32
putInt32be i = putWord32be (fromIntegral i)
parseDNSPacket :: S.ByteString -> Either String DNSPacket
parseDNSPacket = C.runGet getDNSPacket
getDNSPacket :: C.Get DNSPacket
getDNSPacket = unGet $ label "DNSPacket" $
do dnsHeader <- getDNSHeader
qdCount <- getWord16be
anCount <- getWord16be
nsCount <- getWord16be
arCount <- getWord16be
let blockOf c l m = label l (replicateM (fromIntegral c) m)
dnsQuestions <- blockOf qdCount "Questions" getQuery
dnsAnswers <- blockOf anCount "Answers" getRR
dnsAuthorityRecords <- blockOf nsCount "Authority Records" getRR
dnsAdditionalRecords <- blockOf arCount "Additional Records" getRR
return DNSPacket { .. }
getDNSHeader :: Get DNSHeader
getDNSHeader = label "DNS Header" $
do dnsId <- getWord16be
flags <- getWord16be
let dnsQuery = not (flags `testBit` 15)
dnsOpCode = parseOpCode (flags `shiftR` 11)
dnsAA = flags `testBit` 10
dnsTC = flags `testBit` 9
dnsRD = flags `testBit` 8
dnsRA = flags `testBit` 7
dnsZ = (flags `shiftR` 4) .&. 0x7
dnsRC = parseRespCode (flags .&. 0xf)
unless (dnsZ == 0) (fail ("Z not zero"))
return DNSHeader { .. }
parseOpCode :: Word16 -> OpCode
parseOpCode 0 = OpQuery
parseOpCode 1 = OpIQuery
parseOpCode 2 = OpStatus
parseOpCode c = OpReserved (c .&. 0xf)
parseRespCode :: Word16 -> RespCode
parseRespCode 0 = RespNoError
parseRespCode 1 = RespFormatError
parseRespCode 2 = RespServerFailure
parseRespCode 3 = RespNameError
parseRespCode 4 = RespNotImplemented
parseRespCode 5 = RespRefused
parseRespCode c = RespReserved (c .&. 0xf)
getQuery :: Get Query
getQuery = label "Question" $
do qName <- getName
qType <- label "QTYPE" getQType
qClass <- label "QCLASS" getQClass
return Query { .. }
getRR :: Get RR
getRR = label "RR" $
do rrName <- getName
ty <- getType
rrClass <- getClass
rrTTL <- getInt32be
rrRData <- getRData ty
return RR { .. }
getType :: Get Type
getType =
do qt <- getQType
case qt of
QType ty -> return ty
_ -> fail ("Invalid TYPE: " ++ show qt)
getQType :: Get QType
getQType =
do tag <- getWord16be
case tag of
1 -> return (QType A)
2 -> return (QType NS)
3 -> return (QType MD)
4 -> return (QType MF)
5 -> return (QType CNAME)
6 -> return (QType SOA)
7 -> return (QType MB)
8 -> return (QType MG)
9 -> return (QType MR)
10 -> return (QType NULL)
12 -> return (QType PTR)
13 -> return (QType HINFO)
14 -> return (QType MINFO)
15 -> return (QType MX)
28 -> return (QType AAAA)
252 -> return AFXR
253 -> return MAILB
254 -> return MAILA
255 -> return QTAny
_ -> fail ("Invalid TYPE: " ++ show tag)
getQClass :: Get QClass
getQClass =
do tag <- getWord16be
case tag of
1 -> return (QClass IN)
2 -> return (QClass CS)
3 -> return (QClass CH)
4 -> return (QClass HS)
255 -> return QAnyClass
_ -> fail ("Invalid CLASS: " ++ show tag)
getName :: Get Name
getName =
do labels <- go
addLabels labels
return (labelsToName labels)
where
go = do off <- getOffset
len <- getWord8
if | len .&. 0xc0 == 0xc0 ->
do l <- getWord8
let ptr = fromIntegral ((0x3f .&. len) `shiftL` 8)
+ fromIntegral l
ns <- lookupPtr ptr
return [Ptr off ns]
| len == 0 ->
return []
| otherwise ->
do l <- getBytes (fromIntegral len)
ls <- go
return (Label off l:ls)
getClass :: Get Class
getClass = label "CLASS" $
do qc <- getQClass
case qc of
QClass c -> return c
QAnyClass -> fail "Invalid CLASS"
getRData :: Type -> Get RData
getRData ty = label (show ty) $
do len <- getWord16be
isolate (fromIntegral len) $ case ty of
A -> RDA `fmap` liftGet 4 parseIP4
NS -> RDNS `fmap` getName
MD -> RDMD `fmap` getName
MF -> RDMF `fmap` getName
CNAME -> RDCNAME `fmap` getName
SOA -> do mname <- getName
rname <- getName
serial <- getWord32be
refresh <- getInt32be
retry <- getInt32be
expire <- getInt32be
minTTL <- getWord32be
return (RDSOA mname rname serial refresh retry expire minTTL)
MB -> RDMB `fmap` getName
MG -> RDMG `fmap` getName
MR -> RDMR `fmap` getName
NULL -> RDNULL `fmap` (getBytes =<< lift C.remaining)
PTR -> RDPTR `fmap` getName
HINFO -> do cpuLen <- getWord8
cpu <- getBytes (fromIntegral cpuLen)
osLen <- getWord8
os <- getBytes (fromIntegral osLen)
return (RDHINFO cpu os)
MINFO -> do rmailBx <- getName
emailBx <- getName
return (RDMINFO rmailBx emailBx)
MX -> do pref <- getWord16be
ex <- getName
return (RDMX pref ex)
_ -> RDUnknown ty `fmap` (getBytes =<< lift C.remaining)
renderDNSPacket :: DNSPacket -> L.ByteString
renderDNSPacket pkt = chunk (runPut (putDNSPacket pkt))
putDNSPacket :: Putter DNSPacket
putDNSPacket DNSPacket{ .. } =
do putDNSHeader dnsHeader
putWord16be (fromIntegral (length dnsQuestions))
putWord16be (fromIntegral (length dnsAnswers))
putWord16be (fromIntegral (length dnsAuthorityRecords))
putWord16be (fromIntegral (length dnsAdditionalRecords))
traverse_ putQuery dnsQuestions
traverse_ putRR dnsAnswers
traverse_ putRR dnsAuthorityRecords
traverse_ putRR dnsAdditionalRecords
putDNSHeader :: Putter DNSHeader
putDNSHeader DNSHeader { .. } =
do putWord16be dnsId
let flag i b w | b = setBit w i
| otherwise = clearBit w i
flags = flag 15 (not dnsQuery)
$ flag 10 dnsAA
$ flag 9 dnsTC
$ flag 8 dnsRD
$ flag 7 dnsRA
$ flag 4 False
$ (renderOpCode dnsOpCode `shiftL` 11) .|. renderRespCode dnsRC
putWord16be flags
renderOpCode :: OpCode -> Word16
renderOpCode OpQuery = 0
renderOpCode OpIQuery = 1
renderOpCode OpStatus = 2
renderOpCode (OpReserved c) = c .&. 0xf
renderRespCode :: RespCode -> Word16
renderRespCode RespNoError = 0
renderRespCode RespFormatError = 1
renderRespCode RespServerFailure = 2
renderRespCode RespNameError = 3
renderRespCode RespNotImplemented = 4
renderRespCode RespRefused = 5
renderRespCode (RespReserved c) = c .&. 0xf
putName :: Putter Name
putName = go
where
go (l:ls)
| S.null l = putWord8 0
| S.length l > 63 = error "Label too big"
| otherwise = do putWord8 (fromIntegral len)
putByteString l
go ls
where
len = S.length l
go [] = putWord8 0
putQuery :: Putter Query
putQuery Query { .. } =
do putName qName
putQType qType
putQClass qClass
putType :: Putter Type
putType A = putWord16be 1
putType NS = putWord16be 2
putType MD = putWord16be 3
putType MF = putWord16be 4
putType CNAME = putWord16be 5
putType SOA = putWord16be 6
putType MB = putWord16be 7
putType MG = putWord16be 8
putType MR = putWord16be 9
putType NULL = putWord16be 10
putType PTR = putWord16be 12
putType HINFO = putWord16be 13
putType MINFO = putWord16be 14
putType MX = putWord16be 15
putType AAAA = putWord16be 28
putQType :: Putter QType
putQType (QType ty) = putType ty
putQType AFXR = putWord16be 252
putQType MAILB = putWord16be 253
putQType MAILA = putWord16be 254
putQType QTAny = putWord16be 255
putQClass :: Putter QClass
putQClass (QClass c) = putClass c
putQClass QAnyClass = putWord16be 255
putRR :: Putter RR
putRR RR { .. } =
do putName rrName
let (ty,rdata) = putRData rrRData
putType ty
putClass rrClass
putWord32be (fromIntegral rrTTL)
putWord16be (fromIntegral (S.length rdata))
putByteString rdata
putClass :: Putter Class
putClass IN = putWord16be 1
putClass CS = putWord16be 2
putClass CH = putWord16be 3
putClass HS = putWord16be 4
putRData :: RData -> (Type,S.ByteString)
putRData rd = case rd of
RDA addr -> rdata A (renderIP4 addr)
RDNS name -> rdata NS (putName name)
RDMD name -> rdata MD (putName name)
RDMF name -> rdata MF (putName name)
RDCNAME name -> rdata CNAME (putName name)
RDSOA m r s f t ex ttl ->
rdata SOA $ do putName m
putName r
putWord32be s
putInt32be f
putInt32be t
putInt32be ex
putWord32be ttl
RDMB name -> rdata MB (putName name)
RDMG name -> rdata MG (putName name)
RDMR name -> rdata MR (putName name)
RDNULL bytes -> rdata NULL $ do putWord8 (fromIntegral (S.length bytes))
putByteString bytes
RDPTR name -> rdata PTR (putName name)
RDHINFO cpu os -> rdata HINFO $ do putWord8 (fromIntegral (S.length cpu))
putByteString cpu
putWord8 (fromIntegral (S.length os))
putByteString os
RDMINFO rm em -> rdata MINFO $ do putName rm
putName em
RDMX pref ex -> rdata MX $ do putWord16be pref
putName ex
RDUnknown ty bytes -> (ty,bytes)
where
rdata tag m = (tag,runPut m)