module Network.HaskellNet.SSL.Internal
  ( connectSSL
  , connectPlain
  ) where

import Network.Connection
import Network.HaskellNet.SSL
import Network.HaskellNet.BSStream

import qualified Data.ByteString.Char8 as B
import Data.Default

import Control.Monad ((>=>))

type STARTTLS = IO ()

connectionGetBytes :: Connection -> Int -> IO B.ByteString
connectionGetBytes :: Connection -> Int -> IO ByteString
connectionGetBytes = ByteString -> Connection -> Int -> IO ByteString
loop ByteString
B.empty where
  loop :: ByteString -> Connection -> Int -> IO ByteString
loop ByteString
buf Connection
_ Int
0 = ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
buf
  loop ByteString
buf Connection
c Int
l = Connection -> Int -> IO ByteString
connectionGet Connection
c Int
l IO ByteString -> (ByteString -> IO ByteString) -> IO ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> IO ByteString
nextIteration
    where nextIteration :: ByteString -> IO ByteString
nextIteration ByteString
b = ByteString -> Connection -> Int -> IO ByteString
loop (ByteString
buf ByteString -> ByteString -> ByteString
`B.append` ByteString
b) Connection
c (Int -> IO ByteString) -> Int -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
B.length ByteString
b

connectionToStream :: Connection -> Settings -> BSStream
connectionToStream :: Connection -> Settings -> BSStream
connectionToStream Connection
c Settings
cfg = BSStream :: IO ByteString
-> (Int -> IO ByteString)
-> (ByteString -> IO ())
-> IO ()
-> IO ()
-> IO Bool
-> (Int -> IO Bool)
-> BSStream
BSStream
  { bsGet :: Int -> IO ByteString
bsGet = Connection -> Int -> IO ByteString
connectionGetBytes Connection
c (Int -> IO ByteString)
-> (ByteString -> IO ByteString) -> Int -> IO ByteString
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> String -> ByteString -> IO ByteString
withLog String
"RECV"
  , bsPut :: ByteString -> IO ()
bsPut = String -> ByteString -> IO ByteString
withLog String
"SEND" (ByteString -> IO ByteString)
-> (ByteString -> IO ()) -> ByteString -> IO ()
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Connection -> ByteString -> IO ()
connectionPut Connection
c
  , bsFlush :: IO ()
bsFlush = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  , bsClose :: IO ()
bsClose = Connection -> IO ()
connectionClose Connection
c
  , bsIsOpen :: IO Bool
bsIsOpen = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
  , bsGetLine :: IO ByteString
bsGetLine = Int -> Connection -> IO ByteString
connectionGetLine Int
maxl Connection
c IO ByteString -> (ByteString -> IO ByteString) -> IO ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= String -> ByteString -> IO ByteString
withLog String
"RECV"
  , bsWaitForInput :: Int -> IO Bool
bsWaitForInput = Connection -> Int -> IO Bool
connectionWaitForInput Connection
c
  } where maxl :: Int
maxl = Settings -> Int
sslMaxLineLength Settings
cfg
          withLog :: String -> ByteString -> IO ByteString
withLog = if Settings -> Bool
sslLogToConsole Settings
cfg then String -> ByteString -> IO ByteString
logToConsole
                                           else (ByteString -> String -> IO ByteString)
-> String -> ByteString -> IO ByteString
forall a b c. (a -> b -> c) -> b -> a -> c
flip (IO ByteString -> String -> IO ByteString
forall a b. a -> b -> a
const (IO ByteString -> String -> IO ByteString)
-> (ByteString -> IO ByteString)
-> ByteString
-> String
-> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return)

logToConsole :: String -> B.ByteString -> IO B.ByteString
logToConsole :: String -> ByteString -> IO ByteString
logToConsole String
dir ByteString
s = do
    String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"HaskellNet-SSL " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
dir String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": " String -> String -> String
forall a. [a] -> [a] -> [a]
++ ByteString -> String
forall a. Show a => a -> String
show ByteString
s
    ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
s

connectSSL :: String -> Settings -> IO BSStream
connectSSL :: String -> Settings -> IO BSStream
connectSSL String
hostname Settings
cfg = do
    Connection
c <- IO ConnectionContext
initConnectionContext IO ConnectionContext
-> (ConnectionContext -> IO Connection) -> IO Connection
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (ConnectionContext -> ConnectionParams -> IO Connection)
-> ConnectionParams -> ConnectionContext -> IO Connection
forall a b c. (a -> b -> c) -> b -> a -> c
flip ConnectionContext -> ConnectionParams -> IO Connection
connectTo ConnectionParams
params
    BSStream -> IO BSStream
forall (m :: * -> *) a. Monad m => a -> m a
return (BSStream -> IO BSStream) -> BSStream -> IO BSStream
forall a b. (a -> b) -> a -> b
$ Connection -> Settings -> BSStream
connectionToStream Connection
c Settings
cfg
  where params :: ConnectionParams
params = String
-> PortNumber
-> Maybe TLSSettings
-> Maybe ProxySettings
-> ConnectionParams
ConnectionParams String
hostname PortNumber
port (TLSSettings -> Maybe TLSSettings
forall a. a -> Maybe a
Just TLSSettings
tlsCfg) Maybe ProxySettings
forall a. Maybe a
Nothing
        port :: PortNumber
port = Settings -> PortNumber
sslPort Settings
cfg
        tlsCfg :: TLSSettings
tlsCfg = TLSSettings
forall a. Default a => a
def { settingDisableCertificateValidation :: Bool
settingDisableCertificateValidation = Settings -> Bool
sslDisableCertificateValidation Settings
cfg }

connectPlain :: String -> Settings -> IO (BSStream, STARTTLS)
connectPlain :: String -> Settings -> IO (BSStream, IO ())
connectPlain String
hostname Settings
cfg = do
    ConnectionContext
ctx <- IO ConnectionContext
initConnectionContext
    Connection
c <- ConnectionContext -> ConnectionParams -> IO Connection
connectTo ConnectionContext
ctx ConnectionParams
params
    (BSStream, IO ()) -> IO (BSStream, IO ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Connection -> Settings -> BSStream
connectionToStream Connection
c Settings
cfg, ConnectionContext -> Connection -> TLSSettings -> IO ()
connectionSetSecure ConnectionContext
ctx Connection
c TLSSettings
tlsCfg)
  where params :: ConnectionParams
params = String
-> PortNumber
-> Maybe TLSSettings
-> Maybe ProxySettings
-> ConnectionParams
ConnectionParams String
hostname PortNumber
port Maybe TLSSettings
forall a. Maybe a
Nothing Maybe ProxySettings
forall a. Maybe a
Nothing
        port :: PortNumber
port = Settings -> PortNumber
sslPort Settings
cfg
        tlsCfg :: TLSSettings
tlsCfg = TLSSettings
forall a. Default a => a
def { settingDisableCertificateValidation :: Bool
settingDisableCertificateValidation = Settings -> Bool
sslDisableCertificateValidation Settings
cfg }