{-# LANGUAGE BlockArguments    #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards   #-}

-- |
-- Module: Network.Wai.Session.Redis
-- Copyright: (c) 2021, t4ccer
-- License: BSD3
-- Stability: experimental
-- Portability: portable
--
-- Simple Redis backed wai-session backend. This module allows you to store
-- session data of wai-sessions in a Redis database.
module Network.Wai.Session.Redis
  ( dbStore
  , clearSession
  , SessionSettings(..)
  ) where

import           Control.Monad
import           Control.Monad.IO.Class
import           Data.ByteString        (ByteString)
import           Data.Default
import           Data.Either
import           Data.Serialize         (Serialize, decode, encode)
import           Database.Redis         hiding (decode)
import           Network.Wai.Session

-- | Settings to control session store
data SessionSettings = SessionSettings
  { SessionSettings -> ConnectInfo
redisConnectionInfo :: ConnectInfo
  , SessionSettings -> Integer
expiratinTime       :: Integer
  -- ^ Session expiration time in seconds
  }

instance Default SessionSettings where
  def :: SessionSettings
def = SessionSettings :: ConnectInfo -> Integer -> SessionSettings
SessionSettings
    { redisConnectionInfo :: ConnectInfo
redisConnectionInfo = ConnectInfo
defaultConnectInfo
    , expiratinTime :: Integer
expiratinTime       = Integer
60Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
60Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
24Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
7 -- One week
    }

eitherToMaybe :: Either a b -> Maybe b
eitherToMaybe :: Either a b -> Maybe b
eitherToMaybe (Left a
_)  = Maybe b
forall a. Maybe a
Nothing
eitherToMaybe (Right b
a) = b -> Maybe b
forall a. a -> Maybe a
Just b
a

connectAndRunRedis :: MonadIO m => ConnectInfo -> Redis b -> m b
connectAndRunRedis :: ConnectInfo -> Redis b -> m b
connectAndRunRedis ConnectInfo
ci Redis b
cmd = IO b -> m b
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
  Connection
conn <- ConnectInfo -> IO Connection
connect ConnectInfo
ci
  b
res  <- Connection -> Redis b -> IO b
forall a. Connection -> Redis a -> IO a
runRedis Connection
conn Redis b
cmd
  Connection -> IO ()
disconnect Connection
conn
  b -> IO b
forall (m :: * -> *) a. Monad m => a -> m a
return b
res

createSession :: MonadIO m => SessionSettings -> m ByteString
createSession :: SessionSettings -> m ByteString
createSession SessionSettings{Integer
ConnectInfo
expiratinTime :: Integer
redisConnectionInfo :: ConnectInfo
expiratinTime :: SessionSettings -> Integer
redisConnectionInfo :: SessionSettings -> ConnectInfo
..} = IO ByteString -> m ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
  ByteString
sesId <- IO ByteString
genSessionId
  ConnectInfo -> Redis (Either Reply Bool) -> IO (Either Reply Bool)
forall (m :: * -> *) b. MonadIO m => ConnectInfo -> Redis b -> m b
connectAndRunRedis ConnectInfo
redisConnectionInfo (Redis (Either Reply Bool) -> IO (Either Reply Bool))
-> Redis (Either Reply Bool) -> IO (Either Reply Bool)
forall a b. (a -> b) -> a -> b
$ do
    ByteString
-> ByteString -> ByteString -> Redis (Either Reply Integer)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> ByteString -> ByteString -> m (f Integer)
hset ByteString
sesId ByteString
"" ByteString
""
    ByteString -> Integer -> Redis (Either Reply Bool)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> Integer -> m (f Bool)
expire ByteString
sesId Integer
expiratinTime
  ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
sesId

isSesIdValid :: MonadIO m => SessionSettings -> ByteString -> m Bool
isSesIdValid :: SessionSettings -> ByteString -> m Bool
isSesIdValid SessionSettings{Integer
ConnectInfo
expiratinTime :: Integer
redisConnectionInfo :: ConnectInfo
expiratinTime :: SessionSettings -> Integer
redisConnectionInfo :: SessionSettings -> ConnectInfo
..} ByteString
sesId = IO Bool -> m Bool
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
  Either Reply Bool
res <- ConnectInfo -> Redis (Either Reply Bool) -> IO (Either Reply Bool)
forall (m :: * -> *) b. MonadIO m => ConnectInfo -> Redis b -> m b
connectAndRunRedis ConnectInfo
redisConnectionInfo (Redis (Either Reply Bool) -> IO (Either Reply Bool))
-> Redis (Either Reply Bool) -> IO (Either Reply Bool)
forall a b. (a -> b) -> a -> b
$ do
    ByteString -> Redis (Either Reply Bool)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> m (f Bool)
exists ByteString
sesId
  Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> IO Bool) -> Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ Bool -> Either Reply Bool -> Bool
