{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE TypeFamilies #-}
module DBus.Socket
(
Socket
, send
, receive
, SocketError
, socketError
, socketErrorMessage
, socketErrorFatal
, socketErrorAddress
, SocketOptions
, socketAuthenticator
, socketTransportOptions
, defaultSocketOptions
, open
, openWith
, close
, SocketListener
, listen
, listenWith
, accept
, closeListener
, socketListenerAddress
, Authenticator
, authenticator
, authenticatorClient
, authenticatorServer
) where
import Prelude hiding (getLine)
import Control.Concurrent
import Control.Exception
import Control.Monad (mplus)
import qualified Data.ByteString
import qualified Data.ByteString.Char8 as Char8
import Data.Char (ord)
import Data.IORef
import Data.List (isPrefixOf)
import Data.Typeable (Typeable)
import qualified System.Posix.User
import Text.Printf (printf)
import DBus
import DBus.Transport
import DBus.Internal.Wire (unmarshalMessageM)
data SocketError = SocketError
{ socketErrorMessage :: String
, socketErrorFatal :: Bool
, socketErrorAddress :: Maybe Address
}
deriving (Eq, Show, Typeable)
instance Exception SocketError
socketError :: String -> SocketError
socketError msg = SocketError msg True Nothing
data SomeTransport = forall t. (Transport t) => SomeTransport t
instance Transport SomeTransport where
data TransportOptions SomeTransport = SomeTransportOptions
transportDefaultOptions = SomeTransportOptions
transportPut (SomeTransport t) = transportPut t
transportGet (SomeTransport t) = transportGet t
transportClose (SomeTransport t) = transportClose t
data Socket = Socket
{ socketTransport :: SomeTransport
, socketAddress :: Maybe Address
, socketSerial :: IORef Serial
, socketReadLock :: MVar ()
, socketWriteLock :: MVar ()
}
data Authenticator t = Authenticator
{
authenticatorClient :: t -> IO Bool
, authenticatorServer :: t -> UUID -> IO Bool
}
data SocketOptions t = SocketOptions
{
socketAuthenticator :: Authenticator t
, socketTransportOptions :: TransportOptions t
}
defaultSocketOptions :: SocketOptions SocketTransport
defaultSocketOptions = SocketOptions
{ socketTransportOptions = transportDefaultOptions
, socketAuthenticator = authExternal
}
open :: Address -> IO Socket
open = openWith defaultSocketOptions
openWith :: TransportOpen t => SocketOptions t -> Address -> IO Socket
openWith opts addr = toSocketError (Just addr) $ bracketOnError
(transportOpen (socketTransportOptions opts) addr)
transportClose
(\t -> do
authed <- authenticatorClient (socketAuthenticator opts) t
if not authed
then throwIO (socketError "Authentication failed")
{ socketErrorAddress = Just addr
}
else do
serial <- newIORef firstSerial
readLock <- newMVar ()
writeLock <- newMVar ()
return (Socket (SomeTransport t) (Just addr) serial readLock writeLock))
data SocketListener = forall t. (TransportListen t) => SocketListener (TransportListener t) (Authenticator t)
listen :: Address -> IO SocketListener
listen = listenWith defaultSocketOptions
listenWith :: TransportListen t => SocketOptions t -> Address -> IO SocketListener
listenWith opts addr = toSocketError (Just addr) $ bracketOnError
(transportListen (socketTransportOptions opts) addr)
transportListenerClose
(\l -> return (SocketListener l (socketAuthenticator opts)))
accept :: SocketListener -> IO Socket
accept (SocketListener l auth) = toSocketError Nothing $ bracketOnError
(transportAccept l)
transportClose
(\t -> do
let uuid = transportListenerUUID l
authed <- authenticatorServer auth t uuid
if not authed
then throwIO (socketError "Authentication failed")
else do
serial <- newIORef firstSerial
readLock <- newMVar ()
writeLock <- newMVar ()
return (Socket (SomeTransport t) Nothing serial readLock writeLock))
close :: Socket -> IO ()
close = transportClose . socketTransport
closeListener :: SocketListener -> IO ()
closeListener (SocketListener l _) = transportListenerClose l
socketListenerAddress :: SocketListener -> Address
socketListenerAddress (SocketListener l _) = transportListenerAddress l
send :: Message msg => Socket -> msg -> (Serial -> IO a) -> IO a
send sock msg io = toSocketError (socketAddress sock) $ do
serial <- nextSocketSerial sock
case marshal LittleEndian serial msg of
Right bytes -> do
let t = socketTransport sock
a <- io serial
withMVar (socketWriteLock sock) (\_ -> transportPut t bytes)
return a
Left err -> throwIO (socketError ("Message cannot be sent: " ++ show err))
{ socketErrorFatal = False
}
nextSocketSerial :: Socket -> IO Serial
nextSocketSerial sock = atomicModifyIORef (socketSerial sock) (\x -> (nextSerial x, x))
receive :: Socket -> IO ReceivedMessage
receive sock = toSocketError (socketAddress sock) $ do
let t = socketTransport sock
let get n = if n == 0
then return Data.ByteString.empty
else transportGet t n
received <- withMVar (socketReadLock sock) (\_ -> unmarshalMessageM get)
case received of
Left err -> throwIO (socketError ("Error reading message from socket: " ++ show err))
Right msg -> return msg
toSocketError :: Maybe Address -> IO a -> IO a
toSocketError addr io = catches io handlers where
handlers =
[ Handler catchTransportError
, Handler updateSocketError
, Handler catchIOException
]
catchTransportError err = throwIO (socketError (transportErrorMessage err))
{ socketErrorAddress = addr
}
updateSocketError err = throwIO err
{ socketErrorAddress = mplus (socketErrorAddress err) addr
}
catchIOException exc = throwIO (socketError (show (exc :: IOException)))
{ socketErrorAddress = addr
}
authenticator :: Authenticator t
authenticator = Authenticator (\_ -> return False) (\_ _ -> return False)
authExternal :: Authenticator SocketTransport
authExternal = authenticator
{ authenticatorClient = clientAuthExternal
, authenticatorServer = serverAuthExternal
}
clientAuthExternal :: SocketTransport -> IO Bool
clientAuthExternal t = do
transportPut t (Data.ByteString.pack [0])
uid <- System.Posix.User.getRealUserID
let token = concatMap (printf "%02X" . ord) (show uid)
transportPutLine t ("AUTH EXTERNAL " ++ token)
resp <- transportGetLine t
case splitPrefix "OK " resp of
Just _ -> do
transportPutLine t "BEGIN"
return True
Nothing -> return False
serverAuthExternal :: SocketTransport -> UUID -> IO Bool
serverAuthExternal t uuid = do
let waitForBegin = do
resp <- transportGetLine t
if resp == "BEGIN"
then return ()
else waitForBegin
let checkToken token = do
(_, uid, _) <- socketTransportCredentials t
let wantToken = concatMap (printf "%02X" . ord) (maybe "XXX" show uid)
if token == wantToken
then do
transportPutLine t ("OK " ++ formatUUID uuid)
waitForBegin
return True
else return False
c <- transportGet t 1
if c /= Char8.pack "\x00"
then return False
else do
line <- transportGetLine t
case splitPrefix "AUTH EXTERNAL " line of
Just token -> checkToken token
Nothing -> if line == "AUTH EXTERNAL"
then do
dataLine <- transportGetLine t
case splitPrefix "DATA " dataLine of
Just token -> checkToken token
Nothing -> return False
else return False
transportPutLine :: Transport t => t -> String -> IO ()
transportPutLine t line = transportPut t (Char8.pack (line ++ "\r\n"))
transportGetLine :: Transport t => t -> IO String
transportGetLine t = do
let getchr = Char8.head `fmap` transportGet t 1
raw <- readUntil "\r\n" getchr
return (dropEnd 2 raw)
dropEnd :: Int -> [a] -> [a]
dropEnd n xs = take (length xs - n) xs
splitPrefix :: String -> String -> Maybe String
splitPrefix prefix str = if isPrefixOf prefix str
then Just (drop (length prefix) str)
else Nothing
readUntil :: (Monad m, Eq a) => [a] -> m a -> m [a]
readUntil guard getx = readUntil' [] where
guard' = reverse guard
step xs | isPrefixOf guard' xs = return (reverse xs)
| otherwise = readUntil' xs
readUntil' xs = do
x <- getx
step (x:xs)