{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE TypeFamilies #-}
module DBus.Transport
(
Transport(..)
, TransportOpen(..)
, TransportListen(..)
, TransportError
, transportError
, transportErrorMessage
, transportErrorAddress
, SocketTransport
, socketTransportOptionBacklog
, socketTransportCredentials
) where
import Control.Exception
import qualified Data.ByteString
import Data.ByteString (ByteString)
import qualified Data.ByteString.Builder as Builder
import qualified Data.ByteString.Lazy as Lazy
import qualified Data.Map as Map
import Data.Monoid
import Data.Typeable (Typeable)
import Foreign.C (CUInt)
import Network.Socket hiding (recv)
import Network.Socket.ByteString (sendAll, recv)
import qualified System.Info
import Prelude
import DBus
data TransportError = TransportError
{ transportErrorMessage :: String
, transportErrorAddress :: Maybe Address
}
deriving (Eq, Show, Typeable)
instance Exception TransportError
transportError :: String -> TransportError
transportError msg = TransportError msg Nothing
class Transport t where
data TransportOptions t :: *
transportDefaultOptions :: TransportOptions t
transportPut :: t -> ByteString -> IO ()
transportGet :: t -> Int -> IO ByteString
transportClose :: t -> IO ()
class Transport t => TransportOpen t where
transportOpen :: TransportOptions t -> Address -> IO t
class Transport t => TransportListen t where
data TransportListener t :: *
transportListen :: TransportOptions t -> Address -> IO (TransportListener t)
transportAccept :: TransportListener t -> IO t
transportListenerClose :: TransportListener t -> IO ()
transportListenerAddress :: TransportListener t -> Address
transportListenerUUID :: TransportListener t -> UUID
data SocketTransport = SocketTransport (Maybe Address) Socket
instance Transport SocketTransport where
data TransportOptions SocketTransport = SocketTransportOptions
{
socketTransportOptionBacklog :: Int
}
transportDefaultOptions = SocketTransportOptions 30
transportPut (SocketTransport addr s) bytes = catchIOException addr (sendAll s bytes)
transportGet (SocketTransport addr s) n = catchIOException addr (recvLoop s n)
transportClose (SocketTransport addr s) = catchIOException addr (close s)
recvLoop :: Socket -> Int -> IO ByteString
recvLoop s = \n -> Lazy.toStrict `fmap` loop mempty n where
chunkSize = 4096
loop acc n = if n > chunkSize
then do
chunk <- recv s chunkSize
let builder = mappend acc (Builder.byteString chunk)
loop builder (n - Data.ByteString.length chunk)
else do
chunk <- recv s n
case Data.ByteString.length chunk of
0 -> return (Builder.toLazyByteString acc)
len -> do
let builder = mappend acc (Builder.byteString chunk)
if len == n
then return (Builder.toLazyByteString builder)
else loop builder (n - Data.ByteString.length chunk)
instance TransportOpen SocketTransport where
transportOpen _ a = case addressMethod a of
"unix" -> openUnix a
"tcp" -> openTcp a
method -> throwIO (transportError ("Unknown address method: " ++ show method))
{ transportErrorAddress = Just a
}
instance TransportListen SocketTransport where
data TransportListener SocketTransport = SocketTransportListener Address UUID Socket
transportListen opts a = do
uuid <- randomUUID
(a', sock) <- case addressMethod a of
"unix" -> listenUnix uuid a opts
"tcp" -> listenTcp uuid a opts
method -> throwIO (transportError ("Unknown address method: " ++ show method))
{ transportErrorAddress = Just a
}
return (SocketTransportListener a' uuid sock)
transportAccept (SocketTransportListener a _ s) = catchIOException (Just a) $ do
(s', _) <- accept s
return (SocketTransport Nothing s')
transportListenerClose (SocketTransportListener a _ s) = catchIOException (Just a) (close s)
transportListenerAddress (SocketTransportListener a _ _) = a
transportListenerUUID (SocketTransportListener _ uuid _) = uuid
socketTransportCredentials :: SocketTransport -> IO (Maybe CUInt, Maybe CUInt, Maybe CUInt)
socketTransportCredentials (SocketTransport a s) = catchIOException a (getPeerCredential s)
openUnix :: Address -> IO SocketTransport
openUnix transportAddr = go where
params = addressParameters transportAddr
param key = Map.lookup key params
tooMany = "Only one of 'path' or 'abstract' may be specified for the\
\ 'unix' transport."
tooFew = "One of 'path' or 'abstract' must be specified for the\
\ 'unix' transport."
path = case (param "path", param "abstract") of
(Just x, Nothing) -> Right x
(Nothing, Just x) -> Right ('\x00' : x)
(Nothing, Nothing) -> Left tooFew
_ -> Left tooMany
go = case path of
Left err -> throwIO (transportError err)
{ transportErrorAddress = Just transportAddr
}
Right p -> catchIOException (Just transportAddr) $ bracketOnError
(socket AF_UNIX Stream defaultProtocol)
close
(\sock -> do
connect sock (SockAddrUnix p)
return (SocketTransport (Just transportAddr) sock))
tcpHostname :: Maybe String -> Either a Network.Socket.Family -> String
tcpHostname (Just host) _ = host
tcpHostname Nothing (Right AF_INET) = "127.0.0.1"
tcpHostname Nothing (Right AF_INET6) = "::1"
tcpHostname _ _ = "localhost"
openTcp :: Address -> IO SocketTransport
openTcp transportAddr = go where
params = addressParameters transportAddr
param key = Map.lookup key params
hostname = tcpHostname (param "host") getFamily
unknownFamily x = "Unknown socket family for TCP transport: " ++ show x
getFamily = case param "family" of
Just "ipv4" -> Right AF_INET
Just "ipv6" -> Right AF_INET6
Nothing -> Right AF_UNSPEC
Just x -> Left (unknownFamily x)
missingPort = "TCP transport requires the `port' parameter."
badPort x = "Invalid socket port for TCP transport: " ++ show x
getPort = case param "port" of
Nothing -> Left missingPort
Just x -> case readPortNumber x of
Just port -> Right port
Nothing -> Left (badPort x)
getAddresses family_ = getAddrInfo (Just (defaultHints
{ addrFlags = [AI_ADDRCONFIG]
, addrFamily = family_
, addrSocketType = Stream
})) (Just hostname) Nothing
openSocket [] = throwIO (transportError "openTcp: no addresses")
{ transportErrorAddress = Just transportAddr
}
openSocket (addr:addrs) = do
tried <- Control.Exception.try $ bracketOnError
(socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr))
close
(\sock -> do
connect sock (addrAddress addr)
return sock)
case tried of
Left err -> case addrs of
[] -> throwIO (transportError (show (err :: IOException)))
{ transportErrorAddress = Just transportAddr
}
_ -> openSocket addrs
Right sock -> return sock
go = case getPort of
Left err -> throwIO (transportError err)
{ transportErrorAddress = Just transportAddr
}
Right port -> case getFamily of
Left err -> throwIO (transportError err)
{ transportErrorAddress = Just transportAddr
}
Right family_ -> catchIOException (Just transportAddr) $ do
addrs <- getAddresses family_
sock <- openSocket (map (setPort port) addrs)
return (SocketTransport (Just transportAddr) sock)
listenUnix :: UUID -> Address -> TransportOptions SocketTransport -> IO (Address, Socket)
listenUnix uuid origAddr opts = getPath >>= go where
params = addressParameters origAddr
param key = Map.lookup key params
tooMany = "Only one of 'abstract', 'path', or 'tmpdir' may be\
\ specified for the 'unix' transport."
tooFew = "One of 'abstract', 'path', or 'tmpdir' must be specified\
\ for the 'unix' transport."
getPath = case (param "abstract", param "path", param "tmpdir") of
(Just path, Nothing, Nothing) -> let
addr = address_ "unix"
[ ("abstract", path)
, ("guid", formatUUID uuid)
]
in return (Right (addr, '\x00' : path))
(Nothing, Just path, Nothing) -> let
addr = address_ "unix"
[ ("path", path)
, ("guid", formatUUID uuid)
]
in return (Right (addr, path))
(Nothing, Nothing, Just x) -> do
let fileName = x ++ "/haskell-dbus-" ++ formatUUID uuid
let (addrParams, path) = if System.Info.os == "linux"
then ([("abstract", fileName)], '\x00' : fileName)
else ([("path", fileName)], fileName)
let addr = address_ "unix" (addrParams ++ [("guid", formatUUID uuid)])
return (Right (addr, path))
(Nothing, Nothing, Nothing) -> return (Left tooFew)
_ -> return (Left tooMany)
go path = case path of
Left err -> throwIO (transportError err)
{ transportErrorAddress = Just origAddr
}
Right (addr, p) -> catchIOException (Just origAddr) $ bracketOnError
(socket AF_UNIX Stream defaultProtocol)
close
(\sock -> do
bind sock (SockAddrUnix p)
Network.Socket.listen sock (socketTransportOptionBacklog opts)
return (addr, sock))
listenTcp :: UUID -> Address -> TransportOptions SocketTransport -> IO (Address, Socket)
listenTcp uuid origAddr opts = go where
params = addressParameters origAddr
param key = Map.lookup key params
unknownFamily x = "Unknown socket family for TCP transport: " ++ show x
getFamily = case param "family" of
Just "ipv4" -> Right AF_INET
Just "ipv6" -> Right AF_INET6
Nothing -> Right AF_UNSPEC
Just x -> Left (unknownFamily x)
badPort x = "Invalid socket port for TCP transport: " ++ show x
getPort = case param "port" of
Nothing -> Right 0
Just x -> case readPortNumber x of
Just port -> Right port
Nothing -> Left (badPort x)
paramBind = case param "bind" of
Just "*" -> Nothing
Just x -> Just x
Nothing -> Just (tcpHostname (param "host") getFamily)
getAddresses family_ = getAddrInfo (Just (defaultHints
{ addrFlags = [AI_ADDRCONFIG, AI_PASSIVE]
, addrFamily = family_
, addrSocketType = Stream
})) paramBind Nothing
bindAddrs _ [] = throwIO (transportError "listenTcp: no addresses")
{ transportErrorAddress = Just origAddr
}
bindAddrs sock (addr:addrs) = do
tried <- Control.Exception.try (bind sock (addrAddress addr))
case tried of
Left err -> case addrs of
[] -> throwIO (transportError (show (err :: IOException)))
{ transportErrorAddress = Just origAddr
}
_ -> bindAddrs sock addrs
Right _ -> return ()
sockAddr port = address_ "tcp" p where
p = baseParams ++ hostParam ++ familyParam
baseParams =
[ ("port", show port)
, ("guid", formatUUID uuid)
]
hostParam = case param "host" of
Just x -> [("host", x)]
Nothing -> []
familyParam = case param "family" of
Just x -> [("family", x)]
Nothing -> []
go = case getPort of
Left err -> throwIO (transportError err)
{ transportErrorAddress = Just origAddr
}
Right port -> case getFamily of
Left err -> throwIO (transportError err)
{ transportErrorAddress = Just origAddr
}
Right family_ -> catchIOException (Just origAddr) $ do
sockAddrs <- getAddresses family_
bracketOnError
(socket family_ Stream defaultProtocol)
close
(\sock -> do
setSocketOption sock ReuseAddr 1
bindAddrs sock (map (setPort port) sockAddrs)
Network.Socket.listen sock (socketTransportOptionBacklog opts)
sockPort <- socketPort sock
return (sockAddr sockPort, sock))
catchIOException :: Maybe Address -> IO a -> IO a
catchIOException addr io = do
tried <- try io
case tried of
Right a -> return a
Left err -> throwIO (transportError (show (err :: IOException)))
{ transportErrorAddress = addr
}
address_ :: String -> [(String, String)] -> Address
address_ method params = addr where
Just addr = address method (Map.fromList params)
setPort :: PortNumber -> AddrInfo -> AddrInfo
setPort port info = case addrAddress info of
(SockAddrInet _ x) -> info { addrAddress = SockAddrInet port x }
(SockAddrInet6 _ x y z) -> info { addrAddress = SockAddrInet6 port x y z }
_ -> info
readPortNumber :: String -> Maybe PortNumber
readPortNumber s = do
case dropWhile (\c -> c >= '0' && c <= '9') s of
[] -> return ()
_ -> Nothing
let word = read s :: Integer
if word > 0 && word <= 65535
then Just (fromInteger word)
else Nothing