module Web.Scotty.Login.Session ( initializeCookieDb
, addSession
, removeSession
, authCheck
, authCheckWithSession
, SessionConfig(..)
, Session(..)
, defaultSessionConfig
)
where
import Control.Concurrent (forkIO, threadDelay)
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Maybe
import Crypto.Random (getRandomBytes)
import qualified Data.ByteString as B
import qualified Data.Text as TS
import qualified Data.Text.Lazy as T
import Data.Time.Clock
import Database.Persist as D
import Database.Persist.Sqlite
import Network.HTTP.Types.Status (forbidden403)
import Numeric (showHex)
import Web.Scotty.Cookie as SC
import Web.Scotty.Trans as S
import Web.Scotty.Login.Internal.Cookies as C
import Web.Scotty.Login.Internal.Model
import Control.Monad.Logger
import Control.Monad.Trans.Control
import Control.Monad.Trans.Resource (ResourceT, runResourceT)
import qualified Data.HashMap.Strict as H
import Data.IORef
import System.IO.Unsafe (unsafePerformIO)
data SessionConfig =
SessionConfig { dbPath :: String
, syncInterval :: NominalDiffTime
, expirationInterval :: NominalDiffTime
, debugMode :: Bool
}
type SessionVault = H.HashMap T.Text Session
type SessionStore = IORef SessionVault
defaultSessionConfig :: SessionConfig
defaultSessionConfig = SessionConfig "sessions.sqlite3" 1200 120 False
vault :: SessionStore
vault = unsafePerformIO $ newIORef H.empty
readVault :: IO SessionVault
readVault = readIORef vault
modifyVault :: (SessionVault -> SessionVault) -> IO ()
modifyVault f = atomicModifyIORef' vault (flip (,) () . f)
initializeCookieDb :: SessionConfig -> IO ()
initializeCookieDb c = do
t <- getCurrentTime
ses <- runDB c $ do runMigration migrateAll
selectList [SessionExpiration >=. t] []
let sessions = entityVal <$> ses :: [Session]
seshMap = H.fromList $ (\s -> (sessionSid s, s)) <$> sessions
modifyVault $ const seshMap
forkIO $ dbSyncAndCleanupLoop c
return ()
dbSyncAndCleanupLoop :: SessionConfig -> IO ()
dbSyncAndCleanupLoop c = do
threadDelay $ floor (syncInterval c) * 1000000
t <- getCurrentTime
vaultContents <- readVault
runDB c $ deleteWhere [SessionExpiration >=. t]
runDB c $ deleteWhere [SessionExpiration <=. t]
mapM_ (runDB c . insert) vaultContents
modifyVault $ H.filter (\s -> sessionExpiration s >= t)
dbSyncAndCleanupLoop c
addSession :: SessionConfig -> ActionT T.Text IO Session
addSession c = do
(bh :: B.ByteString) <- liftIO $ getRandomBytes 128
t <- liftIO getCurrentTime
let val = TS.pack $ mconcat $ map (`showHex` "") $ B.unpack bh
t' = addUTCTime (expirationInterval c) t
session = Session (T.fromStrict val) t'
when (debugMode c) $ liftIO $ putStrLn $ "adding session " ++ show session
C.setSimpleCookieExpr "SessionId" val t'
liftIO $ insertSession (T.fromStrict val) t'
return $ Session (T.fromStrict val) t'
authCheck :: (MonadIO m, ScottyError e)
=> ActionT e m ()
-> ActionT e m ()
-> ActionT e m ()
authCheck d a = authCheckWithSession d (const a)
authCheckWithSession
:: (MonadIO m, ScottyError e)
=> ActionT e m ()
-> (Session -> ActionT e m ())
-> ActionT e m ()
authCheckWithSession d a = do
vaultContents <- liftIO readVault
maybe (d >> status forbidden403) return <=< runMaybeT $ do
c <- lift (SC.getCookie "SessionId") >>= liftMaybe
session <- liftMaybe $ H.lookup (T.fromStrict c) vaultContents
curTime <- liftIO getCurrentTime
guard $ diffUTCTime (sessionExpiration session) curTime > 0
lift $ a session
removeSession :: SessionConfig -> ActionT T.Text IO ()
removeSession c = do
sid <- SC.getCookie "SessionId"
case sid of
Just sid' -> do liftIO $ modifyVault $ H.delete (T.fromStrict sid')
when (debugMode c) $ liftIO $ putStrLn $ "removed session " ++ show sid
Nothing -> return ()
insertSession :: T.Text
-> UTCTime
-> IO ()
insertSession sid t = modifyVault $ H.insert sid (Session sid t)
runDB
:: (MonadIO m, MonadBaseControl IO m)
=> SessionConfig -> SqlPersistT (LoggingT (ResourceT m)) a -> m a
runDB c = runSqlite' c $ TS.pack $ dbPath c
runSqlite'
:: (MonadIO m, MonadBaseControl IO m)
=> SessionConfig -> TS.Text -> SqlPersistT (LoggingT (ResourceT m)) a -> m a
runSqlite' conf connstr = runResourceT
. runStderrLoggingT
. filterLogger (const . const $ debugMode conf)
. withSqliteConn connstr
. runSqlConn
liftMaybe :: (Monad m) => Maybe a -> MaybeT m a
liftMaybe = MaybeT . return