module Network.DNS.Resolver (
FileOrNumericHost(..), ResolvConf(..), defaultResolvConf
, ResolvSeed, makeResolvSeed
, Resolver(..), withResolver, withResolvers
, lookup
, lookupAuth
, lookupRaw
, lookupRaw'
, fromDNSFormat
) where
import Control.Applicative ((<$>), (<*>), pure)
import Control.Exception (bracket)
import Data.Char (isSpace)
import Data.List (isPrefixOf)
import Data.Maybe (fromMaybe)
import Network.BSD (getProtocolNumber)
import Network.DNS.Decode
import Network.DNS.Encode
import Network.DNS.Internal
import qualified Data.ByteString.Char8 as BS
import Network.Socket (HostName, Socket, SocketType(Datagram), sClose, socket, connect)
import Network.Socket (AddrInfoFlag(..), AddrInfo(..), SockAddr(..), PortNumber(..), defaultHints, getAddrInfo)
import Prelude hiding (lookup)
import System.Random (getStdRandom, randomR)
import System.Timeout (timeout)
#if mingw32_HOST_OS == 1
import Network.Socket (send)
import qualified Data.ByteString.Lazy.Char8 as LB
import Control.Monad (when)
#else
import Network.Socket.ByteString.Lazy (sendAll)
#endif
data FileOrNumericHost = RCFilePath FilePath
| RCHostName HostName
| RCHostPort HostName PortNumber
data ResolvConf = ResolvConf {
resolvInfo :: FileOrNumericHost
, resolvTimeout :: Int
, resolvRetry :: Int
, resolvBufsize :: Integer
}
defaultResolvConf :: ResolvConf
defaultResolvConf = ResolvConf {
resolvInfo = RCFilePath "/etc/resolv.conf"
, resolvTimeout = 3 * 1000 * 1000
, resolvRetry = 3
, resolvBufsize = 512
}
data ResolvSeed = ResolvSeed {
addrInfo :: AddrInfo
, rsTimeout :: Int
, rsRetry :: Int
, rsBufsize :: Integer
}
data Resolver = Resolver {
genId :: IO Int
, dnsSock :: Socket
, dnsTimeout :: Int
, dnsRetry :: Int
, dnsBufsize :: Integer
}
makeResolvSeed :: ResolvConf -> IO ResolvSeed
makeResolvSeed conf = ResolvSeed <$> addr
<*> pure (resolvTimeout conf)
<*> pure (resolvRetry conf)
<*> pure (resolvBufsize conf)
where
addr = case resolvInfo conf of
RCHostName numhost -> makeAddrInfo numhost Nothing
RCHostPort numhost mport -> makeAddrInfo numhost $ Just mport
RCFilePath file -> toAddr <$> readFile file >>= \i -> makeAddrInfo i Nothing
toAddr cs = let l:_ = filter ("nameserver" `isPrefixOf`) $ lines cs
in extract l
extract = reverse . dropWhile isSpace . reverse . dropWhile isSpace . drop 11
makeAddrInfo :: HostName -> Maybe PortNumber -> IO AddrInfo
makeAddrInfo addr mport = do
proto <- getProtocolNumber "udp"
let hints = defaultHints {
addrFlags = [AI_ADDRCONFIG, AI_NUMERICHOST, AI_PASSIVE]
, addrSocketType = Datagram
, addrProtocol = proto
}
a:_ <- getAddrInfo (Just hints) (Just addr) (Just "domain")
let connectPort = case addrAddress a of
SockAddrInet pn ha -> SockAddrInet (fromMaybe pn mport) ha
SockAddrInet6 pn fi ha sid -> SockAddrInet6 (fromMaybe pn mport) fi ha sid
unix -> unix
return $ a { addrAddress = connectPort }
withResolver :: ResolvSeed -> (Resolver -> IO a) -> IO a
withResolver seed func = bracket (openSocket seed) sClose $ \sock -> do
connectSocket sock seed
func $ makeResolver seed sock
withResolvers :: [ResolvSeed] -> ([Resolver] -> IO a) -> IO a
withResolvers seeds func = bracket openSockets closeSockets $ \socks -> do
mapM_ (uncurry connectSocket) $ zip socks seeds
let resolvs = zipWith makeResolver seeds socks
func resolvs
where
openSockets = mapM openSocket seeds
closeSockets = mapM sClose
openSocket :: ResolvSeed -> IO Socket
openSocket seed = socket (addrFamily ai) (addrSocketType ai) (addrProtocol ai)
where
ai = addrInfo seed
connectSocket :: Socket -> ResolvSeed -> IO ()
connectSocket sock seed = connect sock (addrAddress ai)
where
ai = addrInfo seed
makeResolver :: ResolvSeed -> Socket -> Resolver
makeResolver seed sock = Resolver {
genId = getRandom
, dnsSock = sock
, dnsTimeout = rsTimeout seed
, dnsRetry = rsRetry seed
, dnsBufsize = rsBufsize seed
}
getRandom :: IO Int
getRandom = getStdRandom (randomR (0,65535))
lookupSection :: (DNSFormat -> [ResourceRecord])
-> Resolver
-> Domain
-> TYPE
-> IO (Either DNSError [RDATA])
lookupSection section rlv dom typ = do
eans <- lookupRaw rlv dom typ
case eans of
Left err -> return $ Left err
Right ans -> return $ fromDNSFormat ans toRDATA
where
correct r = rrtype r == typ
toRDATA = map rdata . filter correct . section
fromDNSFormat :: DNSFormat -> (DNSFormat -> a) -> Either DNSError a
fromDNSFormat ans conv = case errcode ans of
NoErr -> Right $ conv ans
FormatErr -> Left FormatError
ServFail -> Left ServerFailure
NameErr -> Left NameError
NotImpl -> Left NotImplemented
Refused -> Left OperationRefused
where
errcode = rcode . flags . header
lookup :: Resolver -> Domain -> TYPE -> IO (Either DNSError [RDATA])
lookup = lookupSection answer
lookupAuth :: Resolver -> Domain -> TYPE -> IO (Either DNSError [RDATA])
lookupAuth = lookupSection authority
lookupRaw :: Resolver -> Domain -> TYPE -> IO (Either DNSError DNSFormat)
lookupRaw = lookupRawInternal receive
lookupRaw' :: Resolver -> Domain -> TYPE -> IO (Either DNSError (DNSMessage (RD BS.ByteString)))
lookupRaw' = lookupRawInternal receive'
lookupRawInternal ::
(Socket -> IO (DNSMessage a))
-> Resolver
-> Domain
-> TYPE
-> IO (Either DNSError (DNSMessage a))
lookupRawInternal _ _ dom _
| isIllegal dom = return $ Left IllegalDomain
lookupRawInternal rcv rlv dom typ = do
seqno <- genId rlv
let query = composeQuery seqno [q]
checkSeqno = check seqno
loop query checkSeqno 0 False
where
loop query checkSeqno cnt mismatch
| cnt == retry = do
let ret | mismatch = SequenceNumberMismatch
| otherwise = TimeoutExpired
return $ Left ret
| otherwise = do
sendAll sock query
response <- timeout tm (rcv sock)
case response of
Nothing -> loop query checkSeqno (cnt + 1) False
Just res -> do
let valid = checkSeqno res
if valid then
return $ Right res
else
loop query checkSeqno (cnt + 1) False
sock = dnsSock rlv
tm = dnsTimeout rlv
retry = dnsRetry rlv
q = makeQuestion dom typ
check seqno res = identifier (header res) == seqno
#if mingw32_HOST_OS == 1
sendAll sock bs = do
sent <- send sock (LB.unpack bs)
when (sent < fromIntegral (LB.length bs)) $ sendAll sock (LB.drop (fromIntegral sent) bs)
#endif
isIllegal :: Domain -> Bool
isIllegal "" = True
isIllegal dom
| '.' `BS.notElem` dom = True
| ':' `BS.elem` dom = True
| '/' `BS.elem` dom = True
| BS.length dom > 253 = True
| any (\x -> BS.length x > 63)
(BS.split '.' dom) = True
isIllegal _ = False