module Network.DNS.MDNSResponder.Client
(
Connection
, connect
, disconnect
, defaultAddr
, AsyncConnectionError (..)
, AsyncConnectionErrorHandler
, NullFreeByteString
, DNSServiceFlags
, DNSServiceErrorType
, kDNSServiceErr_NoError
, kDNSServiceErr_ServiceNotRunning
, kDNSServiceErr_ShortResponse
, InterfaceIndex
, kDNSServiceInterfaceIndexAny
, kDNSServiceInterfaceIndexLocalOnly
, Request (..)
, request
, AsyncResponseHandler
, ResponseHeader (..)
, Response (..)
, NTDResponse (..)
, ResolveResponse (..)
) where
import Data.Word
import Data.Int
import Data.Bits
import Data.IORef
import Data.Typeable
import System.Environment
import Control.Exception
import Control.Concurrent
import Control.Monad
import Foreign.Ptr
import Foreign.Marshal.Alloc
import Foreign.Marshal.Utils
import Foreign.Storable
import Foreign.C.Types
import Foreign.C.String
import Control.Monad.Trans.Maybe
import Control.Monad.Trans.Class
import Data.ByteString as BS
import Data.ByteString.Unsafe
import Data.Endian
import qualified Control.Concurrent.Map as CM
import qualified Network.Socket as S
import Network.Socket.Msg
newtype DNSServiceFlags =
DNSServiceFlags Word32 deriving (Eq, Bits)
instance Monoid DNSServiceFlags where
mempty = DNSServiceFlags 0
mappend = (.|.)
newtype InterfaceIndex = InterfaceIndex Word32
kDNSServiceInterfaceIndexAny :: InterfaceIndex
kDNSServiceInterfaceIndexAny =
InterfaceIndex 0
kDNSServiceInterfaceIndexLocalOnly :: InterfaceIndex
kDNSServiceInterfaceIndexLocalOnly =
InterfaceIndex 4294967295
newtype DNSServiceErrorType =
DNSServiceErrorType Int32
deriving (Eq, Show)
kDNSServiceErr_NoError :: DNSServiceErrorType
kDNSServiceErr_NoError =
DNSServiceErrorType 0
kDNSServiceErr_ServiceNotRunning :: DNSServiceErrorType
kDNSServiceErr_ServiceNotRunning =
DNSServiceErrorType (65563)
kDNSServiceErr_ShortResponse :: DNSServiceErrorType
kDNSServiceErr_ShortResponse = DNSServiceErrorType 1
type NullFreeByteString = ByteString
data ResponseHeader = ResponseHeader
{ reshdr_flags :: !DNSServiceFlags
, reshdr_ifi :: !InterfaceIndex
}
data Response a = Response !ResponseHeader !a
data NTDResponse = NTDResponse
{ ntd_name :: !NullFreeByteString
, ntd_regtype :: !NullFreeByteString
, ntd_domain :: !NullFreeByteString
}
data ResolveResponse = ResolveResponse
{ resolve_fullname :: !NullFreeByteString
, resolve_hosttarget :: !NullFreeByteString
, resolve_port :: !S.PortNumber
, resolve_txt :: !ByteString
}
data Request a where
ServiceRegister :: !DNSServiceFlags
-> !InterfaceIndex
-> !NullFreeByteString
-> !NullFreeByteString
-> !NullFreeByteString
-> !NullFreeByteString
-> !S.PortNumber
-> !ByteString
-> Request NTDResponse
ServiceBrowse :: !InterfaceIndex
-> !NullFreeByteString
-> !NullFreeByteString
-> Request NTDResponse
ServiceResolve :: !DNSServiceFlags
-> !InterfaceIndex
-> !NullFreeByteString
-> !NullFreeByteString
-> !NullFreeByteString
-> Request ResolveResponse
data Connection = Connection
{ sock :: !S.Socket
, counter :: !(IORef Word64)
, requestQueue :: !(Chan AnyRequestRegistration)
, responseMap :: !(CM.Map Word64 AnyAsyncResponseHandler)
, recvThreadId :: !ThreadId
, sendThreadId :: !ThreadId
}
data AsyncConnectionError
= AsyncConnectionIOError !IOError
| AsyncConnectionClosedError
| AsyncConnectionBadDaemonVersionError !Word32
deriving (Show, Typeable)
instance Exception AsyncConnectionError
type AsyncConnectionErrorHandler = AsyncConnectionError -> IO ()
connect :: S.SockAddr
-> AsyncConnectionErrorHandler
-> IO (Either DNSServiceErrorType Connection)
connect addr e_handler = bracketOnError makeSocket S.close $ \s -> do
S.connect s addr
allocaBytes ipcMsgHdrSz $ \hdr -> do
pokeHdr (IpcMsgHdr 0 1 0) hdr
sendAll s (castPtr hdr) ipcMsgHdrSz
err <- recvError s
case err of
DNSServiceErrorType 0 -> do
chan <- newChan
handlers <- CM.empty
sTidVar <- newEmptyMVar
rTidVar <- newEmptyMVar
counter' <- newIORef 0
bracketOnError
(createSendThread s chan sTidVar rTidVar)
killThread $ \sTid -> do
bracketOnError
(createRecvThread s handlers sTidVar rTidVar)
killThread $ \rTid -> do
return . Right $
Connection s counter' chan handlers rTid sTid
_ -> return $ Left err
where
makeSocket = S.socket S.AF_UNIX S.Stream S.defaultProtocol
createSendThread s chan sTidVar rTidVar =
mask_ $ forkIOWithUnmask $
sendThread s chan e_handler sTidVar rTidVar
createRecvThread s handlers sTidVar rTidVar =
mask_ $ forkIOWithUnmask $
recvThread s handlers e_handler sTidVar rTidVar
defaultAddr :: IO S.SockAddr
defaultAddr = do
m_def <- lookupEnv "DNSSD_UDS_PATH"
let p = case m_def of
Just p' -> p'
Nothing -> "/var/run/mDNSResponder"
return $ S.SockAddrUnix p
disconnect :: Connection
-> IO ()
disconnect (Connection {..}) = do
killThread recvThreadId
killThread sendThreadId
S.close sock
type AsyncResponseHandler a = Either DNSServiceErrorType (Response a)
-> IO ()
recvError :: S.Socket -> IO DNSServiceErrorType
recvError s = alloca $ \buf -> do
res <- recvAll s (castPtr buf) (4)
case res of
RecvAllOK -> DNSServiceErrorType . fromBigEndian <$> peek buf
RecvAllClosed -> return kDNSServiceErr_ServiceNotRunning
request :: PeekableResponse a
=> Connection
-> Request a
-> AsyncResponseHandler a
-> IO DNSServiceErrorType
request (Connection {..}) req handler =
bracket makeSocks closeSocks $ \(us, them) -> do
ctx <- atomicModifyIORef' counter (\x -> (x + 1, x + 1))
CM.insert ctx (AnyAsyncResponseHandler handler) responseMap
writeChan requestQueue (AnyRequestRegistration ctx req them)
recvError us
where
makeSocks = S.socketPair S.AF_UNIX S.Stream S.defaultProtocol
closeSocks (us, them) = S.close us >> S.close them
data AnyAsyncResponseHandler = forall a . PeekableResponse a =>
AnyAsyncResponseHandler !(AsyncResponseHandler a)
data AnyRequestRegistration =
forall a . AnyRequestRegistration !Word64 !(Request a) !S.Socket
data IpcMsgHdr = IpcMsgHdr
{ datalen :: !Word32
, op :: !Word32
, context :: !Word64
}
ipcMsgHdrSz :: Int
ipcMsgHdrSz = (28)
pokeHdr :: IpcMsgHdr
-> (Ptr IpcMsgHdr)
-> IO ()
pokeHdr (IpcMsgHdr {..}) hdr = do
(\hsc_ptr -> pokeByteOff hsc_ptr 0) hdr $
toBigEndian (1 :: Word32)
(\hsc_ptr -> pokeByteOff hsc_ptr 4) hdr $ toBigEndian datalen
(\hsc_ptr -> pokeByteOff hsc_ptr 8) hdr (0 :: Word32)
(\hsc_ptr -> pokeByteOff hsc_ptr 12) hdr $ toBigEndian op
(\hsc_ptr -> pokeByteOff hsc_ptr 16) hdr context
(\hsc_ptr -> pokeByteOff hsc_ptr 24) hdr (0 :: Word32)
peekHdr :: Ptr IpcMsgHdr -> IO IpcMsgHdr
peekHdr hdr = do
ver <-
(fromBigEndian <$> (\hsc_ptr -> peekByteOff hsc_ptr 0) hdr) :: IO Word32
case ver of
1 -> do
datalen <- fromBigEndian <$> (\hsc_ptr -> peekByteOff hsc_ptr 4) hdr
op <- fromBigEndian <$> (\hsc_ptr -> peekByteOff hsc_ptr 12) hdr
context <- (\hsc_ptr -> peekByteOff hsc_ptr 16) hdr
return $ IpcMsgHdr datalen op context
_ -> throwIO $ AsyncConnectionBadDaemonVersionError ver
size :: Request a -> Int
size (ServiceRegister _ _ name ty domain host _ txt) =
(4) +
4 +
(BS.length name) + 1 +
(BS.length ty) + 1 +
(BS.length domain) + 1 +
(BS.length host) + 1 +
2 +
2 +
(BS.length txt)
size (ServiceBrowse _ ty domain) =
(4) +
4 +
(BS.length ty) + 1 +
(BS.length domain) + 1
size (ServiceResolve _ _ name regtype domain) =
(4) +
4 +
(BS.length name) + 1 +
(BS.length regtype) + 1 +
(BS.length domain) + 1
operation :: Request a -> Word32
operation (ServiceRegister _ _ _ _ _ _ _ _) =
5
operation (ServiceBrowse _ _ _) = 6
operation (ServiceResolve _ _ _ _ _) = 7
type Poke =
(Ptr Word8 -> IO (), Int)
runPokes :: Ptr Word8 -> [ Poke ] -> IO ()
runPokes _ [] = return ()
runPokes p ((io, sz) : pokes) = do
io p
runPokes (plusPtr p sz) pokes
pokeBSNull :: Int -> ByteString -> Ptr Word8 -> IO ()
pokeBSNull sz bs ptr = do
pokeBS sz bs ptr
poke (plusPtr ptr sz) (0 :: Word8)
pokeBS :: Int -> ByteString -> Ptr Word8 -> IO ()
pokeBS sz bs ptr = unsafeUseAsCString bs $ \buf -> do
copyBytes ptr (castPtr buf) sz
pokeBody :: Request a -> Ptr (Request a) -> IO ()
pokeBody (ServiceRegister
(DNSServiceFlags flags)
(InterfaceIndex ifi)
name
ty
domain
host
port
txt
) ptr = runPokes (castPtr ptr)
[ (flip poke (toBigEndian flags) . castPtr, 4)
, (flip poke (toBigEndian ifi) . castPtr, 4)
, (pokeBSNull name_sz name, name_sz + 1)
, (pokeBSNull ty_sz ty, ty_sz + 1)
, (pokeBSNull domain_sz domain, domain_sz + 1)
, (pokeBSNull host_sz host, host_sz + 1)
, (flip poke port . castPtr, 2)
, (flip poke (toBigEndian txtln) . castPtr, 2)
, (pokeBS txt_sz txt, txt_sz)
]
where
name_sz = BS.length name
ty_sz = BS.length ty
domain_sz = BS.length domain
host_sz = BS.length host
txtln :: Word16
txtln = fromIntegral txt_sz
txt_sz = BS.length txt
pokeBody (ServiceBrowse (InterfaceIndex ifi) ty domain) ptr =
runPokes (castPtr ptr)
[ (flip poke (0 :: Word32) . castPtr, 4)
, (flip poke (toBigEndian ifi) . castPtr, 4)
, (pokeBSNull ty_sz ty, ty_sz + 1)
, (pokeBSNull domain_sz domain, domain_sz + 1)
]
where
ty_sz = BS.length ty
domain_sz = BS.length domain
pokeBody (ServiceResolve
(DNSServiceFlags flags)
(InterfaceIndex ifi)
name
ty
domain
) ptr = runPokes (castPtr ptr)
[ (flip poke (toBigEndian flags) . castPtr, 4)
, (flip poke (toBigEndian ifi) . castPtr, 4)
, (pokeBSNull name_sz name, name_sz + 1)
, (pokeBSNull ty_sz ty, ty_sz + 1)
, (pokeBSNull domain_sz domain, domain_sz + 1)
]
where
name_sz = BS.length name
ty_sz = BS.length ty
domain_sz = BS.length domain
data RecvAllResult = RecvAllOK | RecvAllClosed
recvAll :: S.Socket -> Ptr Word8 -> Int -> IO RecvAllResult
recvAll s = loop
where
loop ptr i = do
cnt <- S.recvBuf s ptr i
if cnt == i
then return RecvAllOK
else if cnt == 0
then return RecvAllClosed
else loop (plusPtr ptr cnt) (i cnt)
sendAll :: S.Socket -> Ptr Word8 -> Int -> IO ()
sendAll s = loop
where
loop ptr' i' = do
cnt <- S.sendBuf s ptr' i'
if cnt == i'
then return ()
else loop (plusPtr ptr' cnt) (i' cnt)
data SockEx = SockEx deriving (Show, Typeable)
instance Exception SockEx
sendThread :: S.Socket
-> Chan AnyRequestRegistration
-> AsyncConnectionErrorHandler
-> MVar ThreadId
-> MVar ThreadId
-> (forall a. IO a -> IO a)
-> IO ()
sendThread sock chan e_handler sTidVar rTidVar unmask = do
_ <- (try :: IO () -> IO (Either SockEx ())) . unmask $ do
myThreadId >>= putMVar sTidVar
Left e <- (try loop) :: IO (Either IOError ())
takeMVar rTidVar >>= flip throwTo SockEx
_ <- forkIO . e_handler $ AsyncConnectionIOError e
drain
unmask drain
where
loop = do
(AnyRequestRegistration ctx req them) <- readChan chan
let sz = size req
full_sz = ipcMsgHdrSz + sz
op = operation req
allocaBytes (full_sz + 1) $ \reqptr -> do
pokeHdr (IpcMsgHdr (fromIntegral $ sz + 1) op ctx) reqptr
poke (plusPtr reqptr ipcMsgHdrSz) (0 :: CChar)
pokeBody req $ (plusPtr reqptr (ipcMsgHdrSz + 1))
sendAll sock (castPtr reqptr) full_sz
alloca $ \cmsgptr -> do
poke cmsgptr (S.fdSocket them)
cmsg <- unsafePackCStringLen (castPtr cmsgptr, (4))
body <- unsafePackCStringLen ((plusPtr reqptr full_sz), 1)
sendMsg sock body (S.SockAddrUnix "")
[ CMsg 1 1 cmsg ]
S.close them
loop
drain = do
(AnyRequestRegistration _ _ them) <- readChan chan
S.close them
drain
class PeekableResponse a where
peekResponseBody :: Ptr a -> Int -> IO (Maybe a)
instance PeekableResponse NTDResponse where
peekResponseBody buf sz = runMaybeT $ do
name_null <- findNull (castPtr buf) sz
name <- lift $ packCStringLen (castPtr buf, name_null)
let buf' = plusPtr buf (name_null + 1)
sz' = sz (name_null + 1)
regtype_null <- findNull buf' sz'
regtype <- lift $ packCStringLen (buf', regtype_null)
let buf'' = plusPtr buf' (regtype_null + 1)
sz'' = sz' (regtype_null + 1)
domain_null <- findNull buf'' sz''
domain <- lift $ packCStringLen (buf'', domain_null)
return $ NTDResponse name regtype domain
instance PeekableResponse ResolveResponse where
peekResponseBody buf sz = runMaybeT $ do
name_null <- findNull (castPtr buf) sz
name <- lift $ packCStringLen (castPtr buf, name_null)
let buf' = plusPtr buf (name_null + 1)
sz' = sz (name_null + 1)
target_null <- findNull buf' sz'
target <- lift $ packCStringLen (buf', target_null)
let buf'' = plusPtr buf' (target_null + 1)
sz'' = sz' (target_null + 1)
port_sz = sizeOf (undefined :: S.PortNumber)
when (sz'' < port_sz) mzero
port <- lift $ peek buf''
let buf''' = plusPtr buf'' port_sz
sz''' = sz'' port_sz
len_sz = 2
when (sz''' < len_sz) mzero
len <- (lift $ fromBigEndian <$> peek buf''') :: MaybeT IO Word16
let buf'''' = plusPtr buf''' len_sz
sz'''' = sz''' len_sz
len' = fromIntegral len
when (sz'''' < len') mzero
txt <- lift $ packCStringLen (buf'''', len')
return $ ResolveResponse name target port txt
findNull :: CString -> Int -> MaybeT IO Int
findNull = go 0
where
go _ _ 0 = mzero
go acc ptr n = do
c <- lift $ peek ptr :: MaybeT IO CChar
case c of
0 -> return acc
_ -> go (acc + 1) (plusPtr ptr 1) (n 1)
recvThread :: S.Socket
-> CM.Map Word64 AnyAsyncResponseHandler
-> AsyncConnectionErrorHandler
-> MVar ThreadId
-> MVar ThreadId
-> (forall a. IO a -> IO a)
-> IO ()
recvThread sock handlers e_handler sTidVar rTidVar unmask = do
_ <- (try :: IO () -> IO (Either SockEx ())) $ do
myThreadId >>= putMVar rTidVar
err <- try $ do
Left ex <- try $ unmask loop
return ex
readMVar sTidVar >>= flip throwTo SockEx
unmask $ do
_ <- forkIO . e_handler $ case err of
Left e -> AsyncConnectionIOError e
Right e -> e
drain
unmask drain
where
loop = do
hdr <- allocaBytes ipcMsgHdrSz $ \buf -> do
res <- recvAll sock (castPtr buf) ipcMsgHdrSz
case res of
RecvAllClosed -> throwIO AsyncConnectionClosedError
RecvAllOK -> peekHdr buf
m_handler <- CM.lookup (context hdr) handlers
let len = fromIntegral $ datalen hdr
allocaBytes len $ \buf -> do
res <- recvAll sock buf len
case res of
RecvAllClosed -> throwIO AsyncConnectionClosedError
RecvAllOK -> return ()
case m_handler of
Nothing -> return ()
Just (AnyAsyncResponseHandler handler) -> do
e_r_hdr <- peekResponseHeader buf len
case e_r_hdr of
Left err -> void . forkIO . handler $ Left err
Right r_hdr -> do
m_r_body <- peekResponseBody
(plusPtr buf responseHdrSz)
(len responseHdrSz)
let response = case m_r_body of
Just r_body -> Right $ Response r_hdr r_body
Nothing -> Left $ kDNSServiceErr_ShortResponse
void . forkIO $ handler response
loop
drain = do
entries <- CM.unsafeToList handlers
mapM_ (\(ctx, _) -> CM.delete ctx handlers) entries
threadDelay 5000
drain
responseHdrSz =
(4) +
4 +
(4)
peekResponseHeader buf len = if len < responseHdrSz
then return $ Left kDNSServiceErr_ShortResponse
else do
flags <- DNSServiceFlags . fromBigEndian <$> peek (castPtr buf)
ifi <- InterfaceIndex . fromBigEndian <$>
peek (plusPtr buf (4))
err <- DNSServiceErrorType . fromBigEndian <$>
peek (plusPtr buf ((4) + 4))
return $ case err of
DNSServiceErrorType 0 ->
Right (ResponseHeader flags ifi)
_ -> Left err