module FPE.FF1 (encrypt, decrypt, BlockCipher, Crypter, Tweak) where

import Data.Bits
import Control.Arrow
import Control.Monad
import Data.Tuple (swap)
import Math.NumberTheory.Logarithms
import Data.Vector.Generic (Vector)
import qualified Data.Vector.Generic as V
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L

type BlockCipher = S.ByteString -> S.ByteString
type Tweak = S.ByteString
type Crypter v a = BlockCipher -> Int -> Tweak -> v a -> v a


--  Number of bytes to store a message of given length and radix.
--  Defined in FF1 step 3 using (redundant) double ceiling.
bytesFor :: Int -> Int -> Int
bytesFor :: Int -> Int -> Int
bytesFor Int
radix Int
len =
   Integer -> Int
integerLog2' ((forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
radix forall a b. (Num a, Integral b) => a -> b -> a
^ Int
len) forall a. Num a => a -> a -> a
- Integer
1) forall a. Integral a => a -> a -> a
`div` Int
8 forall a. Num a => a -> a -> a
+ Int
1

xorBytes :: S.ByteString -> S.ByteString -> S.ByteString
xorBytes :: ByteString -> ByteString -> ByteString
xorBytes ByteString
a ByteString
b = [Word8] -> ByteString
S.pack forall a b. (a -> b) -> a -> b
$ forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
S.zipWith forall a. Bits a => a -> a -> a
xor ByteString
a ByteString
b

--  Conversion functions.

vecToNum :: (Vector v a, Integral a) => Int -> v a -> Integer
vecToNum :: forall (v :: * -> *) a.
(Vector v a, Integral a) =>
Int -> v a -> Integer
vecToNum Int
radix = forall (v :: * -> *) b a.
Vector v b =>
(a -> b -> a) -> a -> v b -> a
V.foldl forall {a} {a}. (Integral a, Num a) => a -> a -> a
go Integer
0 where
   go :: a -> a -> a
go a
val a
c = a
val forall a. Num a => a -> a -> a
* forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
radix forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral a
c

numToVec :: (Vector v a, Integral a) => Int -> Int -> Integer -> v a
numToVec :: forall (v :: * -> *) a.
(Vector v a, Integral a) =>
Int -> Int -> Integer -> v a
numToVec Int
radix Int
len Integer
num = forall (v :: * -> *) a. Vector v a => v a -> v a
V.reverse forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. Vector v a => Int -> [a] -> v a
V.fromListN Int
len forall a b. (a -> b) -> a -> b
$
   forall a b. (a -> b) -> [a] -> [b]
map (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Integral a => a -> a -> a
`mod` Integer
radix_)) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a) -> a -> [a]
iterate (forall a. Integral a => a -> a -> a
`div` Integer
radix_) Integer
num
      where radix_ :: Integer
radix_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
radix

--  Same as above, but with a ByteString of fixed radix.
--  Possibly we could use Vector Word8 instead of ByteStrings?