forall b a. b -> Either a b -> b
fromRight Bool
False Either Reply Bool
res

insertIntoSession :: MonadIO m => SessionSettings
  -> ByteString -- ^ Sessionn id
  -> ByteString -- ^ Key
  -> ByteString -- ^ Value
  -> m ()
insertIntoSession :: SessionSettings -> ByteString -> ByteString -> ByteString -> m ()
insertIntoSession SessionSettings{Integer
ConnectInfo
expiratinTime :: Integer
redisConnectionInfo :: ConnectInfo
expiratinTime :: SessionSettings -> Integer
redisConnectionInfo :: SessionSettings -> ConnectInfo
..} ByteString
sesId ByteString
key ByteString
value = do
  ConnectInfo -> Redis (Either Reply Bool) -> m (Either Reply Bool)
forall (m :: * -> *) b. MonadIO m => ConnectInfo -> Redis b -> m b
connectAndRunRedis ConnectInfo
redisConnectionInfo (Redis (Either Reply Bool) -> m (Either Reply Bool))
-> Redis (Either Reply Bool) -> m (Either Reply Bool)
forall a b. (a -> b) -> a -> b
$ do
    ByteString
-> ByteString -> ByteString -> Redis (Either Reply Integer)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> ByteString -> ByteString -> m (f Integer)
hset ByteString
sesId ByteString
key ByteString
value
    ByteString -> Integer -> Redis (Either Reply Bool)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> Integer -> m (f Bool)
expire ByteString
sesId Integer
expiratinTime
  () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

lookupFromSession :: MonadIO m => SessionSettings
  -> ByteString -- ^ Session id
  -> ByteString -- ^ Key
  -> m (Maybe ByteString)
lookupFromSession :: SessionSettings -> ByteString -> ByteString -> m (Maybe ByteString)
lookupFromSession SessionSettings{Integer
ConnectInfo
expiratinTime :: Integer
redisConnectionInfo :: ConnectInfo
expiratinTime :: SessionSettings -> Integer
redisConnectionInfo :: SessionSettings -> ConnectInfo
..} ByteString
sesId ByteString
key = do
  Either Reply (Maybe ByteString)
v <- ConnectInfo
-> Redis (Either Reply (Maybe ByteString))
-> m (Either Reply (Maybe ByteString))
forall (m :: * -> *) b. MonadIO m => ConnectInfo -> Redis b -> m b
connectAndRunRedis ConnectInfo
redisConnectionInfo (Redis (Either Reply (Maybe ByteString))
 -> m (Either Reply (Maybe ByteString)))
-> Redis (Either Reply (Maybe ByteString))
-> m (Either Reply (Maybe ByteString))
forall a b. (a -> b) -> a -> b
$ do
    Either Reply (Maybe ByteString)
v <- ByteString -> ByteString -> Redis (Either Reply (Maybe ByteString))
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> ByteString -> m (f (Maybe ByteString))
hget ByteString
sesId ByteString
key
    ByteString -> Integer -> Redis (Either Reply Bool)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> Integer -> m (f Bool)
