{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-} -- for readElection
{-# LANGUAGE UndecidableInstances #-} -- for Reifies constraints in instances
module Voting.Protocol.Election where

import Control.DeepSeq (NFData)
import Control.Monad (Monad(..), mapM, zipWithM)
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.Except (ExceptT(..), runExcept, throwE, withExceptT)
import Data.Aeson (ToJSON(..),FromJSON(..),(.:),(.:?),(.=))
import Data.Bool
import Data.Either (either)
import Data.Eq (Eq(..))
import Data.Foldable (foldMap, and)
import Data.Function (($), (.), id, const)
import Data.Functor ((<$>))
import Data.Functor.Identity (Identity(..))
import Data.Maybe (Maybe(..), maybe, fromJust, fromMaybe)
import Data.Monoid (Monoid(..))
import Data.Ord (Ord(..))
import Data.Proxy (Proxy(..))
import Data.Reflection (Reifies(..), reify)
import Data.Semigroup (Semigroup(..))
import Data.String (String)
import Data.Text (Text)
import Data.Traversable (Traversable(..))
import Data.Tuple (fst, snd)
import GHC.Generics (Generic)
import GHC.Natural (minusNaturalMaybe)
import Numeric.Natural (Natural)
import Prelude (fromIntegral)
import System.IO (IO, FilePath)
import System.Random (RandomGen)
import Text.Show (Show(..))
import qualified Control.Monad.Trans.State.Strict as S
import qualified Data.Aeson as JSON
import qualified Data.Aeson.Encoding as JSON
import qualified Data.Aeson.Internal as JSON
import qualified Data.Aeson.Parser.Internal as JSON
import qualified Data.Aeson.Types as JSON
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import qualified Data.List as List

import Voting.Protocol.Utils
import Voting.Protocol.Arithmetic
import Voting.Protocol.Version
import Voting.Protocol.Credential
import Voting.Protocol.Cryptography

-- * Type 'Question'
data Question v = Question
 { question_text    :: !Text
 , question_choices :: ![Text]
 , question_mini    :: !Natural
 , question_maxi    :: !Natural
 -- , question_blank :: Maybe Bool
 } deriving (Eq,Show,Generic,NFData)
instance Reifies v Version => ToJSON (Question v) where
        toJSON Question{..} =
                JSON.object
                 [ "question" .= question_text
                 , "answers"  .= question_choices
                 , "min"      .= question_mini
                 , "max"      .= question_maxi
                 ]
        toEncoding Question{..} =
                JSON.pairs
                 (  "question" .= question_text
                 <> "answers"  .= question_choices
                 <> "min"      .= question_mini
                 <> "max"      .= question_maxi
                 )
instance Reifies v Version => FromJSON (Question v) where
        parseJSON = JSON.withObject "Question" $ \o -> do
                question_text    <- o .: "question"
                question_choices <- o .: "answers"
                question_mini    <- o .: "min"
                question_maxi    <- o .: "max"
                return Question{..}

-- * Type 'Answer'
data Answer crypto v c = Answer
 { answer_opinions :: ![(Encryption crypto v c, DisjProof crypto v c)]
   -- ^ Encrypted 'Opinion' for each 'question_choices'
   -- with a 'DisjProof' that they belong to [0,1].
 , answer_sumProof :: !(DisjProof crypto v c)
   -- ^ Proofs that the sum of the 'Opinon's encrypted in 'answer_opinions'
   -- is an element of @[mini..maxi]@.
 -- , answer_blankProof ::
 } deriving (Generic)
deriving instance Eq (G crypto c) => Eq (Answer crypto v c)
deriving instance (Show (G crypto c), Show (G crypto c)) => Show (Answer crypto v c)
deriving instance NFData (G crypto c) => NFData (Answer crypto v c)
instance
 ( Reifies v Version
 , CryptoParams crypto c
 ) => ToJSON (Answer crypto v c) where
        toJSON Answer{..} =
                let (answer_choices, answer_individual_proofs) =
                        List.unzip answer_opinions in
                JSON.object
                 [ "choices"           .= answer_choices
                 , "individual_proofs" .= answer_individual_proofs
                 , "overall_proof"     .= answer_sumProof
                 ]
        toEncoding Answer{..} =
                let (answer_choices, answer_individual_proofs) =
                        List.unzip answer_opinions in
                JSON.pairs
                 (  "choices"           .= answer_choices
                 <> "individual_proofs" .= answer_individual_proofs
                 <> "overall_proof"     .= answer_sumProof
                 )
instance
 ( Reifies v Version
 , CryptoParams crypto c
 ) => FromJSON (Answer crypto v c) where
        parseJSON = JSON.withObject "Answer" $ \o -> do
                answer_choices <- o .: "choices"
                answer_individual_proofs <- o .: "individual_proofs"
                let answer_opinions = List.zip answer_choices answer_individual_proofs
                answer_sumProof <- o .: "overall_proof"
                return Answer{..}

-- | @('encryptAnswer' elecPubKey zkp quest opinions)@
-- returns an 'Answer' validable by 'verifyAnswer',
-- unless an 'ErrorAnswer' is returned.
encryptAnswer ::
 Reifies v Version =>
 CryptoParams crypto c =>
 Monad m => RandomGen r =>
 PublicKey crypto c -> ZKP ->
 Question v -> [Bool] ->
 S.StateT r (ExceptT ErrorAnswer m) (Answer crypto v c)
encryptAnswer (elecPubKey::PublicKey crypto c) zkp Question{..} opinionByChoice
 | not (question_mini <= opinionsSum && opinionsSum <= question_maxi) =
        lift $ throwE $
                ErrorAnswer_WrongSumOfOpinions opinionsSum question_mini question_maxi
 | List.length opinions /= List.length question_choices =
        lift $ throwE $
                ErrorAnswer_WrongNumberOfOpinions
                 (fromIntegral $ List.length opinions)
                 (fromIntegral $ List.length question_choices)
 | otherwise = do
        encryptions <- encrypt elecPubKey `mapM` opinions
        individualProofs <- zipWithM
         (\opinion -> proveEncryption elecPubKey zkp $
                if opinion
                then (List.init booleanDisjunctions,[])
                else ([],List.tail booleanDisjunctions))
         opinionByChoice encryptions
        sumProof <- proveEncryption elecPubKey zkp
         (List.tail <$> List.genericSplitAt
                 (fromJust $ opinionsSum`minusNaturalMaybe`question_mini)
                 (intervalDisjunctions question_mini question_maxi))
         ( sum (fst <$> encryptions) -- NOTE: sum the 'encNonce's
         , sum (snd <$> encryptions) -- NOTE: sum the 'Encryption's
         )
        return $ Answer
         { answer_opinions = List.zip
                 (snd <$> encryptions) -- NOTE: drop encNonce
                 individualProofs
         , answer_sumProof = sumProof
         }
 where
        opinionsSum = sum $ nat <$> opinions
        opinions = (\o -> if o then one else zero) <$> opinionByChoice

verifyAnswer ::
 Reifies v Version =>
 CryptoParams crypto c =>
 PublicKey crypto c -> ZKP ->
 Question v -> Answer crypto v c -> Bool
verifyAnswer (elecPubKey::PublicKey crypto c) zkp Question{..} Answer{..}
 | List.length question_choices /= List.length answer_opinions = False
 | otherwise = do
        either (const False) id $ runExcept $ do
                validOpinions <-
                        verifyEncryption elecPubKey zkp booleanDisjunctions
                         `traverse` answer_opinions
                validSum <- verifyEncryption elecPubKey zkp
                 (intervalDisjunctions question_mini question_maxi)
                 ( sum (fst <$> answer_opinions)
                 , answer_sumProof )
                return (and validOpinions && validSum)

-- ** Type 'ErrorAnswer'
-- | Error raised by 'encryptAnswer'.
data ErrorAnswer
 =   ErrorAnswer_WrongNumberOfOpinions Natural Natural
     -- ^ When the number of opinions is different than
     -- the number of choices ('question_choices').
 |   ErrorAnswer_WrongSumOfOpinions Natural Natural Natural
     -- ^ When the sum of opinions is not within the bounds
     -- of 'question_mini' and 'question_maxi'.
 deriving (Eq,Show,Generic,NFData)

-- ** Type 'Opinion'
-- | Index of a 'Disjunction' within a list of them.
-- It is encrypted as a 'GroupExponent' by 'encrypt'.
type Opinion = E

-- * Type 'Election'
data Election crypto v c = Election
 { election_name        :: !Text
 , election_description :: !Text
 , election_questions   :: ![Question v]
 , election_uuid        :: !UUID
 , election_hash        :: Base64SHA256
 , election_crypto      :: !crypto
 , election_version     :: !(Maybe Version)
 , election_public_key  :: !(PublicKey crypto c)
 } deriving (Generic)
deriving instance (Eq crypto, Eq (G crypto c)) => Eq (Election crypto v c)
deriving instance (Show crypto, Show (G crypto c)) => Show (Election crypto v c)
deriving instance (NFData crypto, NFData (G crypto c)) => NFData (Election crypto v c)
instance
 ( Reifies v Version
 , CryptoParams crypto c
 , ToJSON crypto
 ) => ToJSON (Election crypto v c) where
        toJSON Election{..} =
                JSON.object $
                 [ "name" .= election_name
                 , "description" .= election_description
                 , ("public_key", JSON.object
                         [ "group" .= election_crypto
                         , "y" .= election_public_key
                         ])
                 , "questions" .= election_questions
                 , "uuid" .= election_uuid
                 ] <>
                 maybe [] (\version -> [ "version" .= version ]) election_version
        toEncoding Election{..} =
                JSON.pairs $
                 (  "name" .= election_name
                 <> "description" .= election_description
                 <> JSON.pair "public_key" (JSON.pairs $
                        "group" .= election_crypto
                        <> "y" .= election_public_key
                 )
                 <> "questions" .= election_questions
                 <> "uuid" .= election_uuid
                 ) <>
                 maybe mempty ("version" .=) election_version

hashElection ::
 Reifies v Version =>
 CryptoParams crypto c =>
 ToJSON crypto =>
 Election crypto v c -> Base64SHA256
hashElection = base64SHA256 . BSL.toStrict . JSON.encode

readElection ::
 forall crypto r.
 FromJSON crypto =>
 ReifyCrypto crypto =>
 FilePath ->
 (forall v c.
        Reifies v Version =>
        CryptoParams crypto c =>
        Election crypto v c -> r) ->
 ExceptT String IO r
readElection filePath k = do
        fileData <- lift $ BS.readFile filePath
        ExceptT $ return $
                jsonEitherFormatError $
                        JSON.eitherDecodeStrictWith JSON.jsonEOF
                         (JSON.iparse (parseElection fileData))
                         fileData
        where
        parseElection fileData = JSON.withObject "Election" $ \o -> do
                election_version <- o .:? "version"
                reify (fromMaybe stableVersion election_version) $ \(_v::Proxy v) -> do
                        (election_crypto, elecPubKey) <-
                                JSON.explicitParseField
                                 (JSON.withObject "public_key" $ \obj -> do
                                                crypto <- obj .: "group"
                                                pubKey :: JSON.Value <- obj .: "y"
                                                return (crypto, pubKey)
                                 ) o "public_key"
                        reifyCrypto election_crypto $ \(_c::Proxy c) -> do
                                election_name <- o .: "name"
                                election_description <- o .: "description"
                                election_questions <- o .: "questions" :: JSON.Parser [Question v]
                                election_uuid <- o .: "uuid"
                                election_public_key :: PublicKey crypto c <- parseJSON elecPubKey
                                return $ k $ Election
                                 { election_questions  = election_questions
                                 , election_public_key = election_public_key
                                 , election_hash       = base64SHA256 fileData
                                 , ..
                                 }

-- * Type 'Ballot'
data Ballot crypto v c = Ballot
 { ballot_answers       :: ![Answer crypto v c]
 , ballot_signature     :: !(Maybe (Signature crypto v c))
 , ballot_election_uuid :: !UUID
 , ballot_election_hash :: !Base64SHA256
 } deriving (Generic)
deriving instance (NFData (G crypto c), NFData crypto) => NFData (Ballot crypto v c)
instance
 ( Reifies v Version
 , CryptoParams crypto c
 , ToJSON (G crypto c)
 ) => ToJSON (Ballot crypto v c) where
        toJSON Ballot{..} =
                JSON.object $
                 [ "answers"       .= ballot_answers
                 , "election_uuid" .= ballot_election_uuid
                 , "election_hash" .= ballot_election_hash
                 ] <>
                 maybe [] (\sig -> [ "signature" .= sig ]) ballot_signature
        toEncoding Ballot{..} =
                JSON.pairs $
                 (  "answers"       .= ballot_answers
                 <> "election_uuid" .= ballot_election_uuid
                 <> "election_hash" .= ballot_election_hash
                 ) <>
                 maybe mempty ("signature" .=) ballot_signature
instance
 ( Reifies v Version
 , CryptoParams crypto c
 ) => FromJSON (Ballot crypto v c) where
        parseJSON = JSON.withObject "Ballot" $ \o -> do
                ballot_answers       <- o .: "answers"
                ballot_signature     <- o .:? "signature"
                ballot_election_uuid <- o .: "election_uuid"
                ballot_election_hash <- o .: "election_hash"
                return Ballot{..}

-- | @('encryptBallot' c ('Just' ballotSecKey) opinionsByQuest)@
-- returns a 'Ballot' signed by 'secKey' (the voter's secret key)
-- where 'opinionsByQuest' is a list of 'Opinion's
-- on each 'question_choices' of each 'election_questions'.
encryptBallot ::
 Reifies v Version =>
 CryptoParams crypto c => Key crypto =>
 Monad m => RandomGen r =>
 Election crypto v c ->
 Maybe (SecretKey crypto c) -> [[Bool]] ->
 S.StateT r (ExceptT ErrorBallot m) (Ballot crypto v c)
encryptBallot (Election{..}::Election crypto v c) ballotSecKeyMay opinionsByQuest
 | List.length election_questions /= List.length opinionsByQuest =
        lift $ throwE $
                ErrorBallot_WrongNumberOfAnswers
                 (fromIntegral $ List.length opinionsByQuest)
                 (fromIntegral $ List.length election_questions)
 | otherwise = do
        let (voterKeys, voterZKP) =
                case ballotSecKeyMay of
                 Nothing -> (Nothing, ZKP "")
                 Just ballotSecKey ->
                        ( Just (ballotSecKey, ballotPubKey)
                        , ZKP (bytesNat ballotPubKey) )
                        where ballotPubKey = publicKey ballotSecKey
        ballot_answers <-
                S.mapStateT (withExceptT ErrorBallot_Answer) $
                        zipWithM (encryptAnswer election_public_key voterZKP)
                         election_questions opinionsByQuest
        ballot_signature <- case voterKeys of
         Nothing -> return Nothing
         Just (ballotSecKey, signature_publicKey) -> do
                signature_proof <-
                        proveQuicker ballotSecKey (Identity groupGen) $
                         \(Identity commitment) ->
                                hash @crypto
                                 -- NOTE: the order is unusual, the commitments are first
                                 -- then comes the statement. Best guess is that
                                 -- this is easier to code due to their respective types.
                                 (ballotCommitments @crypto voterZKP commitment)
                                 (ballotStatement @crypto ballot_answers)
                return $ Just Signature{..}
        return Ballot
         { ballot_answers
         , ballot_election_hash = election_hash
         , ballot_election_uuid = election_uuid
         , ballot_signature
         }

verifyBallot ::
 Reifies v Version =>
 CryptoParams crypto c =>
 Election crypto v c ->
 Ballot crypto v c -> Bool
verifyBallot (Election{..}::Election crypto v c) Ballot{..} =
        ballot_election_uuid == election_uuid &&
        ballot_election_hash == election_hash &&
        List.length election_questions == List.length ballot_answers &&
        let (isValidSign, zkpSign) =
                case ballot_signature of
                 Nothing -> (True, ZKP "")
                 Just Signature{..} ->
                        let zkp = ZKP (bytesNat signature_publicKey) in
                        (, zkp) $
                                proof_challenge signature_proof == hash
                                 (ballotCommitments @crypto zkp (commitQuicker signature_proof groupGen signature_publicKey))
                                 (ballotStatement @crypto ballot_answers)
        in
        and $ isValidSign :
                List.zipWith (verifyAnswer election_public_key zkpSign)
                 election_questions ballot_answers


-- ** Type 'ErrorBallot'
-- | Error raised by 'encryptBallot'.
data ErrorBallot
 =   ErrorBallot_WrongNumberOfAnswers Natural Natural
     -- ^ When the number of answers
     -- is different than the number of questions.
 |   ErrorBallot_Answer ErrorAnswer
     -- ^ When 'encryptAnswer' raised an 'ErrorAnswer'.
 |   ErrorBallot_Wrong
     -- ^ TODO: to be more precise.
 deriving (Eq,Show,Generic,NFData)

-- ** Hashing

-- | @('ballotStatement' ballot)@
-- returns the encrypted material to be signed:
-- all the 'encryption_nonce's and 'encryption_vault's of the given 'ballot_answers'.
ballotStatement :: CryptoParams crypto c => [Answer crypto v c] -> [G crypto c]
ballotStatement =
        foldMap $ \Answer{..} ->
                (`foldMap` answer_opinions) $ \(Encryption{..}, _proof) ->
                        [encryption_nonce, encryption_vault]

-- | @('ballotCommitments' voterZKP commitment)@
ballotCommitments ::
 CryptoParams crypto c =>
 ToNatural (G crypto c) =>
 ZKP -> Commitment crypto c -> BS.ByteString
ballotCommitments (ZKP voterZKP) commitment =
        "sig|"<>voterZKP<>"|" -- NOTE: this is actually part of the statement
         <> bytesNat commitment<>"|"