module Happstack.Server.Session (Session(..), SessionConfig(..), mkSessionConfig, SessionHandler, startSession, getSession, setSession, updateSession, deleteSession) where
import Prelude hiding (lookup)
import Data.Word
import Data.Maybe
import Data.Text (Text)
import Control.Monad
import Data.Either
import Control.Monad.IO.Class
import Data.Time.Clock.POSIX
import Control.Exception
import Crypto.Cipher.Types
import Crypto.Cipher.AES
import Crypto.Error
import Crypto.Data.Padding
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as C
import Data.ByteString.Base16 (encode, decode)
import Happstack.Server hiding (Session, host)
maybeRead :: (Read a) => String -> Maybe a
maybeRead = fmap fst . listToMaybe . reads
data Session a b = Session {
sessionId :: a,
sessionExpire :: Word64,
sessionData :: b
} deriving (Show)
data SessionConfig a = SessionConfig {
sessionAuthEncrypt :: (a -> String),
sessionAuthDecrypt :: (String -> Maybe a)
}
mkSessionConfig :: (Read a, Show a) =>
BS.ByteString
-> BS.ByteString
-> SessionConfig a
mkSessionConfig key iv = SessionConfig {
sessionAuthEncrypt = encrypt,
sessionAuthDecrypt = decrypt
}
where
cipher = either (\m -> error $ show m ++ "\nInvalid key passed to mkSessionConfig") (\x -> x) $ eitherCryptoError $ cipherInit key :: AES128
iv' = fromMaybe (error "Invalid IV passed to mkSessionConfig") $ makeIV $ fst $ decode iv :: IV AES128
encrypt m = C.unpack $ encode $ cbcEncrypt cipher iv' $ pad (PKCS7 16) $ C.pack $ show m
decrypt c | (BS.length $ fst $ decode $ C.pack c) `mod` 16 == 0 = read . C.unpack <$> (unpad (PKCS7 16) $ cbcDecrypt cipher iv' $ fst $ decode $ C.pack c)
| otherwise = Nothing
data SessionHandler a b = SessionHandler (SessionConfig a) (a -> IO (Maybe (Session a b))) (b -> Word64 -> IO (Session a b)) (a -> b -> IO (Maybe (Session a b))) (a -> IO ())
startSession :: SessionConfig a
-> IO ((a -> IO (Maybe (Session a b))), (b -> Word64 -> IO (Session a b)), (a -> b -> IO (Maybe (Session a b))), (a -> IO ()))
-> IO (SessionHandler a b)
startSession sessionConfig sessionHandler = do
(getSession, setSession, updateSession, deleteSession) <- sessionHandler
return $ SessionHandler sessionConfig getSession setSession updateSession deleteSession
getSession :: (MonadPlus m, MonadIO m, FilterMonad Response m, HasRqData m, Read a) => (SessionHandler a b) -> m (Maybe b)
getSession sessionHandler@(SessionHandler sessionConfig getSession' _ _ _) = msum
[ do
sid' <- lookCookieValue "SID"
let sid = sessionAuthDecrypt sessionConfig sid'
session <- maybe (return Nothing) (liftIO . getSession') sid
timeNow <- liftIO $ fmap (floor) getPOSIXTime
if ( fromMaybe False $ (< timeNow) . sessionExpire <$> session ) then do
deleteSession sessionHandler
return Nothing
else
return $ sessionData <$> session
, do
return Nothing
]
setSession :: (MonadIO m, FilterMonad Response m, Show a) =>
(SessionHandler a b)
-> b
-> Word64
-> m ()
setSession (SessionHandler sessionConfig _ setSession' _ _) dat expiry = do
timeNow <- liftIO $ fmap (floor) getPOSIXTime
session <- liftIO $ setSession' dat (timeNow + expiry)
addCookie (MaxAge $ fromIntegral expiry) $ mkCookie "SID" $ sessionAuthEncrypt sessionConfig $ sessionId session
return ()
updateSession :: (MonadPlus m, MonadIO m, FilterMonad Response m, HasRqData m, Read a) =>
(SessionHandler a b)
-> b
-> m ()
updateSession sessionHandler@(SessionHandler sessionConfig _ _ updateSession' _) dat = msum
[ do
sid' <- lookCookieValue "SID"
let sid = sessionAuthDecrypt sessionConfig sid'
session <- maybe (return Nothing) (liftIO . (`updateSession'` dat)) sid
timeNow <- liftIO $ fmap (floor) getPOSIXTime
if (fromMaybe False $ (> timeNow) . sessionExpire <$> session) then do
deleteSession sessionHandler
return ()
else
return ()
, do
return ()
]
deleteSession :: (MonadPlus m, MonadIO m, FilterMonad Response m, HasRqData m, Read a) => (SessionHandler a b) -> m ()
deleteSession (SessionHandler sessionConfig _ _ _ deleteSession') = msum
[ do
sid' <- lookCookieValue "SID"
let sid = sessionAuthDecrypt sessionConfig sid'
maybe (return ()) (liftIO . deleteSession') sid
expireCookie "SID"
return ()
, do
return ()
]