expire ByteString
sesId Integer
expiratinTime
    Either Reply (Maybe ByteString)
-> Redis (Either Reply (Maybe ByteString))
forall (m :: * -> *) a. Monad m => a -> m a
return Either Reply (Maybe ByteString)
v
  Maybe ByteString -> m (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> m (Maybe ByteString))
-> Maybe ByteString -> m (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ Maybe (Maybe ByteString) -> Maybe ByteString
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (Maybe (Maybe ByteString) -> Maybe ByteString)
-> Maybe (Maybe ByteString) -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Either Reply (Maybe ByteString) -> Maybe (Maybe ByteString)
forall a b. Either a b -> Maybe b
eitherToMaybe Either Reply (Maybe ByteString)
v

-- | Invalidate session id
clearSession :: MonadIO m => SessionSettings
  -> ByteString -- ^ Session id
  -> m ()
clearSession :: SessionSettings -> ByteString -> m ()
clearSession SessionSettings{Integer
ConnectInfo
expiratinTime :: Integer
redisConnectionInfo :: ConnectInfo
expiratinTime :: SessionSettings -> Integer
redisConnectionInfo :: SessionSettings -> ConnectInfo
..} ByteString
sessionId = do
  ConnectInfo
-> Redis (Either Reply Integer) -> m (Either Reply Integer)
forall (m :: * -> *) b. MonadIO m => ConnectInfo -> Redis b -> m b
connectAndRunRedis ConnectInfo
redisConnectionInfo (Redis (Either Reply Integer) -> m (Either Reply Integer))
-> Redis (Either Reply Integer) -> m (Either Reply Integer)
forall a b. (a -> b) -> a -> b
$ do
    [ByteString] -> Redis (Either Reply Integer)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
[ByteString] -> m (f Integer)
del [ByteString
sessionId]
  () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Create new redis backend wai session store
dbStore :: (MonadIO m1, MonadIO m2, Eq k, Serialize k, Serialize v) => SessionSettings -> m2 (SessionStore m1 k v)
dbStore :: SessionSettings -> m2 (SessionStore m1 k v)
dbStore SessionSettings
s = do
  SessionStore m1 k v -> m2 (SessionStore m1 k v)
forall (m :: * -> *) a. Monad m => a -> m a
return (SessionStore m1 k v -> m2 (SessionStore m1 k v))
-> SessionStore m1 k v -> m2 (SessionStore m1 k v)
forall a b. (a -> b) -> a -> b
$ SessionSettings -> SessionStore m1 k v
forall (m1 :: * -> *) (m2 :: * -> *) k v.
(MonadIO m1, MonadIO m2, Eq k, Serialize k, Serialize v,
 Monad m2) =>
SessionSettings
-> Maybe ByteString -> m2 (Session m1 k v, m2 ByteString)
dbStore' SessionSettings
s

dbStore' :: (MonadIO m1, MonadIO m2, Eq k, Serialize k, Serialize v, Monad m2) => SessionSettings -> Maybe ByteString -> m2 (Session m1 k v, m2 ByteString)
dbStore' :: SessionSettings
-> Maybe ByteString -> m2 (Session m1 k v, m2 ByteString)
dbStore' SessionSettings
s (Just ByteString
sesId) = do
  Bool
isValid <- SessionSettings -> ByteString -> m2 Bool
forall (m :: * -> *).
MonadIO m =>
SessionSettings -> ByteString -> m Bool
isSesIdValid SessionSettings
s ByteString
sesId
  if Bool
isValid
    then (Session m1 k v, m2 ByteString)
-> m2 (Session m1 k v, m2 ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (SessionSettings -> ByteString -> Session m1 k v
forall (m1 :: * -> *) k v.
(MonadIO m1, Eq k, Serialize k, Serialize v) =>
SessionSettings -> ByteString -> Session m1 k v
mkSessionFromSesId SessionSettings
s ByteString
sesId, ByteString -> m2 ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
sesId)
    else SessionSettings
-> Maybe ByteString -> m2 (Session m1 k v, m2 ByteString)
forall (m1 :: * -> *) (m2 :: * -> *) k v.
(MonadIO m1, MonadIO m2, Eq k, Serialize k, Serialize v,
 Monad m2) =>
SessionSettings
-> Maybe ByteString -> m2 (Session m1 k v, m2 ByteString)
dbStore' SessionSettings
s Maybe ByteString
forall a. Maybe a
Nothing
dbStore' SessionSettings
s Maybe ByteString
Nothing = do
  ByteString
sesId <- SessionSettings -> m2 ByteString
forall (m :: * -> *). MonadIO m => SessionSettings -> m ByteString
createSession SessionSettings
s
  (Session m1 k v, m2 ByteString)
-> m2 (Session m1 k v, m2 ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (SessionSettings -> ByteString -> Session m1 k v
forall (m1 :: * -> *) k v.
(MonadIO m1, Eq k, Serialize k, Serialize v) =>
SessionSettings -> ByteString -> Session m1 k v
mkSessionFromSesId SessionSettings
s ByteString
sesId, ByteString -> m2 ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
sesId)

mkSessionFromSesId :: (MonadIO m1, Eq k, Serialize k, Serialize v) => SessionSettings -> ByteString -> Session m1 k v
mkSessionFromSesId :: SessionSettings -> ByteString -> Session m1 k v
mkSessionFromSesId SessionSettings
s ByteString
sesId = (k -> m1 (Maybe v)
forall (m :: * -> *) a a.
(MonadIO m, Serialize a, Serialize a) =>
a -> m (Maybe a)
mkLookup, k -> v -> m1 ()
forall (m :: * -> *) a a.
(MonadIO m, Serialize a, Serialize a) =>
a -> a -> m ()
mkInsert)
  where
    mkLookup :: a -> m (Maybe a)
mkLookup a
k = IO (Maybe a) -> m (Maybe a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe a) -> m (Maybe a)) -> IO (Maybe a) -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ (Maybe ByteString -> Maybe a)
-> IO (Maybe ByteString) -> IO (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe (Maybe a) -> Maybe a
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (Maybe (Maybe a) -> Maybe a)
-> (Maybe ByteString -> Maybe (Maybe a))
-> Maybe ByteString
-> Maybe a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> Maybe a) -> Maybe ByteString -> Maybe (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Either String a -> Maybe a
forall a b. Either a b -> Maybe b
eitherToMaybe (Either String a -> Maybe a)
-> (ByteString -> Either String a) -> ByteString -> Maybe a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Either String a
forall a. Serialize a => ByteString -> Either String a
decode)) (IO (Maybe ByteString) -> IO (Maybe a))
-> IO (Maybe ByteString) -> IO (Maybe a)
forall a b. (a -> b) -> a -> b
$ SessionSettings
-> ByteString -> ByteString -> IO (Maybe ByteString)
forall (m :: * -> *).
MonadIO m =>
SessionSettings -> ByteString -> ByteString -> m (Maybe ByteString)
lookupFromSession SessionSettings
s ByteString
sesId (a -> ByteString
forall a. Serialize a => a -> ByteString
encode a
k)
    mkInsert :: a -> a -> m ()
mkInsert a
k a
v = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ SessionSettings -> ByteString -> ByteString -> ByteString -> IO ()
forall (m :: * -> *).
MonadIO m =>
SessionSettings -> ByteString -> ByteString -> ByteString -> m ()
insertIntoSession SessionSettings
s ByteString
sesId (a -> ByteString
forall a. Serialize a => a -> ByteString
encode a
k) (a -> ByteString
forall a. Serialize a => a -> ByteString
encode a
v)