bytesToNum :: Integral a => S.ByteString -> a
bytesToNum :: forall a. Integral a => ByteString -> a
bytesToNum = forall a. (a -> Word8 -> a) -> a -> ByteString -> a
S.foldl (\a
val Word8
c -> a
val forall a. Num a => a -> a -> a
* a
256 forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
c) a
0
{-# SPECIALIZE bytesToNum :: S.ByteString -> Integer #-}

numToBytes :: Integral a => Int -> a -> S.ByteString
numToBytes :: forall a. Integral a => Int -> a -> ByteString
numToBytes Int
len a
num = ByteString -> ByteString
S.reverse forall a b. (a -> b) -> a -> b
$ [Word8] -> ByteString
S.pack forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$
   forall a. Int -> [a] -> [a]
take Int
len forall a b. (a -> b) -> a -> b
$ forall a. (a -> a) -> a -> [a]
iterate (forall a. Integral a => a -> a -> a
`div` a
256) a
num

--  Cipherish functions.

prf :: BlockCipher -> S.ByteString -> S.ByteString
prf :: (ByteString -> ByteString) -> ByteString -> ByteString
prf ByteString -> ByteString
cipher = ByteString -> ByteString -> ByteString
loop (Int -> Word8 -> ByteString
S.replicate Int
16 Word8
0) where
   loop :: ByteString -> ByteString -> ByteString
loop ByteString
y ByteString
src = if ByteString -> Bool
S.null ByteString
rest then ByteString
y' else ByteString -> ByteString -> ByteString
loop ByteString
y' ByteString
rest where
      (ByteString
x, ByteString
rest) = Int -> ByteString -> (ByteString, ByteString)
S.splitAt Int
16 ByteString
src
      y' :: ByteString
y' = ByteString -> ByteString
cipher forall a b. (a -> b) -> a -> b
$ ByteString
x ByteString -> ByteString -> ByteString
`xorBytes` ByteString
y

--  Extends (or shortens) a block to arbitrary length using secure hashing.
extend :: BlockCipher -> Int -> S.ByteString -> S.ByteString
extend :: (ByteString -> ByteString) -> Int -> ByteString -> ByteString
extend ByteString -> ByteString
cipher Int
len ByteString
blk = ByteString -> ByteString
L.toStrict forall a b. (a -> b) -> a -> b
$ Int64 -> ByteString -> ByteString
L.take (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
L.fromChunks forall a b. (a -> b) -> a -> b
$
   ByteString
blk forall a. a -> [a] -> [a]
: [ ByteString -> ByteString
cipher forall a b. (a -> b) -> a -> b
$ ByteString
blk ByteString -> ByteString -> ByteString
`xorBytes` forall a. Integral a => Int -> a -> ByteString
numToBytes @Int Int
16 Int
i | Int
i <- [Int
1..] ]


--  Encrypt and decrypt.

--  True for encryption, False for decryption.
crypt :: (Vector v a, Integral a) => Bool -> Crypter v a
crypt :: forall (v :: * -> *) a.
(Vector v a, Integral a) =>
Bool -> Crypter v a
crypt Bool
isEncrypt ByteString -> ByteString
cipher Int
radix ByteString
tweak v a
msg =
   forall (v :: * -> *) a.
(Vector v a, Integral a) =>
Int -> Int -> Integer -> v a
numToVec Int
radix Int
u Integer
finalA forall (v :: * -> *) a. Vector v a => v a -> v a -> v a
V.++ forall (v :: * -> *) a.
(Vector v a, Integral a) =>
Int -> Int -> Integer -> v a
numToVec Int
radix Int
v Integer
finalB where
      t :: Int
t = ByteString -> Int
S.length ByteString
tweak
      n :: Int
n = forall (v :: * -> *) a. Vector v a => v a -> Int
V.length v a
msg; u :: Int
u = Int
n forall a. Integral a => a -> a -> a
`div` Int
2; v :: Int
v = Int
n forall a. Num a => a -> a -> a
- Int
u
      b :: Int
b = Int -> Int -> Int
bytesFor Int
radix Int
v
      d :: Int
d = Int
4forall a. Num a => a -> a -> a
*((Int
bforall a. Num a => a -> a -> a
-Int
1)forall a. Integral a => a -> a -> a
`div`Int
4) forall a. Num a => a -> a -> a
+ Int
8
      rpow :: Int -> Integer
rpow = (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
radix forall a b. (Num a, Integral b) => a -> b -> a
^)
      bP :: ByteString
bP = [ByteString] -> ByteString
S.concat [
            [Word8] -> ByteString
S.pack [Word8
1, Word8
2, Word8
1], forall a. Integral a => Int -> a -> ByteString
numToBytes Int
3 Int
radix,
            [Word8] -> ByteString
S.pack [Word8
10, forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
u], forall a. Integral a => Int -> a -> ByteString
numToBytes Int
4 Int
n, forall a. Integral a => Int -> a -> ByteString
numToBytes Int
4 Int
t]
      pfxQ :: ByteString
pfxQ = ByteString
tweak forall a. Semigroup a => a -> a -> a
<> [Word8] -> ByteString
S.pack (forall a. Int -> a -> [a]
replicate ((-Int
tforall a. Num a => a -> a -> a
-Int
bforall a. Num a => a -> a -> a
-Int
1)forall a. Integral a => a -> a -> a
`mod`Int
16) Word8
0)
      (Integer
numA0, Integer
numB0) = forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
(***) (forall (v :: * -> *) a.
(Vector v a, Integral a) =>
Int -> v a -> Integer
vecToNum Int
radix) forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. Vector v a => Int -> v a -> (v a, v a)
V.splitAt Int
u v a
msg
      loop :: (Integer, a) -> Word8 -> (a, Integer)
loop (Integer
numA, a
numB) Word8
i = (a
numB, Integer
numC) where
         y :: Integer
y = forall a. Integral a => ByteString -> a
bytesToNum forall a b. (a -> b) -> a -> b
$ (ByteString -> ByteString) -> Int -> ByteString -> ByteString
extend ByteString -> ByteString
cipher Int
d forall a b. (a -> b) -> a -> b
$ (ByteString -> ByteString) -> ByteString -> ByteString
prf ByteString -> ByteString
cipher forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
S.concat [
               ByteString
bP, ByteString
pfxQ, Word8 -> ByteString
S.singleton Word8
i, forall a. Integral a => Int -> a -> ByteString
numToBytes Int
b a
numB]
         op :: Integer -> Integer
op = if Bool
isEncrypt then (forall a. Num a => a -> a -> a
+ Integer
y) else forall a. Num a => a -> a -> a
subtract Integer
y
         numC :: Integer
numC = Integer -> Integer
op Integer
numA forall a. Integral a => a -> a -> a
`mod` (if forall a. Integral a => a -> Bool
even Word8
i then Int -> Integer
rpow Int
u else Int -> Integer
rpow Int
v)
      wrap :: (a, a) -> (a, a)
wrap = if Bool
isEncrypt then forall a. a -> a
id else forall a b. (a, b) -> (b, a)
swap
      (Integer
finalA, Integer
finalB) = forall {a}. (a, a) -> (a, a)
wrap forall a b. (a -> b) -> a -> b
$
         forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {a}. Integral a => (Integer, a) -> Word8 -> (a, Integer)
loop (forall {a}. (a, a) -> (a, a)
wrap (Integer
numA0, Integer
numB0)) forall a b. (a -> b) -> a -> b
$
         if Bool
isEncrypt then [Word8
0..Word8
9] else [Word8
9,Word8
8..Word8
0]

encrypt, decrypt :: (Vector v a, Integral a) => Crypter v a
encrypt :: forall (v :: * -> *) a. (Vector v a, Integral a) => Crypter v a
encrypt = forall (v :: * -> *) a.
(Vector v a, Integral a) =>
Bool -> Crypter v a
crypt Bool
True
decrypt :: forall (v :: * -> *) a. (Vector v a, Integral a) => Crypter v a
decrypt = forall (v :: * -> *) a.
(Vector v a, Integral a) =>
Bool -> Crypter v a
crypt Bool
False