{-# LANGUAGE OverloadedStrings #-}
module Snap.Util.CORS
(
applyCORS
, CORSOptions(..)
, defaultOptions
, OriginList(..)
, OriginSet, mkOriginSet, origins
, HashableURI(..), HashableMethod (..)
) where
import Control.Applicative
import Control.Monad (join, when)
import Data.CaseInsensitive (CI)
import Data.Hashable (Hashable(..))
import Data.Maybe (fromMaybe)
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Network.URI (URI (..), URIAuth (..), parseURI)
import qualified Data.Attoparsec.ByteString.Char8 as Attoparsec
import qualified Data.ByteString.Char8 as S
import qualified Data.CaseInsensitive as CI
import qualified Data.HashSet as HashSet
import qualified Data.Text as Text
import qualified Snap.Core as Snap
import Snap.Internal.Parsing (pTokens)
newtype OriginSet = OriginSet { OriginSet -> HashSet HashableURI
origins :: HashSet.HashSet HashableURI }
data OriginList
= Everywhere
| Nowhere
| Origins OriginSet
data CORSOptions m = CORSOptions
{ forall (m :: * -> *). CORSOptions m -> m OriginList
corsAllowOrigin :: m OriginList
, forall (m :: * -> *). CORSOptions m -> m Bool
corsAllowCredentials :: m Bool
, :: m (HashSet.HashSet (CI S.ByteString))
, forall (m :: * -> *). CORSOptions m -> m (HashSet HashableMethod)
corsAllowedMethods :: m (HashSet.HashSet HashableMethod)
, :: HashSet.HashSet S.ByteString -> m (HashSet.HashSet S.ByteString)
}
defaultOptions :: Monad m => CORSOptions m
defaultOptions :: forall (m :: * -> *). Monad m => CORSOptions m
defaultOptions = CORSOptions
{ corsAllowOrigin :: m OriginList
corsAllowOrigin = forall (m :: * -> *) a. Monad m => a -> m a
return OriginList
Everywhere
, corsAllowCredentials :: m Bool
corsAllowCredentials = forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
, corsExposeHeaders :: m (HashSet (CI ByteString))
corsExposeHeaders = forall (m :: * -> *) a. Monad m => a -> m a
return forall a. HashSet a
HashSet.empty
, corsAllowedMethods :: m (HashSet HashableMethod)
corsAllowedMethods = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! HashSet HashableMethod
defaultAllowedMethods
, corsAllowedHeaders :: HashSet ByteString -> m (HashSet ByteString)
corsAllowedHeaders = forall (m :: * -> *) a. Monad m => a -> m a
return
}
defaultAllowedMethods :: HashSet.HashSet HashableMethod
defaultAllowedMethods :: HashSet HashableMethod
defaultAllowedMethods = forall a. (Eq a, Hashable a) => [a] -> HashSet a
HashSet.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Method -> HashableMethod
HashableMethod
[ Method
Snap.GET, Method
Snap.POST, Method
Snap.PUT, Method
Snap.DELETE, Method
Snap.HEAD ]
applyCORS :: Snap.MonadSnap m => CORSOptions m -> m () -> m ()
applyCORS :: forall (m :: * -> *). MonadSnap m => CORSOptions m -> m () -> m ()
applyCORS CORSOptions m
options m ()
m =
(forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> Maybe URI
decodeOrigin forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CI ByteString -> m (Maybe ByteString)
getHeader CI ByteString
"Origin") forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall b a. b -> (a -> b) -> Maybe a -> b
maybe m ()
m URI -> m ()
corsRequestFrom
where
corsRequestFrom :: URI -> m ()
corsRequestFrom URI
origin = do
OriginList
originList <- forall (m :: * -> *). CORSOptions m -> m OriginList
corsAllowOrigin CORSOptions m
options
if URI
origin URI -> OriginList -> Bool
`inOriginList` OriginList
originList
then forall (m :: * -> *) a. MonadSnap m => Method -> m a -> m a
Snap.method Method
Snap.OPTIONS (forall {a}. Show a => a -> m ()
preflightRequestFrom URI
origin)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall {a}. Show a => a -> m ()
handleRequestFrom URI
origin
else m ()
m
preflightRequestFrom :: a -> m ()
preflightRequestFrom a
origin = do
Maybe HashableMethod
maybeMethod <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (String -> HashableMethod
parseMethod forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
S.unpack) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
CI ByteString -> m (Maybe ByteString)
getHeader CI ByteString
"Access-Control-Request-Method"
case Maybe HashableMethod
maybeMethod of
Maybe HashableMethod
Nothing -> m ()
m
Just HashableMethod
method -> do
HashSet HashableMethod
allowedMethods <- forall (m :: * -> *). CORSOptions m -> m (HashSet HashableMethod)
corsAllowedMethods CORSOptions m
options
if HashableMethod
method forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`HashSet.member` HashSet HashableMethod
allowedMethods
then do
Maybe (HashSet ByteString)
maybeHeaders <-
forall a. a -> Maybe a -> a
fromMaybe (forall a. a -> Maybe a
Just forall a. HashSet a
HashSet.empty) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> Maybe (HashSet ByteString)
splitHeaders
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CI ByteString -> m (Maybe ByteString)
getHeader CI ByteString
"Access-Control-Request-Headers"
case Maybe (HashSet ByteString)
maybeHeaders of
Maybe (HashSet ByteString)
Nothing -> m ()
m
Just HashSet ByteString
headers -> do
HashSet ByteString
allowedHeaders <- forall (m :: * -> *).
CORSOptions m -> HashSet ByteString -> m (HashSet ByteString)
corsAllowedHeaders CORSOptions m
options HashSet ByteString
headers
if Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall a. HashSet a -> Bool
HashSet.null forall a b. (a -> b) -> a -> b
$
HashSet ByteString
headers forall a. (Eq a, Hashable a) => HashSet a -> HashSet a -> HashSet a
`HashSet.difference` HashSet ByteString
allowedHeaders
then m ()
m
else do
forall {m :: * -> *} {a}. (MonadSnap m, Show a) => a -> m ()
addAccessControlAllowOrigin a
origin
m ()
addAccessControlAllowCredentials
forall {m :: * -> *} {a}.
MonadSnap m =>
CI ByteString -> (a -> ByteString) -> [a] -> m ()
commaSepHeader
CI ByteString
"Access-Control-Allow-Headers"
forall a. a -> a
id (forall a. HashSet a -> [a]
HashSet.toList HashSet ByteString
allowedHeaders)
forall {m :: * -> *} {a}.
MonadSnap m =>
CI ByteString -> (a -> ByteString) -> [a] -> m ()
commaSepHeader
CI ByteString
"Access-Control-Allow-Methods"
(String -> ByteString
S.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> String
show) (forall a. HashSet a -> [a]
HashSet.toList HashSet HashableMethod
allowedMethods)
else m ()
m
handleRequestFrom :: a -> m ()
handleRequestFrom a
origin = do
forall {m :: * -> *} {a}. (MonadSnap m, Show a) => a -> m ()
addAccessControlAllowOrigin a
origin
m ()
addAccessControlAllowCredentials
HashSet (CI ByteString)
exposeHeaders <- forall (m :: * -> *). CORSOptions m -> m (HashSet (CI ByteString))
corsExposeHeaders CORSOptions m
options
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall a. HashSet a -> Bool
HashSet.null HashSet (CI ByteString)
exposeHeaders) forall a b. (a -> b) -> a -> b
$
forall {m :: * -> *} {a}.
MonadSnap m =>
CI ByteString -> (a -> ByteString) -> [a] -> m ()
commaSepHeader
CI ByteString
"Access-Control-Expose-Headers"
forall s. CI s -> s
CI.original (forall a. HashSet a -> [a]
HashSet.toList HashSet (CI ByteString)
exposeHeaders)
m ()
m
addAccessControlAllowOrigin :: a -> m ()
addAccessControlAllowOrigin a
origin =
forall {m :: * -> *}.
MonadSnap m =>
CI ByteString -> ByteString -> m ()
addHeader CI ByteString
"Access-Control-Allow-Origin"
(Text -> ByteString
encodeUtf8 forall a b. (a -> b) -> a -> b
$ String -> Text
Text.pack forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show a
origin)
addAccessControlAllowCredentials :: m ()
addAccessControlAllowCredentials = do
Bool
allowCredentials <- forall (m :: * -> *). CORSOptions m -> m Bool
corsAllowCredentials CORSOptions m
options
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
allowCredentials) forall a b. (a -> b) -> a -> b
$
forall {m :: * -> *}.
MonadSnap m =>
CI ByteString -> ByteString -> m ()
addHeader CI ByteString
"Access-Control-Allow-Credentials" ByteString
"true"
decodeOrigin :: S.ByteString -> Maybe URI
decodeOrigin :: ByteString -> Maybe URI
decodeOrigin = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap URI -> URI
simplifyURI forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Maybe URI
parseURI forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
Text.unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
decodeUtf8
addHeader :: CI ByteString -> ByteString -> m ()
addHeader CI ByteString
k ByteString
v = forall (m :: * -> *). MonadSnap m => (Response -> Response) -> m ()
Snap.modifyResponse (forall a. HasHeaders a => CI ByteString -> ByteString -> a -> a
Snap.addHeader CI ByteString
k ByteString
v)
commaSepHeader :: CI ByteString -> (a -> ByteString) -> [a] -> m ()
commaSepHeader CI ByteString
k a -> ByteString
f [a]
vs =
case [a]
vs of
[] -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
[a]
_ -> forall {m :: * -> *}.
MonadSnap m =>
CI ByteString -> ByteString -> m ()
addHeader CI ByteString
k forall a b. (a -> b) -> a -> b
$ ByteString -> [ByteString] -> ByteString
S.intercalate ByteString
", " (forall a b. (a -> b) -> [a] -> [b]
map a -> ByteString
f [a]
vs)
getHeader :: CI ByteString -> m (Maybe ByteString)
getHeader = forall (m :: * -> *) a. MonadSnap m => (Request -> a) -> m a
Snap.getsRequest forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasHeaders a => CI ByteString -> a -> Maybe ByteString
Snap.getHeader
splitHeaders :: ByteString -> Maybe (HashSet ByteString)
splitHeaders = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall a b. a -> b -> a
const forall a. Maybe a
Nothing) (forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (Eq a, Hashable a) => [a] -> HashSet a
HashSet.fromList) forall b c a. (b -> c) -> (a -> b) -> a -> c
.
forall a. Parser a -> ByteString -> Either String a
Attoparsec.parseOnly Parser [ByteString]
pTokens
mkOriginSet :: [URI] -> OriginSet
mkOriginSet :: [URI] -> OriginSet
mkOriginSet = HashSet HashableURI -> OriginSet
OriginSet forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (Eq a, Hashable a) => [a] -> HashSet a
HashSet.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
.
forall a b. (a -> b) -> [a] -> [b]
map (URI -> HashableURI
HashableURI forall b c a. (b -> c) -> (a -> b) -> a -> c
. URI -> URI
simplifyURI)
simplifyURI :: URI -> URI
simplifyURI :: URI -> URI
simplifyURI URI
uri = URI
uri { uriAuthority :: Maybe URIAuth
uriAuthority =
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap URIAuth -> URIAuth
simplifyURIAuth (URI -> Maybe URIAuth
uriAuthority URI
uri)
, uriPath :: String
uriPath = String
""
, uriQuery :: String
uriQuery = String
""
, uriFragment :: String
uriFragment = String
""
}
where simplifyURIAuth :: URIAuth -> URIAuth
simplifyURIAuth URIAuth
auth = URIAuth
auth { uriUserInfo :: String
uriUserInfo = String
"" }
parseMethod :: String -> HashableMethod
parseMethod :: String -> HashableMethod
parseMethod String
"GET" = Method -> HashableMethod
HashableMethod Method
Snap.GET
parseMethod String
"POST" = Method -> HashableMethod
HashableMethod Method
Snap.POST
parseMethod String
"HEAD" = Method -> HashableMethod
HashableMethod Method
Snap.HEAD
parseMethod String
"PUT" = Method -> HashableMethod
HashableMethod Method
Snap.PUT
parseMethod String
"DELETE" = Method -> HashableMethod
HashableMethod Method
Snap.DELETE
parseMethod String
"TRACE" = Method -> HashableMethod
HashableMethod Method
Snap.TRACE
parseMethod String
"OPTIONS" = Method -> HashableMethod
HashableMethod Method
Snap.OPTIONS
parseMethod String
"CONNECT" = Method -> HashableMethod
HashableMethod Method
Snap.CONNECT
parseMethod String
"PATCH" = Method -> HashableMethod
HashableMethod Method
Snap.PATCH
parseMethod String
s = Method -> HashableMethod
HashableMethod forall a b. (a -> b) -> a -> b
$ ByteString -> Method
Snap.Method (String -> ByteString
S.pack String
s)
newtype HashableURI = HashableURI URI
deriving (HashableURI -> HashableURI -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HashableURI -> HashableURI -> Bool
$c/= :: HashableURI -> HashableURI -> Bool
== :: HashableURI -> HashableURI -> Bool
$c== :: HashableURI -> HashableURI -> Bool
Eq)
instance Show HashableURI where
show :: HashableURI -> String
show (HashableURI URI
u) = forall a. Show a => a -> String
show URI
u
instance Hashable HashableURI where
hashWithSalt :: Int -> HashableURI -> Int
hashWithSalt Int
s (HashableURI (URI String
scheme Maybe URIAuth
authority String
path String
query String
fragment)) =
Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
String
scheme forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap URIAuth -> Int
hashAuthority Maybe URIAuth
authority forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
String
path forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
String
query forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
String
fragment
where
hashAuthority :: URIAuth -> Int
hashAuthority (URIAuth String
userInfo String
regName String
port) =
Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
String
userInfo forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
String
regName forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
String
port
inOriginList :: URI -> OriginList -> Bool
URI
_ inOriginList :: URI -> OriginList -> Bool
`inOriginList` OriginList
Nowhere = Bool
False
URI
_ `inOriginList` OriginList
Everywhere = Bool
True
URI
origin `inOriginList` (Origins (OriginSet HashSet HashableURI
xs)) =
URI -> HashableURI
HashableURI URI
origin forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`HashSet.member` HashSet HashableURI
xs
newtype HashableMethod = HashableMethod Snap.Method
deriving (HashableMethod -> HashableMethod -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HashableMethod -> HashableMethod -> Bool
$c/= :: HashableMethod -> HashableMethod -> Bool
== :: HashableMethod -> HashableMethod -> Bool
$c== :: HashableMethod -> HashableMethod -> Bool
Eq)
instance Hashable HashableMethod where
hashWithSalt :: Int -> HashableMethod -> Int
hashWithSalt Int
s (HashableMethod Method
Snap.GET) = Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
0 :: Int)
hashWithSalt Int
s (HashableMethod Method
Snap.HEAD) = Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
1 :: Int)
hashWithSalt Int
s (HashableMethod Method
Snap.POST) = Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
2 :: Int)
hashWithSalt Int
s (HashableMethod Method
Snap.PUT) = Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
3 :: Int)
hashWithSalt Int
s (HashableMethod Method
Snap.DELETE) = Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
4 :: Int)
hashWithSalt Int
s (HashableMethod Method
Snap.TRACE) = Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
5 :: Int)
hashWithSalt Int
s (HashableMethod Method
Snap.OPTIONS) = Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
6 :: Int)
hashWithSalt Int
s (HashableMethod Method
Snap.CONNECT) = Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
7 :: Int)
hashWithSalt Int
s (HashableMethod Method
Snap.PATCH) = Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
8 :: Int)
hashWithSalt Int
s (HashableMethod (Snap.Method ByteString
m)) =
Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
9 :: Int) forall a. Hashable a => Int -> a -> Int
`hashWithSalt` ByteString
m
instance Show HashableMethod where
show :: HashableMethod -> String
show (HashableMethod Method
m) = forall a. Show a => a -> String
show Method
m