{-# LANGUAGE StrictData #-}
{-# LANGUAGE NoFieldSelectors #-}
module Wai.CryptoCookie.Encryption
( Encryption (..)
, autoKeyFileBase16
, readKeyFileBase16
, readKeyFile
, writeKeyFile
) where
import Control.Exception qualified as Ex
import Control.Monad
import Control.Monad.IO.Class
import Crypto.Random qualified as C
import Data.Aeson qualified as Ae
import Data.Bits
import Data.ByteArray qualified as BA
import Data.ByteArray.Encoding qualified as BA
import Data.ByteArray.Sized qualified as BAS
import Data.ByteString.Lazy qualified as BL
import Data.Char qualified as Char
import Data.Kind (Type)
import Data.Text.Encoding qualified as T
import Data.Word
import GHC.TypeNats
import System.IO qualified as IO
import System.IO.Error qualified as IO
class (KnownNat (KeyLength e), Eq (Key e)) => Encryption (e :: k) where
data Key e :: Type
type KeyLength e :: Natural
data Encrypt e :: Type
data Decrypt e :: Type
genKey :: (C.MonadRandom m) => m (Key e)
keyFromBytes :: (BA.ByteArrayAccess raw) => raw -> Either String (Key e)
keyToBytes :: (BAS.ByteArrayN (KeyLength e) raw) => Key e -> raw
initial :: (C.MonadRandom m) => Key e -> m (Encrypt e, Decrypt e)
advance :: Encrypt e -> Encrypt e
encrypt :: Encrypt e -> BL.ByteString -> BL.ByteString
decrypt :: Decrypt e -> BL.ByteString -> Either String BL.ByteString
autoKeyFileBase16
:: forall e m
. (Encryption e, MonadIO m)
=> FilePath
-> m (Key e)
autoKeyFileBase16 :: forall {k} (e :: k) (m :: * -> *).
(Encryption e, MonadIO m) =>
FilePath -> m (Key e)
autoKeyFileBase16 FilePath
path = IO (Key e) -> m (Key e)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
(IOError -> Maybe ())
-> IO (Key e) -> (() -> IO (Key e)) -> IO (Key e)
forall e b a.
Exception e =>
(e -> Maybe b) -> IO a -> (b -> IO a) -> IO a
Ex.catchJust
(Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> (IOError -> Bool) -> IOError -> Maybe ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOError -> Bool
IO.isDoesNotExistError)
(FilePath -> IO (Key e)
forall {k} (e :: k) (m :: * -> *).
(Encryption e, MonadIO m) =>
FilePath -> m (Key e)
readKeyFileBase16 FilePath
path)
\()
_ -> do
Key e
k0 <- IO (Key e)
forall k (e :: k) (m :: * -> *).
(Encryption e, MonadRandom m) =>
m (Key e)
forall (m :: * -> *). MonadRandom m => m (Key e)
genKey
(SizedByteArray (KeyLength e) ScrubbedBytes -> ScrubbedBytes)
-> FilePath -> Key e -> IO ()
forall {k} (e :: k) (m :: * -> *).
(Encryption e, MonadIO m) =>
(SizedByteArray (KeyLength e) ScrubbedBytes -> ScrubbedBytes)
-> FilePath -> Key e -> m ()
writeKeyFile (Base -> SizedByteArray (KeyLength e) ScrubbedBytes -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> bout
BA.convertToBase Base
BA.Base16) FilePath
path Key e
k0
Key e
k1 <- FilePath -> IO (Key e)
forall {k} (e :: k) (m :: * -> *).
(Encryption e, MonadIO m) =>
FilePath -> m (Key e)
readKeyFileBase16 FilePath
path
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Key e
k0 Key e -> Key e -> Bool
forall a. Eq a => a -> a -> Bool
/= Key e
k1) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath -> IO ()
forall a. FilePath -> IO a
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail FilePath
"autoKeyFile: no roundtrip"
Key e -> IO (Key e)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Key e
k1
readKeyFileBase16
:: forall e m
. (Encryption e, MonadIO m)
=> FilePath
-> m (Key e)
readKeyFileBase16 :: forall {k} (e :: k) (m :: * -> *).
(Encryption e, MonadIO m) =>
FilePath -> m (Key e)
readKeyFileBase16 = (ScrubbedBytes -> Either FilePath ScrubbedBytes)
-> FilePath -> m (Key e)
forall {k} (e :: k) (m :: * -> *).
(Encryption e, MonadIO m) =>
(ScrubbedBytes -> Either FilePath ScrubbedBytes)
-> FilePath -> m (Key e)
readKeyFile \ScrubbedBytes
a ->
case (Word8 -> Bool) -> ScrubbedBytes -> (ScrubbedBytes, ScrubbedBytes)
forall bs. ByteArray bs => (Word8 -> Bool) -> bs -> (bs, bs)
BA.span (Bool -> Bool
not (Bool -> Bool) -> (Word8 -> Bool) -> Word8 -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Bool
rn) ScrubbedBytes
a of
(ScrubbedBytes
pre, ScrubbedBytes
pos)
| (Word8 -> Bool) -> ScrubbedBytes -> Bool
forall ba. ByteArrayAccess ba => (Word8 -> Bool) -> ba -> Bool
BA.all Word8 -> Bool
rn ScrubbedBytes
pos -> Base -> ScrubbedBytes -> Either FilePath ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> Either FilePath bout
BA.convertFromBase Base
BA.Base16 ScrubbedBytes
pre
| Bool
otherwise -> FilePath -> Either FilePath ScrubbedBytes
forall a b. a -> Either a b
Left FilePath
"invalid format"
where
Word8
_r :: Word8 = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Char -> Int
Char.ord Char
'\r')
Word8
_n :: Word8 = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Char -> Int
Char.ord Char
'\n')
Word8 -> Bool
rn :: Word8 -> Bool = \Word8
w -> Word8
w Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_r Bool -> Bool -> Bool
|| Word8
w Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_n
readKeyFile
:: forall e m
. (Encryption e, MonadIO m)
=> (BA.ScrubbedBytes -> Either String BA.ScrubbedBytes)
-> FilePath
-> m (Key e)
readKeyFile :: forall {k} (e :: k) (m :: * -> *).
(Encryption e, MonadIO m) =>
(ScrubbedBytes -> Either FilePath ScrubbedBytes)
-> FilePath -> m (Key e)
readKeyFile ScrubbedBytes -> Either FilePath ScrubbedBytes
g FilePath
path = IO (Key e) -> m (Key e)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
FilePath -> IOMode -> (Handle -> IO (Key e)) -> IO (Key e)
forall r. FilePath -> IOMode -> (Handle -> IO r) -> IO r
IO.withFile FilePath
path IOMode
IO.ReadMode \Handle
h -> do
Int
flen :: Int <- do
Integer
a <- Handle -> IO Integer
IO.hFileSize Handle
h
case Integer -> Maybe Int
forall a b.
(Integral a, Integral b, Bits a, Bits b) =>
a -> Maybe b
toIntegralSized Integer
a of
Just Int
b | Int
b Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 -> Int -> IO Int
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
b
Maybe Int
_ -> FilePath -> IO Int
forall a. FilePath -> IO a
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail FilePath
"readKeyFile: invalid key file size"
(Int
rlen, ScrubbedBytes
fraw) <- Int -> (Ptr Any -> IO Int) -> IO (Int, ScrubbedBytes)
forall ba p a. ByteArray ba => Int -> (Ptr p -> IO a) -> IO (a, ba)
forall p a. Int -> (Ptr p -> IO a) -> IO (a, ScrubbedBytes)
BA.allocRet Int
flen \Ptr Any
p -> Handle -> Ptr Any -> Int -> IO Int
forall a. Handle -> Ptr a -> Int -> IO Int
IO.hGetBuf Handle
h Ptr Any
p Int
flen
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
rlen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
flen) do
FilePath -> IO ()
forall a. FilePath -> IO a
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail FilePath
"readKeyFile: could not read key file"
case ScrubbedBytes -> Either FilePath ScrubbedBytes
g ScrubbedBytes
fraw of
Left FilePath
e -> FilePath -> IO (Key e)
forall a. FilePath -> IO a
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail (FilePath -> IO (Key e)) -> FilePath -> IO (Key e)
forall a b. (a -> b) -> a -> b
$ FilePath
"readKeyFile: " FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> FilePath
e
Right ScrubbedBytes
kraw -> case ScrubbedBytes -> Either FilePath (Key e)
forall raw. ByteArrayAccess raw => raw -> Either FilePath (Key e)
forall k (e :: k) raw.
(Encryption e, ByteArrayAccess raw) =>
raw -> Either FilePath (Key e)
keyFromBytes ScrubbedBytes
kraw of
Right Key e
key -> Key e -> IO (Key e)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Key e
key
Left FilePath
err -> FilePath -> IO (Key e)
forall a. FilePath -> IO a
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail (FilePath -> IO (Key e)) -> FilePath -> IO (Key e)
forall a b. (a -> b) -> a -> b
$ FilePath
"readKeyFile: " FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> FilePath
err
writeKeyFile
:: forall e m
. (Encryption e, MonadIO m)
=> (BAS.SizedByteArray (KeyLength e) BA.ScrubbedBytes -> BA.ScrubbedBytes)
-> FilePath
-> Key e
-> m ()
writeKeyFile :: forall {k} (e :: k) (m :: * -> *).
(Encryption e, MonadIO m) =>
(SizedByteArray (KeyLength e) ScrubbedBytes -> ScrubbedBytes)
-> FilePath -> Key e -> m ()
writeKeyFile SizedByteArray (KeyLength e) ScrubbedBytes -> ScrubbedBytes
g FilePath
path Key e
key = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
ScrubbedBytes
kout <- ScrubbedBytes -> IO ScrubbedBytes
forall a. a -> IO a
Ex.evaluate (ScrubbedBytes -> IO ScrubbedBytes)
-> ScrubbedBytes -> IO ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ SizedByteArray (KeyLength e) ScrubbedBytes -> ScrubbedBytes
g (SizedByteArray (KeyLength e) ScrubbedBytes -> ScrubbedBytes)
-> SizedByteArray (KeyLength e) ScrubbedBytes -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ Key e -> SizedByteArray (KeyLength e) ScrubbedBytes
forall raw. ByteArrayN (KeyLength e) raw => Key e -> raw
forall k (e :: k) raw.
(Encryption e, ByteArrayN (KeyLength e) raw) =>
Key e -> raw
keyToBytes Key e
key
FilePath -> IOMode -> (Handle -> IO ()) -> IO ()
forall r. FilePath -> IOMode -> (Handle -> IO r) -> IO r
IO.withFile FilePath
path IOMode
IO.WriteMode \Handle
h ->
ScrubbedBytes -> (Ptr Any -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
BA.withByteArray ScrubbedBytes
kout \Ptr Any
p ->
Handle -> Ptr Any -> Int -> IO ()
forall a. Handle -> Ptr a -> Int -> IO ()
IO.hPutBuf Handle
h Ptr Any
p (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ScrubbedBytes
kout
instance (Encryption e) => Ae.FromJSON (Key e) where
parseJSON :: Value -> Parser (Key e)
parseJSON = FilePath -> (Text -> Parser (Key e)) -> Value -> Parser (Key e)
forall a. FilePath -> (Text -> Parser a) -> Value -> Parser a
Ae.withText FilePath
"Key" \Text
t ->
case Base -> ByteString -> Either FilePath ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> Either FilePath bout
BA.convertFromBase Base
BA.Base16 (Text -> ByteString
T.encodeUtf8 Text
t) of
Right (ScrubbedBytes
kraw :: BA.ScrubbedBytes) ->
case ScrubbedBytes -> Either FilePath (Key e)
forall raw. ByteArrayAccess raw => raw -> Either FilePath (Key e)
forall k (e :: k) raw.
(Encryption e, ByteArrayAccess raw) =>
raw -> Either FilePath (Key e)
keyFromBytes ScrubbedBytes
kraw of
Right Key e
key -> Key e -> Parser (Key e)
forall a. a -> Parser a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Key e
key
Left FilePath
err -> FilePath -> Parser (Key e)
forall a. FilePath -> Parser a
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail FilePath
err
Either FilePath ScrubbedBytes
_ -> FilePath -> Parser (Key e)
forall a. FilePath -> Parser a
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail FilePath
"Invalid key"