{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DoAndIfThenElse #-}
module Data.ByteString.Base64.Internal
( encodeWith
, decodeWithTable
, decodeLenientWithTable
, mkEncodeTable
, done
, peek8, poke8, peek8_32
, reChunkIn
, Padding(..)
, withBS
, mkBS
) where
import Data.Bits ((.|.), (.&.), shiftL, shiftR)
import qualified Data.ByteString as B
import Data.ByteString.Internal (ByteString(..), mallocByteString)
import Data.Word (Word8, Word16, Word32)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr, castForeignPtr)
import Foreign.Ptr (Ptr, castPtr, minusPtr, plusPtr)
import Foreign.Storable (peek, peekElemOff, poke)
import System.IO.Unsafe (unsafePerformIO)
peek8 :: Ptr Word8 -> IO Word8
peek8 = peek
poke8 :: Ptr Word8 -> Word8 -> IO ()
poke8 = poke
peek8_32 :: Ptr Word8 -> IO Word32
peek8_32 = fmap fromIntegral . peek8
data Padding = Padded | Don'tCare | Unpadded deriving Eq
encodeWith :: Padding -> EncodeTable -> ByteString -> ByteString
encodeWith !padding !(ET alfaFP encodeTable) !bs = withBS bs go
where
go !sptr !slen
| slen > maxBound `div` 4 =
error "Data.ByteString.Base64.encode: input too long"
| otherwise = do
let dlen = ((slen + 2) `div` 3) * 4
dfp <- mallocByteString dlen
withForeignPtr alfaFP $ \aptr ->
withForeignPtr encodeTable $ \ep -> do
let aidx n = peek8 (aptr `plusPtr` n)
sEnd = sptr `plusPtr` slen
finish !n = return $ mkBS dfp n
fill !dp !sp !n
| sp `plusPtr` 2 >= sEnd = complete (castPtr dp) sp n
| otherwise = {-# SCC "encode/fill" #-} do
i <- peek8_32 sp
j <- peek8_32 (sp `plusPtr` 1)
k <- peek8_32 (sp `plusPtr` 2)
let w = (i `shiftL` 16) .|. (j `shiftL` 8) .|. k
enc = peekElemOff ep . fromIntegral
poke dp =<< enc (w `shiftR` 12)
poke (dp `plusPtr` 2) =<< enc (w .&. 0xfff)
fill (dp `plusPtr` 4) (sp `plusPtr` 3) (n + 4)
complete dp sp n
| sp == sEnd = finish n
| otherwise = {-# SCC "encode/complete" #-} do
let peekSP m f = (f . fromIntegral) `fmap` peek8 (sp `plusPtr` m)
twoMore = sp `plusPtr` 2 == sEnd
equals = 0x3d :: Word8
doPad = padding == Padded
{-# INLINE equals #-}
!a <- peekSP 0 ((`shiftR` 2) . (.&. 0xfc))
!b <- peekSP 0 ((`shiftL` 4) . (.&. 0x03))
poke8 dp =<< aidx a
if twoMore
then do
!b' <- peekSP 1 ((.|. b) . (`shiftR` 4) . (.&. 0xf0))
!c <- aidx =<< peekSP 1 ((`shiftL` 2) . (.&. 0x0f))
poke8 (dp `plusPtr` 1) =<< aidx b'
poke8 (dp `plusPtr` 2) c
if doPad
then poke8 (dp `plusPtr` 3) equals >> finish (n + 4)
else finish (n + 3)
else do
poke8 (dp `plusPtr` 1) =<< aidx b
if doPad
then do
poke8 (dp `plusPtr` 2) equals
poke8 (dp `plusPtr` 3) equals
finish (n + 4)
else finish (n + 2)
withForeignPtr dfp $! \dptr -> fill (castPtr dptr) sptr 0
data EncodeTable = ET !(ForeignPtr Word8) !(ForeignPtr Word16)
mkEncodeTable :: ByteString -> EncodeTable
#if MIN_VERSION_bytestring(0,11,0)
mkEncodeTable alphabet@(BS afp _) =
case table of BS fp _ -> ET afp (castForeignPtr fp)
#else
mkEncodeTable alphabet@(PS afp _ _) =
case table of PS fp _ _ -> ET afp (castForeignPtr fp)
#endif
where
ix = fromIntegral . B.index alphabet
table = B.pack $ concat $ [ [ix j, ix k] | j <- [0..63], k <- [0..63] ]
decodeWithTable :: Padding -> ForeignPtr Word8 -> ByteString -> Either String ByteString
decodeWithTable padding !decodeFP bs
| B.length bs == 0 = Right B.empty
| otherwise = case padding of
Padded
| r == 0 -> withBS bs go
| r == 1 -> Left "Base64-encoded bytestring has invalid size"
| otherwise -> Left "Base64-encoded bytestring is unpadded or has invalid padding"
Don'tCare
| r == 0 -> withBS bs go
| r == 2 -> withBS (B.append bs (B.replicate 2 0x3d)) go
| r == 3 -> validateLastPad bs invalidPad $ withBS (B.append bs (B.replicate 1 0x3d)) go
| otherwise -> Left "Base64-encoded bytestring has invalid size"
Unpadded
| r == 0 -> validateLastPad bs noPad $ withBS bs go
| r == 2 -> validateLastPad bs noPad $ withBS (B.append bs (B.replicate 2 0x3d)) go
| r == 3 -> validateLastPad bs noPad $ withBS (B.append bs (B.replicate 1 0x3d)) go
| otherwise -> Left "Base64-encoded bytestring has invalid size"
where
(!q, !r) = (B.length bs) `divMod` 4
noPad = "Base64-encoded bytestring required to be unpadded"
invalidPad = "Base64-encoded bytestring has invalid padding"
!dlen = q * 3
go !sptr !slen = do
dfp <- mallocByteString dlen
withForeignPtr decodeFP $! \ !decptr ->
withForeignPtr dfp $! \dptr ->
decodeLoop decptr sptr dptr (sptr `plusPtr` slen) dfp
decodeLoop
:: Ptr Word8
-> Ptr Word8
-> Ptr Word8
-> Ptr Word8
-> ForeignPtr Word8
-> IO (Either String ByteString)
decodeLoop !dtable !sptr !dptr !end !dfp = go dptr sptr
where
err p = return . Left
$ "invalid character at offset: "
++ show (p `minusPtr` sptr)
padErr p = return . Left
$ "invalid padding at offset: "
++ show (p `minusPtr` sptr)
canonErr p = return . Left
$ "non-canonical encoding detected at offset: "
++ show (p `minusPtr` sptr)
look :: Ptr Word8 -> IO Word32
look !p = do
!i <- peek p
!v <- peekElemOff dtable (fromIntegral i)
return (fromIntegral v)
go !dst !src
| plusPtr src 4 >= end = do
!a <- look src
!b <- look (src `plusPtr` 1)
!c <- look (src `plusPtr` 2)
!d <- look (src `plusPtr` 3)
finalChunk dst src a b c d
| otherwise = do
!a <- look src
!b <- look (src `plusPtr` 1)
!c <- look (src `plusPtr` 2)
!d <- look (src `plusPtr` 3)
decodeChunk dst src a b c d
decodeChunk !dst !src !a !b !c !d
| a == 0x63 = padErr src
| b == 0x63 = padErr (plusPtr src 1)
| c == 0x63 = padErr (plusPtr src 2)
| d == 0x63 = padErr (plusPtr src 3)
| a == 0xff = err src
| b == 0xff = err (plusPtr src 1)
| c == 0xff = err (plusPtr src 2)
| d == 0xff = err (plusPtr src 3)
| otherwise = do
let !w = ((shiftL a 18)
.|. (shiftL b 12)
.|. (shiftL c 6)
.|. d) :: Word32
poke8 dst (fromIntegral (shiftR w 16))
poke8 (plusPtr dst 1) (fromIntegral (shiftR w 8))
poke8 (plusPtr dst 2) (fromIntegral w)
go (plusPtr dst 3) (plusPtr src 4)
finalChunk !dst !src a b c d
| a == 0x63 = padErr src
| b == 0x63 = padErr (plusPtr src 1)
| c == 0x63 && d /= 0x63 = err (plusPtr src 3)
| a == 0xff = err src
| b == 0xff = err (plusPtr src 1)
| c == 0xff = err (plusPtr src 2)
| d == 0xff = err (plusPtr src 3)
| otherwise = do
let !w = ((shiftL a 18)
.|. (shiftL b 12)
.|. (shiftL c 6)
.|. d) :: Word32
poke8 dst (fromIntegral (shiftR w 16))
if c == 0x63 && d == 0x63
then
if sanityCheckPos b mask_4bits
then return $ Right $ mkBS dfp (1 + (dst `minusPtr` dptr))
else canonErr (plusPtr src 1)
else if d == 0x63
then
if sanityCheckPos c mask_2bits
then do
poke8 (plusPtr dst 1) (fromIntegral (shiftR w 8))
return $ Right $ mkBS dfp (2 + (dst `minusPtr` dptr))
else canonErr (plusPtr src 2)
else do
poke8 (plusPtr dst 1) (fromIntegral (shiftR w 8))
poke8 (plusPtr dst 2) (fromIntegral w)
return $ Right $ mkBS dfp (3 + (dst `minusPtr` dptr))
decodeLenientWithTable :: ForeignPtr Word8 -> ByteString -> ByteString
decodeLenientWithTable !decodeFP !bs = withBS bs go
where
go !sptr !slen
| dlen <= 0 = return B.empty
| otherwise = do
dfp <- mallocByteString dlen
withForeignPtr decodeFP $ \ !decptr -> do
let finish dbytes
| dbytes > 0 = return $ mkBS dfp dbytes
| otherwise = return B.empty
sEnd = sptr `plusPtr` slen
fill !dp !sp !n
| sp >= sEnd = finish n
| otherwise = {-# SCC "decodeLenientWithTable/fill" #-}
let look :: Bool -> Ptr Word8
-> (Ptr Word8 -> Word32 -> IO ByteString)
-> IO ByteString
{-# INLINE look #-}
look skipPad p0 f = go' p0
where
go' p | p >= sEnd = f (sEnd `plusPtr` (-1)) done
| otherwise = {-# SCC "decodeLenient/look" #-} do
ix <- fromIntegral `fmap` peek8 p
v <- peek8 (decptr `plusPtr` ix)
if v == x || (v == done && skipPad)
then go' (p `plusPtr` 1)
else f (p `plusPtr` 1) (fromIntegral v)
in look True sp $ \ !aNext !aValue ->
look True aNext $ \ !bNext !bValue ->
if aValue == done || bValue == done
then finish n
else
look False bNext $ \ !cNext !cValue ->
look False cNext $ \ !dNext !dValue -> do
let w = (aValue `shiftL` 18) .|. (bValue `shiftL` 12) .|.
(cValue `shiftL` 6) .|. dValue
poke8 dp $ fromIntegral (w `shiftR` 16)
if cValue == done
then finish (n + 1)
else do
poke8 (dp `plusPtr` 1) $ fromIntegral (w `shiftR` 8)
if dValue == done
then finish (n + 2)
else do
poke8 (dp `plusPtr` 2) $ fromIntegral w
fill (dp `plusPtr` 3) dNext (n+3)
withForeignPtr dfp $ \dptr -> fill dptr sptr 0
where
!dlen = ((slen + 3) `div` 4) * 3
x :: Integral a => a
x = 255
{-# INLINE x #-}
done :: Integral a => a
done = 99
{-# INLINE done #-}
reChunkIn :: Int -> [ByteString] -> [ByteString]
reChunkIn !n = go
where
go [] = []
go (y : ys) = case B.length y `divMod` n of
(_, 0) -> y : go ys
(d, _) -> case B.splitAt (d * n) y of
(prefix, suffix) -> prefix : fixup suffix ys
fixup acc [] = [acc]
fixup acc (z : zs) = case B.splitAt (n - B.length acc) z of
(prefix, suffix) ->
let acc' = acc `B.append` prefix
in if B.length acc' == n
then let zs' = if B.null suffix
then zs
else suffix : zs
in acc' : go zs'
else
fixup acc' zs
validateLastPad
:: ByteString
-> String
-> Either String ByteString
-> Either String ByteString
validateLastPad !bs err !io
| B.last bs == 0x3d = Left err
| otherwise = io
{-# INLINE validateLastPad #-}
sanityCheckPos :: Word32 -> Word8 -> Bool
sanityCheckPos pos mask = ((fromIntegral pos) .&. mask) == 0
{-# INLINE sanityCheckPos #-}
mask_2bits :: Word8
mask_2bits = 3
{-# NOINLINE mask_2bits #-}
mask_4bits :: Word8
mask_4bits = 15
{-# NOINLINE mask_4bits #-}
mkBS :: ForeignPtr Word8 -> Int -> ByteString
#if MIN_VERSION_bytestring(0,11,0)
mkBS dfp n = BS dfp n
#else
mkBS dfp n = PS dfp 0 n
#endif
{-# INLINE mkBS #-}
withBS :: ByteString -> (Ptr Word8 -> Int -> IO a) -> a
#if MIN_VERSION_bytestring(0,11,0)
withBS (BS !sfp !slen) f = unsafePerformIO $
withForeignPtr sfp $ \p -> f p slen
#else
withBS (PS !sfp !soff !slen) f = unsafePerformIO $
withForeignPtr sfp $ \p -> f (plusPtr p soff) slen
#endif
{-# INLINE withBS #-}