module Network.DNS.Resolver (
FileOrNumericHost(..), ResolvConf(..), defaultResolvConf
, ResolvSeed, makeResolvSeed
, Resolver(..), withResolver, withResolvers
, lookup
, lookupAuth
, lookupRaw
, lookupRawAD
, fromDNSMessage
, fromDNSFormat
) where
#if !defined(mingw32_HOST_OS)
#define POSIX
#else
#define WIN
#endif
#if __GLASGOW_HASKELL__ < 709
#define GHC708
#endif
import Control.Exception (bracket)
import Data.Char (isSpace)
import Data.List (isPrefixOf)
import Data.Maybe (fromMaybe)
import Network.BSD (getProtocolNumber)
import Network.DNS.Decode
import Network.DNS.Encode
import Network.DNS.Internal
import qualified Data.ByteString.Char8 as BS
import Network.Socket (HostName, Socket, SocketType(Stream, Datagram))
import Network.Socket (AddrInfoFlag(..), AddrInfo(..), SockAddr(..))
import Network.Socket (Family(AF_INET, AF_INET6), PortNumber(..))
import Network.Socket (close, socket, connect, getPeerName, getAddrInfo)
import Network.Socket (defaultHints, defaultProtocol)
import Prelude hiding (lookup)
import System.Random (getStdRandom, randomR)
import System.Timeout (timeout)
#ifdef GHC708
import Control.Applicative ((<$>), (<*>), pure)
#endif
#if defined(WIN) && defined(GHC708)
import Network.Socket (send)
import qualified Data.ByteString.Lazy.Char8 as LB
import Control.Monad (when)
#else
import Network.Socket.ByteString.Lazy (sendAll)
#endif
data FileOrNumericHost = RCFilePath FilePath
| RCHostName HostName
| RCHostPort HostName PortNumber
data ResolvConf = ResolvConf {
resolvInfo :: FileOrNumericHost
, resolvTimeout :: Int
, resolvRetry :: Int
, resolvBufsize :: Integer
}
defaultResolvConf :: ResolvConf
defaultResolvConf = ResolvConf {
resolvInfo = RCFilePath "/etc/resolv.conf"
, resolvTimeout = 3 * 1000 * 1000
, resolvRetry = 3
, resolvBufsize = 512
}
data ResolvSeed = ResolvSeed {
addrInfo :: AddrInfo
, rsTimeout :: Int
, rsRetry :: Int
, rsBufsize :: Integer
}
data Resolver = Resolver {
genId :: IO Int
, dnsSock :: Socket
, dnsTimeout :: Int
, dnsRetry :: Int
, dnsBufsize :: Integer
}
makeResolvSeed :: ResolvConf -> IO ResolvSeed
makeResolvSeed conf = ResolvSeed <$> addr
<*> pure (resolvTimeout conf)
<*> pure (resolvRetry conf)
<*> pure (resolvBufsize conf)
where
addr = case resolvInfo conf of
RCHostName numhost -> makeAddrInfo numhost Nothing
RCHostPort numhost mport -> makeAddrInfo numhost $ Just mport
RCFilePath file -> toAddr <$> readFile file >>= \i -> makeAddrInfo i Nothing
toAddr cs = let l:_ = filter ("nameserver" `isPrefixOf`) $ lines cs
in extract l
extract = reverse . dropWhile isSpace . reverse . dropWhile isSpace . drop 11
makeAddrInfo :: HostName -> Maybe PortNumber -> IO AddrInfo
makeAddrInfo addr mport = do
proto <- getProtocolNumber "udp"
let hints = defaultHints {
addrFlags = [AI_ADDRCONFIG, AI_NUMERICHOST, AI_PASSIVE]
, addrSocketType = Datagram
, addrProtocol = proto
}
a:_ <- getAddrInfo (Just hints) (Just addr) (Just "domain")
let connectPort = case addrAddress a of
SockAddrInet pn ha -> SockAddrInet (fromMaybe pn mport) ha
SockAddrInet6 pn fi ha sid -> SockAddrInet6 (fromMaybe pn mport) fi ha sid
unixAddr -> unixAddr
return $ a { addrAddress = connectPort }
withResolver :: ResolvSeed -> (Resolver -> IO a) -> IO a
withResolver seed func = bracket (openSocket seed) close $ \sock -> do
connectSocket sock seed
func $ makeResolver seed sock
withResolvers :: [ResolvSeed] -> ([Resolver] -> IO a) -> IO a
withResolvers seeds func = bracket openSockets closeSockets $ \socks -> do
mapM_ (uncurry connectSocket) $ zip socks seeds
let resolvs = zipWith makeResolver seeds socks
func resolvs
where
openSockets = mapM openSocket seeds
closeSockets = mapM close
openSocket :: ResolvSeed -> IO Socket
openSocket seed = socket (addrFamily ai) (addrSocketType ai) (addrProtocol ai)
where
ai = addrInfo seed
connectSocket :: Socket -> ResolvSeed -> IO ()
connectSocket sock seed = connect sock (addrAddress ai)
where
ai = addrInfo seed
makeResolver :: ResolvSeed -> Socket -> Resolver
makeResolver seed sock = Resolver {
genId = getRandom
, dnsSock = sock
, dnsTimeout = rsTimeout seed
, dnsRetry = rsRetry seed
, dnsBufsize = rsBufsize seed
}
getRandom :: IO Int
getRandom = getStdRandom (randomR (0,65535))
lookupSection :: (DNSMessage -> [ResourceRecord])
-> Resolver
-> Domain
-> TYPE
-> IO (Either DNSError [RData])
lookupSection section rlv dom typ = do
eans <- lookupRaw rlv dom typ
case eans of
Left err -> return $ Left err
Right ans -> return $ fromDNSMessage ans toRData
where
correct r = rrtype r == typ
toRData = map rdata . filter correct . section
fromDNSMessage :: DNSMessage -> (DNSMessage -> a) -> Either DNSError a
fromDNSMessage ans conv = case errcode ans of
NoErr -> Right $ conv ans
FormatErr -> Left FormatError
ServFail -> Left ServerFailure
NameErr -> Left NameError
NotImpl -> Left NotImplemented
Refused -> Left OperationRefused
BadOpt -> Left BadOptRecord
where
errcode = rcode . flags . header
fromDNSFormat :: DNSMessage -> (DNSMessage -> a) -> Either DNSError a
fromDNSFormat = fromDNSMessage
lookup :: Resolver -> Domain -> TYPE -> IO (Either DNSError [RData])
lookup = lookupSection answer
lookupAuth :: Resolver -> Domain -> TYPE -> IO (Either DNSError [RData])
lookupAuth = lookupSection authority
lookupRaw :: Resolver -> Domain -> TYPE -> IO (Either DNSError DNSMessage)
lookupRaw = lookupRawInternal receive False
lookupRawAD :: Resolver -> Domain -> TYPE -> IO (Either DNSError DNSMessage)
lookupRawAD = lookupRawInternal receive True
lookupRawInternal ::
(Socket -> IO DNSMessage)
-> Bool
-> Resolver
-> Domain
-> TYPE
-> IO (Either DNSError DNSMessage)
lookupRawInternal _ _ _ dom _
| isIllegal dom = return $ Left IllegalDomain
lookupRawInternal rcv ad rlv dom typ = do
seqno <- genId rlv
let query = (if ad then composeQueryAD else composeQuery) seqno [q]
checkSeqno = check seqno
loop query checkSeqno 0 False
where
loop query checkSeqno cnt mismatch
| cnt == retry = do
let ret | mismatch = SequenceNumberMismatch
| otherwise = RetryLimitExceeded
return $ Left ret
| otherwise = do
sendAll sock query
response <- timeout tm (rcv sock)
case response of
Nothing -> loop query checkSeqno (cnt + 1) False
Just res -> do
let valid = checkSeqno res
case valid of
False -> loop query checkSeqno (cnt + 1) False
True | not $ trunCation $ flags $ header res
-> return $ Right res
_ -> tcpRetry query sock tm
sock = dnsSock rlv
tm = dnsTimeout rlv
retry = dnsRetry rlv
q = makeQuestion dom typ
check seqno res = identifier (header res) == seqno
tcpRetry ::
Query
-> Socket
-> Int
-> IO (Either DNSError DNSMessage)
tcpRetry query sock tm = do
peer <- getPeerName sock
bracket (tcpOpen peer)
(maybe (return ()) close)
(tcpLookup query peer tm)
tcpOpen :: SockAddr -> IO (Maybe Socket)
tcpOpen peer = do
case (peer) of
SockAddrInet _ _ ->
socket AF_INET Stream defaultProtocol >>= return . Just
SockAddrInet6 _ _ _ _ ->
socket AF_INET6 Stream defaultProtocol >>= return . Just
_ -> return Nothing
tcpLookup ::
Query
-> SockAddr
-> Int
-> Maybe Socket
-> IO (Either DNSError DNSMessage)
tcpLookup _ _ _ Nothing = return $ Left ServerFailure
tcpLookup query peer tm (Just vc) = do
response <- timeout tm $ do
connect vc $ peer
sendAll vc $ encodeVC query
receiveVC vc
case response of
Nothing -> return $ Left TimeoutExpired
Just res -> return $ Right res
#if defined(WIN) && defined(GHC708)
sendAll :: Socket -> BS.ByteString -> IO ()
sendAll sock bs = do
sent <- send sock (LB.unpack bs)
when (sent < fromIntegral (LB.length bs)) $ sendAll sock (LB.drop (fromIntegral sent) bs)
#endif
isIllegal :: Domain -> Bool
isIllegal "" = True
isIllegal dom
| '.' `BS.notElem` dom = True
| ':' `BS.elem` dom = True
| '/' `BS.elem` dom = True
| BS.length dom > 253 = True
| any (\x -> BS.length x > 63)
(BS.split '.' dom) = True
isIllegal _ = False