module OpenSSL.EVP.Base64
(
encodeBase64
, encodeBase64BS
, encodeBase64LBS
, decodeBase64
, decodeBase64BS
, decodeBase64LBS
)
where
import Control.Exception (assert)
import Data.ByteString.Internal (createAndTrim)
import Data.ByteString.Unsafe (unsafeUseAsCStringLen)
import qualified Data.ByteString.Lazy.Internal as L8Internal
import qualified Data.ByteString.Char8 as B8
import qualified Data.ByteString.Lazy.Char8 as L8
import Data.List
#if MIN_VERSION_base(4,5,0)
import Foreign.C.Types (CChar(..), CInt(..))
#else
import Foreign.C.Types (CChar, CInt)
#endif
import Foreign.Ptr (Ptr, castPtr)
import System.IO.Unsafe (unsafePerformIO)
nextBlock :: Int -> ([B8.ByteString], L8.ByteString) -> ([B8.ByteString], L8.ByteString)
nextBlock minLen (xs, src)
= if foldl' (+) 0 (map B8.length xs) >= minLen then
(xs, src)
else
case src of
L8Internal.Empty -> (xs, src)
L8Internal.Chunk y ys -> nextBlock minLen (xs ++ [y], ys)
foreign import ccall unsafe "EVP_EncodeBlock"
_EncodeBlock :: Ptr CChar -> Ptr CChar -> CInt -> IO CInt
encodeBlock :: B8.ByteString -> B8.ByteString
encodeBlock inBS
= unsafePerformIO $
unsafeUseAsCStringLen inBS $ \ (inBuf, inLen) ->
createAndTrim maxOutLen $ \ outBuf ->
fmap fromIntegral
(_EncodeBlock (castPtr outBuf) inBuf (fromIntegral inLen))
where
maxOutLen = (inputLen `div` 3 + 1) * 4 + 1
inputLen = B8.length inBS
encodeBase64 :: String -> String
encodeBase64 = L8.unpack . encodeBase64LBS . L8.pack
encodeBase64BS :: B8.ByteString -> B8.ByteString
encodeBase64BS = encodeBlock
encodeBase64LBS :: L8.ByteString -> L8.ByteString
encodeBase64LBS inLBS
| L8.null inLBS = L8.empty
| otherwise
= let (blockParts', remain' ) = nextBlock 3 ([], inLBS)
block' = B8.concat blockParts'
blockLen' = B8.length block'
(block , leftover) = if blockLen' < 3 then
(block', B8.empty)
else
B8.splitAt (blockLen' blockLen' `mod` 3) block'
remain = if B8.null leftover then
remain'
else
L8.fromChunks [leftover] `L8.append` remain'
encodedBlock = encodeBlock block
encodedRemain = encodeBase64LBS remain
in
L8.fromChunks [encodedBlock] `L8.append` encodedRemain
foreign import ccall unsafe "EVP_DecodeBlock"
_DecodeBlock :: Ptr CChar -> Ptr CChar -> CInt -> IO CInt
decodeBlock :: B8.ByteString -> B8.ByteString
decodeBlock inBS
= assert (B8.length inBS `mod` 4 == 0) $
unsafePerformIO $
unsafeUseAsCStringLen inBS $ \ (inBuf, inLen) ->
createAndTrim (B8.length inBS) $ \ outBuf ->
_DecodeBlock (castPtr outBuf) inBuf (fromIntegral inLen)
>>= \ outLen -> return (fromIntegral outLen paddingLen)
where
paddingLen :: Int
paddingLen = B8.count '=' inBS
decodeBase64 :: String -> String
decodeBase64 = L8.unpack . decodeBase64LBS . L8.pack
decodeBase64BS :: B8.ByteString -> B8.ByteString
decodeBase64BS = decodeBlock
decodeBase64LBS :: L8.ByteString -> L8.ByteString
decodeBase64LBS inLBS
| L8.null inLBS = L8.empty
| otherwise
= let (blockParts', remain' ) = nextBlock 4 ([], inLBS)
block' = B8.concat blockParts'
blockLen' = B8.length block'
(block , leftover) = assert (blockLen' >= 4) $
B8.splitAt (blockLen' blockLen' `mod` 4) block'
remain = if B8.null leftover then
remain'
else
L8.fromChunks [leftover] `L8.append` remain'
decodedBlock = decodeBlock block
decodedRemain = decodeBase64LBS remain
in
L8.fromChunks [decodedBlock] `L8.append` decodedRemain