module Snap.Util.FileUploads
(
handleFileUploads
, handleMultipart
, PartInfo(..)
, UploadPolicy
, defaultUploadPolicy
, doProcessFormInputs
, setProcessFormInputs
, getMaximumFormInputSize
, setMaximumFormInputSize
, getMaximumNumberOfFormInputs
, setMaximumNumberOfFormInputs
, getMinimumUploadRate
, setMinimumUploadRate
, getMinimumUploadSeconds
, setMinimumUploadSeconds
, getUploadTimeout
, setUploadTimeout
, PartUploadPolicy
, disallow
, allowWithMaximumSize
, FileUploadException
, fileUploadExceptionReason
, BadPartException
, badPartExceptionReason
, PolicyViolationException
, policyViolationExceptionReason
) where
import Control.Arrow
import Control.Applicative
import Control.Concurrent.MVar
import Control.Exception (SomeException(..))
import Control.Monad
import Control.Monad.CatchIO
import Control.Monad.Trans
import qualified Data.Attoparsec.Char8 as Atto
import Data.Attoparsec.Char8
import Data.Attoparsec.Enumerator
import qualified Data.ByteString.Char8 as S
import Data.ByteString.Char8 (ByteString)
import Data.ByteString.Internal (c2w)
import qualified Data.CaseInsensitive as CI
import Data.Enumerator.Binary (iterHandle)
import Data.IORef
import Data.Int
import Data.List hiding (takeWhile)
import qualified Data.Map as Map
import Data.Maybe
import qualified Data.Text as T
import Data.Text (Text)
import qualified Data.Text.Encoding as TE
import Data.Typeable
#if MIN_VERSION_base(4,6,0)
import Prelude hiding (getLine, takeWhile)
#else
import Prelude hiding (catch, getLine, takeWhile)
#endif
import System.Directory
import System.IO hiding (isEOF)
import Snap.Core
import Snap.Iteratee hiding (map)
import qualified Snap.Iteratee as I
import Snap.Internal.Debug
import Snap.Internal.Iteratee.Debug
import Snap.Internal.Iteratee.BoyerMooreHorspool
import Snap.Internal.Parsing
import qualified Snap.Types.Headers as H
#ifdef USE_UNIX
import System.FilePath ((</>))
import System.Posix.Temp (mkstemp)
#endif
handleFileUploads ::
(MonadSnap m) =>
FilePath
-> UploadPolicy
-> (PartInfo -> PartUploadPolicy)
-> ([(PartInfo, Either PolicyViolationException FilePath)] -> m a)
-> m a
handleFileUploads tmpdir uploadPolicy partPolicy handler = do
uploadedFiles <- newUploadedFiles
(do
xs <- handleMultipart uploadPolicy (iter uploadedFiles)
handler xs
) `finally` (cleanupUploadedFiles uploadedFiles)
where
iter uploadedFiles partInfo = maybe disallowed takeIt mbFs
where
ctText = partContentType partInfo
fnText = fromMaybe "" $ partFileName partInfo
ct = TE.decodeUtf8 ctText
fn = TE.decodeUtf8 fnText
(PartUploadPolicy mbFs) = partPolicy partInfo
retVal (_,x) = (partInfo, Right x)
takeIt maxSize = do
debug "handleFileUploads/takeIt: begin"
let it = fmap retVal $
joinI' $
iterateeDebugWrapper "takeNoMoreThan" $
takeNoMoreThan maxSize $$
fileReader uploadedFiles tmpdir partInfo
it `catches` [
Handler $ \(_ :: TooManyBytesReadException) -> do
debug $ "handleFileUploads/iter: " ++
"caught TooManyBytesReadException"
skipToEof
tooMany maxSize
, Handler $ \(e :: SomeException) -> do
debug $ "handleFileUploads/iter: caught " ++ show e
debug "handleFileUploads/iter: rethrowing"
throw e
]
tooMany maxSize =
return ( partInfo
, Left $ PolicyViolationException
$ T.concat [ "File \""
, fn
, "\" exceeded maximum allowable size "
, T.pack $ show maxSize ] )
disallowed =
return ( partInfo
, Left $ PolicyViolationException
$ T.concat [ "Policy disallowed upload of file \""
, fn
, "\" with content-type \""
, ct
, "\"" ] )
handleMultipart ::
(MonadSnap m) =>
UploadPolicy
-> (PartInfo -> Iteratee ByteString IO a)
-> m [a]
handleMultipart uploadPolicy origPartHandler = do
hdrs <- liftM headers getRequest
let (ct, mbBoundary) = getContentType hdrs
tickleTimeout <- liftM (. max) getTimeoutModifier
let bumpTimeout = tickleTimeout $ uploadTimeout uploadPolicy
let partHandler = if doProcessFormInputs uploadPolicy
then captureVariableOrReadFile
(getMaximumFormInputSize uploadPolicy)
origPartHandler
else (\p -> fmap File (origPartHandler p))
when (ct /= "multipart/form-data") $ do
debug $ "handleMultipart called with content-type=" ++ S.unpack ct
++ ", passing"
pass
when (isNothing mbBoundary) $
throw $ BadPartException $
"got multipart/form-data without boundary"
let boundary = fromJust mbBoundary
captures <- runRequestBody (iter bumpTimeout boundary partHandler)
xs <- procCaptures [] captures
modifyRequest $ \req ->
let pp = rqPostParams req
in rqModifyParams (\p -> Map.unionWith (++) p pp) req
return xs
where
rateLimit bump m =
killIfTooSlow bump
(minimumUploadRate uploadPolicy)
(minimumUploadSeconds uploadPolicy)
m
`catchError` \e -> do
debug $ "rateLimit: caught " ++ show e
let (me::Maybe RateTooSlowException) = fromException e
maybe (throwError e)
terminateConnection
me
iter bump boundary ph = iterateeDebugWrapper "killIfTooSlow" $
rateLimit bump $
internalHandleMultipart boundary ph
ins k v = Map.insertWith' (flip (++)) k [v]
maxFormVars = maximumNumberOfFormInputs uploadPolicy
modifyParams f r = r { rqPostParams = f $ rqPostParams r }
procCaptures l [] = return $! reverse l
procCaptures l ((File x):xs) = procCaptures (x:l) xs
procCaptures l ((Capture k v):xs) = do
rq <- getRequest
let n = Map.size $ rqPostParams rq
when (n >= maxFormVars) $
throw $ PolicyViolationException $
T.concat [ "number of form inputs exceeded maximum of "
, T.pack $ show maxFormVars ]
modifyRequest $ modifyParams (ins k v)
procCaptures l xs
data PartInfo =
PartInfo { partFieldName :: !ByteString
, partFileName :: !(Maybe ByteString)
, partContentType :: !ByteString
}
deriving (Show)
data FileUploadException =
GenericFileUploadException {
_genericFileUploadExceptionReason :: Text
}
| forall e . (Exception e, Show e) =>
WrappedFileUploadException {
_wrappedFileUploadException :: e
, _wrappedFileUploadExceptionReason :: Text
}
deriving (Typeable)
instance Show FileUploadException where
show (GenericFileUploadException r) = "File upload exception: " ++
T.unpack r
show (WrappedFileUploadException e _) = show e
instance Exception FileUploadException
fileUploadExceptionReason :: FileUploadException -> Text
fileUploadExceptionReason (GenericFileUploadException r) = r
fileUploadExceptionReason (WrappedFileUploadException _ r) = r
uploadExceptionToException :: Exception e => e -> Text -> SomeException
uploadExceptionToException e r =
SomeException $ WrappedFileUploadException e r
uploadExceptionFromException :: Exception e => SomeException -> Maybe e
uploadExceptionFromException x = do
WrappedFileUploadException e _ <- fromException x
cast e
data BadPartException = BadPartException { badPartExceptionReason :: Text }
deriving (Typeable)
instance Exception BadPartException where
toException e@(BadPartException r) = uploadExceptionToException e r
fromException = uploadExceptionFromException
instance Show BadPartException where
show (BadPartException s) = "Bad part: " ++ T.unpack s
data PolicyViolationException = PolicyViolationException {
policyViolationExceptionReason :: Text
} deriving (Typeable)
instance Exception PolicyViolationException where
toException e@(PolicyViolationException r) =
uploadExceptionToException e r
fromException = uploadExceptionFromException
instance Show PolicyViolationException where
show (PolicyViolationException s) = "File upload policy violation: "
++ T.unpack s
data UploadPolicy = UploadPolicy {
processFormInputs :: Bool
, maximumFormInputSize :: Int64
, maximumNumberOfFormInputs :: Int
, minimumUploadRate :: Double
, minimumUploadSeconds :: Int
, uploadTimeout :: Int
} deriving (Show, Eq)
defaultUploadPolicy :: UploadPolicy
defaultUploadPolicy = UploadPolicy True maxSize maxNum minRate minSeconds tout
where
maxSize = 2^(17::Int)
maxNum = 10
minRate = 1000
minSeconds = 10
tout = 20
doProcessFormInputs :: UploadPolicy -> Bool
doProcessFormInputs = processFormInputs
setProcessFormInputs :: Bool -> UploadPolicy -> UploadPolicy
setProcessFormInputs b u = u { processFormInputs = b }
getMaximumFormInputSize :: UploadPolicy -> Int64
getMaximumFormInputSize = maximumFormInputSize
setMaximumFormInputSize :: Int64 -> UploadPolicy -> UploadPolicy
setMaximumFormInputSize s u = u { maximumFormInputSize = s }
getMaximumNumberOfFormInputs :: UploadPolicy -> Int
getMaximumNumberOfFormInputs = maximumNumberOfFormInputs
setMaximumNumberOfFormInputs :: Int -> UploadPolicy -> UploadPolicy
setMaximumNumberOfFormInputs s u = u { maximumNumberOfFormInputs = s }
getMinimumUploadRate :: UploadPolicy -> Double
getMinimumUploadRate = minimumUploadRate
setMinimumUploadRate :: Double -> UploadPolicy -> UploadPolicy
setMinimumUploadRate s u = u { minimumUploadRate = s }
getMinimumUploadSeconds :: UploadPolicy -> Int
getMinimumUploadSeconds = minimumUploadSeconds
setMinimumUploadSeconds :: Int -> UploadPolicy -> UploadPolicy
setMinimumUploadSeconds s u = u { minimumUploadSeconds = s }
getUploadTimeout :: UploadPolicy -> Int
getUploadTimeout = uploadTimeout
setUploadTimeout :: Int -> UploadPolicy -> UploadPolicy
setUploadTimeout s u = u { uploadTimeout = s }
data PartUploadPolicy = PartUploadPolicy {
_maximumFileSize :: Maybe Int64
} deriving (Show, Eq)
disallow :: PartUploadPolicy
disallow = PartUploadPolicy Nothing
allowWithMaximumSize :: Int64 -> PartUploadPolicy
allowWithMaximumSize = PartUploadPolicy . Just
captureVariableOrReadFile ::
Int64
-> (PartInfo -> Iteratee ByteString IO a)
-> (PartInfo -> Iteratee ByteString IO (Capture a))
captureVariableOrReadFile maxSize fileHandler partInfo =
case partFileName partInfo of
Nothing -> iter
_ -> liftM File $ fileHandler partInfo
where
iter = varIter `catchError` handler
fieldName = partFieldName partInfo
varIter = do
var <- liftM S.concat $
joinI' $
takeNoMoreThan maxSize $$ consume
return $! Capture fieldName var
handler e = do
debug $ "captureVariableOrReadFile/handler: caught " ++ show e
let m = fromException e :: Maybe TooManyBytesReadException
case m of
Nothing -> do
debug "didn't expect this error, rethrowing"
throwError e
Just _ -> do
debug "rethrowing as PolicyViolationException"
throwError $ PolicyViolationException $
T.concat [ "form input '"
, TE.decodeUtf8 fieldName
, "' exceeded maximum permissible size ("
, T.pack $ show maxSize
, " bytes)" ]
data Capture a = Capture ByteString ByteString
| File a
deriving (Show)
fileReader :: UploadedFiles
-> FilePath
-> PartInfo
-> Iteratee ByteString IO (PartInfo, FilePath)
fileReader uploadedFiles tmpdir partInfo = do
debug "fileReader: begin"
(fn, h) <- openFileForUpload uploadedFiles tmpdir
let i = iterateeDebugWrapper "fileReader" $ iter fn h
i `catch` \(e::SomeException) -> throwError e
where
iter fileName h = do
iterHandle h
debug "fileReader: closing active file"
closeActiveFile uploadedFiles
return (partInfo, fileName)
internalHandleMultipart ::
ByteString
-> (PartInfo -> Iteratee ByteString IO a)
-> Iteratee ByteString IO [a]
internalHandleMultipart boundary clientHandler = go `catch` errorHandler
where
errorHandler :: SomeException -> Iteratee ByteString IO a
errorHandler e = do
skipToEof
throwError e
go = do
_ <- iterParser $ parseFirstBoundary boundary
step <- iterateeDebugWrapper "boyer-moore" $
(bmhEnumeratee (fullBoundary boundary) $$ processParts iter)
liftM concat $ lift $ run_ $ returnI step
pBoundary b = Atto.try $ do
_ <- string "--"
string b
fullBoundary b = S.concat ["\r\n", "--", b]
pLine = takeWhile (not . isEndOfLine . c2w) <* eol
takeLine = pLine *> pure ()
parseFirstBoundary b = pBoundary b <|> (takeLine *> parseFirstBoundary b)
takeHeaders = hdrs `catchError` handler
where
hdrs = liftM toHeaders $
iterateeDebugWrapper "header parser" $
joinI' $
takeNoMoreThan mAX_HDRS_SIZE $$
iterParser pHeadersWithSeparator
handler e = do
debug $ "internalHandleMultipart/takeHeaders: caught " ++ show e
let m = fromException e :: Maybe TooManyBytesReadException
case m of
Nothing -> throwError e
Just _ -> throwError $ BadPartException $
"headers exceeded maximum size"
iter = do
hdrs <- takeHeaders
debug $ "internalHandleMultipart/iter: got headers"
let (contentType, mboundary) = getContentType hdrs
let (fieldName, fileName) = getFieldName hdrs
if contentType == "multipart/mixed"
then maybe (throwError $ BadPartException $
"got multipart/mixed without boundary")
(processMixed fieldName)
mboundary
else do
let info = PartInfo fieldName fileName contentType
liftM (:[]) $ clientHandler info
processMixed fieldName mixedBoundary = do
_ <- iterParser $ parseFirstBoundary mixedBoundary
step <- iterateeDebugWrapper "boyer-moore" $
(bmhEnumeratee (fullBoundary mixedBoundary) $$
processParts (mixedIter fieldName))
lift $ run_ $ returnI step
mixedIter fieldName = do
hdrs <- takeHeaders
let (contentType, _) = getContentType hdrs
let (_, fileName) = getFieldName hdrs
let info = PartInfo fieldName fileName contentType
clientHandler info
getContentType :: Headers
-> (ByteString, Maybe ByteString)
getContentType hdrs = (contentType, boundary)
where
contentTypeValue = fromMaybe "text/plain" $
getHeader "content-type" hdrs
eCT = fullyParse contentTypeValue pContentTypeWithParameters
(contentType, params) = either (const ("text/plain", [])) id eCT
boundary = findParam "boundary" params
getFieldName :: Headers -> (ByteString, Maybe ByteString)
getFieldName hdrs = (fieldName, fileName)
where
contentDispositionValue = fromMaybe "" $
getHeader "content-disposition" hdrs
eDisposition = fullyParse contentDispositionValue pValueWithParameters
(_, dispositionParameters) =
either (const ("", [])) id eDisposition
fieldName = fromMaybe "" $ findParam "name" dispositionParameters
fileName = findParam "filename" dispositionParameters
findParam :: (Eq a) => a -> [(a, b)] -> Maybe b
findParam p = fmap snd . find ((== p) . fst)
processPart :: (Monad m) => Enumeratee MatchInfo ByteString m a
processPart st =
case st of
(Continue k) -> go k
_ -> yield st (Chunks [])
where
go :: (Monad m) => (Stream ByteString -> Iteratee ByteString m a)
-> Iteratee MatchInfo m (Step ByteString m a)
go !k =
I.head >>= maybe finished process
where
finished =
lift $ runIteratee $ k EOF
process (NoMatch !s) = do
!step <- lift $ runIteratee $ k $ Chunks [s]
case step of
(Continue k') -> go k'
_ -> yield step (Chunks [])
process (Match _) =
lift $ runIteratee $ k EOF
processParts :: Iteratee ByteString IO a
-> Iteratee MatchInfo IO [a]
processParts partIter = iterateeDebugWrapper "processParts" $ go id
where
iter = do
isLast <- bParser
if isLast
then return Nothing
else do
!x <- partIter
skipToEof
return $! Just x
go !soFar = do
b <- isEOF
if b
then return $ soFar []
else do
innerStep <- processPart $$ iter
output <- lift $ run_ $ returnI innerStep
case output of
Just x -> go (soFar . (x:))
Nothing -> return $ soFar []
bParser = iterateeDebugWrapper "boundary debugger" $
iterParser $ pBoundaryEnd
pBoundaryEnd = (eol *> pure False) <|> (string "--" *> pure True)
eol :: Parser ByteString
eol = (string "\n") <|> (string "\r\n")
pHeadersWithSeparator :: Parser [(ByteString,ByteString)]
pHeadersWithSeparator = pHeaders <* crlf
toHeaders :: [(ByteString,ByteString)] -> Headers
toHeaders kvps = H.fromList kvps'
where
kvps' = map (first CI.mk) kvps
mAX_HDRS_SIZE :: Int64
mAX_HDRS_SIZE = 32768
data UploadedFilesState = UploadedFilesState {
_currentFile :: Maybe (FilePath, Handle)
, _alreadyReadFiles :: [FilePath]
}
emptyUploadedFilesState :: UploadedFilesState
emptyUploadedFilesState = UploadedFilesState Nothing []
data UploadedFiles = UploadedFiles (IORef UploadedFilesState)
(MVar ())
newUploadedFiles :: MonadIO m => m UploadedFiles
newUploadedFiles = liftIO $ do
r <- newIORef emptyUploadedFilesState
m <- newMVar ()
let u = UploadedFiles r m
addMVarFinalizer m $ cleanupUploadedFiles u
return u
cleanupUploadedFiles :: (MonadIO m) => UploadedFiles -> m ()
cleanupUploadedFiles (UploadedFiles stateRef _) = liftIO $ do
state <- readIORef stateRef
killOpenFile state
mapM_ killFile $ _alreadyReadFiles state
writeIORef stateRef emptyUploadedFilesState
where
killFile = eatException . removeFile
killOpenFile state = maybe (return ())
(\(fp,h) -> do
eatException $ hClose h
eatException $ removeFile fp)
(_currentFile state)
openFileForUpload :: (MonadIO m) =>
UploadedFiles
-> FilePath
-> m (FilePath, Handle)
openFileForUpload ufs@(UploadedFiles stateRef _) tmpdir = liftIO $ do
state <- readIORef stateRef
when (isJust $ _currentFile state) $ do
cleanupUploadedFiles ufs
throw $ GenericFileUploadException alreadyOpenMsg
fph@(_,h) <- makeTempFile tmpdir "snap-"
hSetBuffering h NoBuffering
writeIORef stateRef $ state { _currentFile = Just fph }
return fph
where
alreadyOpenMsg =
T.concat [ "Internal error! UploadedFiles: "
, "opened new file with pre-existing open handle" ]
closeActiveFile :: (MonadIO m) => UploadedFiles -> m ()
closeActiveFile (UploadedFiles stateRef _) = liftIO $ do
state <- readIORef stateRef
let m = _currentFile state
maybe (return ())
(\(fp,h) -> do
eatException $ hClose h
writeIORef stateRef $
state { _currentFile = Nothing
, _alreadyReadFiles = fp:(_alreadyReadFiles state) })
m
eatException :: (MonadCatchIO m) => m a -> m ()
eatException m =
(m >> return ()) `catch` (\(_ :: SomeException) -> return ())
makeTempFile :: FilePath -> String -> IO (FilePath, Handle)
#ifdef USE_UNIX
makeTempFile fp temp = mkstemp $ fp </> (temp ++ "XXXXXXX")
#else
makeTempFile = openBinaryTempFile
#endif