{-# LANGUAGE ScopedTypeVariables #-}
module Crypto.JOSE.AESKW
(
aesKeyWrap
, aesKeyUnwrap
) where
import Control.Monad (join)
import Control.Monad.State (StateT, execStateT, get, lift, put)
import Crypto.Cipher.Types
import Data.Bits (xor)
import Data.ByteArray as BA hiding (replicate, xor)
import Data.Memory.Endian (BE(..), toBE)
import Data.Memory.PtrMethods (memCopy)
import Data.Word (Word64)
import Foreign.Ptr (Ptr, plusPtr)
import Foreign.Storable (peek, peekElemOff, poke, pokeElemOff)
import System.IO.Unsafe (unsafePerformIO)
iv :: Word64
iv :: Word64
iv = Word64
0xA6A6A6A6A6A6A6A6
aesKeyWrapStep
:: BlockCipher128 cipher
=> cipher
-> Ptr Word64
-> (Int, Int)
-> StateT Word64 IO ()
aesKeyWrapStep :: forall cipher.
BlockCipher128 cipher =>
cipher -> Ptr Word64 -> (Int, Int) -> StateT Word64 IO ()
aesKeyWrapStep cipher
cipher Ptr Word64
p (Int
t, Int
i) = do
Word64
a <- forall s (m :: * -> *). MonadState s m => m s
get
Word64
r_i <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr Word64
p Int
i
ScrubbedBytes
m :: ScrubbedBytes <-
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
alloc Int
16 forall a b. (a -> b) -> a -> b
$ \Ptr Word64
p' -> forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word64
p' Word64
a forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr Word64
p' Int
1 Word64
r_i
let b :: ScrubbedBytes
b = forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> ba -> ba
ecbEncrypt cipher
cipher ScrubbedBytes
m
Word64
b_hi <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray ScrubbedBytes
b forall a. Storable a => Ptr a -> IO a
peek
Word64
b_lo <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray ScrubbedBytes
b (forall a. Storable a => Ptr a -> Int -> IO a
`peekElemOff` Int
1)
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Word64
b_hi forall a. Bits a => a -> a -> a
`xor` forall a. BE a -> a
unBE (forall a. ByteSwap a => a -> BE a
toBE (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
t)))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr Word64
p Int
i Word64
b_lo
aesKeyWrap
:: (ByteArrayAccess m, ByteArray c, BlockCipher128 cipher)
=> cipher
-> m
-> c
aesKeyWrap :: forall m c cipher.
(ByteArrayAccess m, ByteArray c, BlockCipher128 cipher) =>
cipher -> m -> c
aesKeyWrap cipher
cipher m
m = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
let n :: Int
n = forall ba. ByteArrayAccess ba => ba -> Int
BA.length m
m
c
c <- forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray m
m forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p ->
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
alloc (Int
n forall a. Num a => a -> a -> a
+ Int
8) forall a b. (a -> b) -> a -> b
$ \Ptr Any
p' ->
Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy (Ptr Any
p' forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
8) Ptr Word8
p Int
n
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray c
c forall a b. (a -> b) -> a -> b
$ \Ptr Word64
p -> do
let coords :: [(Int, Int)]
coords = forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1..] (forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (forall a. Int -> a -> [a]
replicate Int
6 [Int
1 .. Int
n forall a. Integral a => a -> a -> a
`div` Int
8]))
Word64
a <- forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT (forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall cipher.
BlockCipher128 cipher =>
cipher -> Ptr Word64 -> (Int, Int) -> StateT Word64 IO ()
aesKeyWrapStep cipher
cipher Ptr Word64
p) [(Int, Int)]
coords) Word64
iv
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word64
p Word64
a
forall (m :: * -> *) a. Monad m => a -> m a
return c
c
aesKeyUnwrapStep
:: BlockCipher128 cipher
=> cipher
-> Ptr Word64
-> (Int, Int)
-> StateT Word64 IO ()
aesKeyUnwrapStep :: forall cipher.
BlockCipher128 cipher =>
cipher -> Ptr Word64 -> (Int, Int) -> StateT Word64 IO ()
aesKeyUnwrapStep cipher
cipher Ptr Word64
p (Int
t, Int
i) = do
Word64
a <- forall s (m :: * -> *). MonadState s m => m s
get
Word64
r_i <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr Word64
p Int
i
let a_t :: Word64
a_t = Word64
a forall a. Bits a => a -> a -> a
`xor` forall a. BE a -> a
unBE (forall a. ByteSwap a => a -> BE a
toBE (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
t))
ScrubbedBytes
m :: ScrubbedBytes <-
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
alloc Int
16 forall a b. (a -> b) -> a -> b
$ \Ptr Word64
p' -> forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word64
p' Word64
a_t forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr Word64
p' Int
1 Word64
r_i
let b :: ScrubbedBytes
b = forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> ba -> ba
ecbDecrypt cipher
cipher ScrubbedBytes
m
Word64
b_hi <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray ScrubbedBytes
b forall a. Storable a => Ptr a -> IO a
peek
Word64
b_lo <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray ScrubbedBytes
b (forall a. Storable a => Ptr a -> Int -> IO a
`peekElemOff` Int
1)
forall s (m :: * -> *). MonadState s m => s -> m ()
put Word64
b_hi
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr Word64
p Int
i Word64
b_lo
aesKeyUnwrap
:: (ByteArrayAccess c, ByteArray m, BlockCipher128 cipher)
=> cipher
-> c
-> Maybe m
aesKeyUnwrap :: forall c m cipher.
(ByteArrayAccess c, ByteArray m, BlockCipher128 cipher) =>
cipher -> c -> Maybe m
aesKeyUnwrap cipher
cipher c
c = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
let n :: Int
n = forall ba. ByteArrayAccess ba => ba -> Int
BA.length c
c forall a. Num a => a -> a -> a
- Int
8
m
m <- forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray c
c forall a b. (a -> b) -> a -> b
$ \Ptr Any
p' ->
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
alloc Int
n forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p ->
Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
p (Ptr Any
p' forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
8) Int
n
Word64
a <- forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray c
c forall a b. (a -> b) -> a -> b
$ \Ptr Word64
p' -> forall a. Storable a => Ptr a -> IO a
peek Ptr Word64
p'
Word64
a' <- forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray m
m forall a b. (a -> b) -> a -> b
$ \Ptr Word64
p -> do
let n' :: Int
n' = Int
n forall a. Integral a => a -> a -> a
`div` Int
8
let tMax :: Int
tMax = Int
n' forall a. Num a => a -> a -> a
* Int
6
let coords :: [(Int, Int)]
coords = forall a b. [a] -> [b] -> [(a, b)]
zip [Int
tMax,Int
tMaxforall a. Num a => a -> a -> a
-Int
1..Int
1] (forall a. [a] -> [a]
cycle [Int
n'forall a. Num a => a -> a -> a
-Int
1,Int
n'forall a. Num a => a -> a -> a
-Int
2..Int
0])
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT (forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall cipher.
BlockCipher128 cipher =>
cipher -> Ptr Word64 -> (Int, Int) -> StateT Word64 IO ()
aesKeyUnwrapStep cipher
cipher Ptr Word64
p) [(Int, Int)]
coords) Word64
a
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ if Word64
a' forall a. Eq a => a -> a -> Bool
== Word64
iv then forall a. a -> Maybe a
Just m
m else forall a. Maybe a
Nothing