{-# LANGUAGE DataKinds            #-}
{-# LANGUAGE DeriveDataTypeable   #-}
{-# LANGUAGE EmptyCase            #-}
{-# LANGUAGE GADTs                #-}
{-# LANGUAGE KindSignatures       #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE StandaloneDeriving   #-}
{-# LANGUAGE UndecidableInstances #-}
-- | Finite numbers.
--
-- This module is designed to be imported as
--
-- @
-- import Data.Fin (Fin (..))
-- import qualified Data.Fin as Fin
-- @
--
module Data.Fin (
    Fin (..),
    cata,
    -- * Showing
    explicitShow,
    explicitShowsPrec,
    -- * Conversions
    toNat,
    fromNat,
    toNatural,
    toInteger,
    -- * Interesting
    inverse,
    universe,
    inlineUniverse,
    universe1,
    inlineUniverse1,
    absurd,
    boring,
    -- * Plus
    weakenLeft,
    weakenRight,
    append,
    split,
    -- * Aliases
    fin0, fin1, fin2, fin3, fin4, fin5, fin6, fin7, fin8, fin9,
    ) where

import Control.DeepSeq    (NFData (..))
import Data.Bifunctor     (bimap)
import Data.Hashable      (Hashable (..))
import Data.List.NonEmpty (NonEmpty (..))
import Data.Proxy         (Proxy (..))
import Data.Typeable      (Typeable)
import GHC.Exception      (ArithException (..), throw)
import Numeric.Natural    (Natural)
import Data.Type.Nat (Nat (..))

import qualified Data.List.NonEmpty as NE
import qualified Data.Type.Nat      as N

-- | Finite numbers: @[0..n-1]@.
data Fin (n :: Nat) where
    FZ :: Fin ('S n)
    FS :: Fin n -> Fin ('S n)
  deriving (Typeable)

-------------------------------------------------------------------------------
-- Instances
-------------------------------------------------------------------------------

deriving instance Eq (Fin n)
deriving instance Ord (Fin n)

-- | 'Fin' is printed as 'Natural'.
--
-- To see explicit structure, use 'explicitShow' or 'explicitShowsPrec'
instance Show (Fin n) where
    showsPrec d  = showsPrec d . toNatural

-- | Operations module @n@.
--
-- >>> map fromInteger [0, 1, 2, 3, 4, -5] :: [Fin N.Nat3]
-- [0,1,2,0,1,1]
--
-- >>> fromInteger 42 :: Fin N.Nat0
-- *** Exception: divide by zero
-- ...
--
-- >>> signum (FZ :: Fin N.Nat1)
-- 0
--
-- >>> signum (3 :: Fin N.Nat4)
-- 1
--
-- >>> 2 + 3 :: Fin N.Nat4
-- 1
--
-- >>> 2 * 3 :: Fin N.Nat4
-- 2
--
instance N.SNatI n => Num (Fin n) where
    abs = id

    signum FZ          = FZ
    signum (FS FZ)     = FS FZ
    signum (FS (FS _)) = FS FZ

    fromInteger = unsafeFromNum . (`mod` (N.reflectToNum (Proxy :: Proxy n)))

    n + m = fromInteger (toInteger n + toInteger m)
    n * m = fromInteger (toInteger n * toInteger m)
    n - m = fromInteger (toInteger n - toInteger m)

    negate = fromInteger . negate . toInteger

instance N.SNatI n => Real (Fin n) where
    toRational = cata 0 succ

-- | 'quot' works only on @'Fin' n@ where @n@ is prime.
instance N.SNatI n => Integral (Fin n) where
    toInteger = cata 0 succ

    quotRem a b = (quot a b, 0)
    quot a b = a * inverse b

-- | Multiplicative inverse.
--
-- Works for @'Fin' n@ where @n@ is coprime with an argument, i.e. in general when @n@ is prime.
--
-- >>> map inverse universe :: [Fin N.Nat5]
-- [0,1,3,2,4]
--
-- >>> zipWith (*) universe (map inverse universe) :: [Fin N.Nat5]
-- [0,1,1,1,1]
--
-- Adaptation of [pseudo-code in Wikipedia](https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Modular_integers)
--
inverse :: forall n. N.SNatI n => Fin n -> Fin n
inverse = fromInteger . iter 0 n 1 . toInteger where
    n = N.reflectToNum (Proxy :: Proxy n)

    iter t _ _  0
        | t < 0     = t + n
        | otherwise = t
    iter t r t' r' =
        let q = r `div` r'
        in iter t' r' (t - q * t') (r - q * r')

instance N.SNatI n => Enum (Fin n) where
    fromEnum = go where
        go :: Fin m -> Int
        go FZ     = 0
        go (FS n) = succ (go n)

    toEnum = unsafeFromNum

instance (n ~ 'S m, N.SNatI m) => Bounded (Fin n) where
    minBound = FZ
    maxBound = getMaxBound $ N.induction
        (MaxBound FZ)
        (MaxBound . FS . getMaxBound)

newtype MaxBound n = MaxBound { getMaxBound :: Fin ('S n) }

instance NFData (Fin n) where
    rnf FZ     = ()
    rnf (FS n) = rnf n

instance Hashable (Fin n) where
    hashWithSalt salt = hashWithSalt salt . cata (0 :: Integer) succ

-------------------------------------------------------------------------------
-- Showing
-------------------------------------------------------------------------------

-- | 'show' displaying a structure of 'Fin'.
--
-- >>> explicitShow (0 :: Fin N.Nat1)
-- "FZ"
--
-- >>> explicitShow (2 :: Fin N.Nat3)
-- "FS (FS FZ)"
--
explicitShow :: Fin n -> String
explicitShow n = explicitShowsPrec 0 n ""

-- | 'showsPrec' displaying a structure of 'Fin'.
explicitShowsPrec :: Int -> Fin n -> ShowS
explicitShowsPrec _ FZ     = showString "FZ"
explicitShowsPrec d (FS n) = showParen (d > 10)
    $ showString "FS "
    . explicitShowsPrec 11 n

-------------------------------------------------------------------------------
-- Conversions
-------------------------------------------------------------------------------

-- | Fold 'Fin'.
cata :: forall a n. a -> (a -> a) -> Fin n -> a
cata z f = go where
    go :: Fin m -> a
    go FZ = z
    go (FS n) = f (go n)

-- | Convert to 'Nat'.
toNat :: Fin n -> N.Nat
toNat = cata Z S

-- | Convert from 'Nat'.
--
-- >>> fromNat N.nat1 :: Maybe (Fin N.Nat2)
-- Just 1
--
-- >>> fromNat N.nat1 :: Maybe (Fin N.Nat1)
-- Nothing
--
fromNat :: N.SNatI n => N.Nat -> Maybe (Fin n)
fromNat = appNatToFin (N.induction start step) where
    start :: NatToFin 'Z
    start = NatToFin $ const Nothing

    step :: NatToFin n -> NatToFin ('S n)
    step (NatToFin f) = NatToFin $ \n -> case n of
        Z   -> Just FZ
        S m -> fmap FS (f m)

newtype NatToFin n = NatToFin { appNatToFin :: N.Nat -> Maybe (Fin n) }

-- | Convert to 'Natural'.
toNatural :: Fin n -> Natural
toNatural = cata 0 succ

-- | Convert from any 'Ord' 'Num'.
unsafeFromNum :: forall n i. (Num i, Ord i, N.SNatI n) => i -> Fin n
unsafeFromNum = appUnsafeFromNum (N.induction start step) where
    start :: UnsafeFromNum i 'Z
    start = UnsafeFromNum $ \n -> case compare n 0 of
        LT -> throw Underflow
        EQ -> throw Overflow
        GT -> throw Overflow

    step :: UnsafeFromNum i m -> UnsafeFromNum i ('S m)
    step (UnsafeFromNum f) = UnsafeFromNum $ \n -> case compare n 0 of
        EQ -> FZ
        GT -> FS (f (n - 1))
        LT -> throw Underflow

newtype UnsafeFromNum i n = UnsafeFromNum { appUnsafeFromNum :: i -> Fin n }

-------------------------------------------------------------------------------
-- "Interesting" stuff
-------------------------------------------------------------------------------

-- | All values. @[minBound .. maxBound]@ won't work for @'Fin' 'N.Nat0'@.
--
-- >>> universe :: [Fin N.Nat3]
-- [0,1,2]
universe :: N.SNatI n => [Fin n]
universe = getUniverse $ N.induction (Universe []) step where
    step :: Universe n -> Universe ('S n)
    step (Universe xs) = Universe (FZ : map FS xs)

-- | Like 'universe' but 'NonEmpty'.
--
-- >>> universe1 :: NonEmpty (Fin N.Nat3)
-- 0 :| [1,2]
universe1 :: N.SNatI n => NonEmpty (Fin ('S n))
universe1 = getUniverse1 $ N.induction (Universe1 (FZ :| [])) step where
    step :: Universe1 n -> Universe1 ('S n)
    step (Universe1 xs) = Universe1 (NE.cons FZ (fmap FS xs))

-- | 'universe' which will be fully inlined, if @n@ is known at compile time.
--
-- >>> inlineUniverse :: [Fin N.Nat3]
-- [0,1,2]
inlineUniverse :: N.InlineInduction n => [Fin n]
inlineUniverse = getUniverse $ N.inlineInduction (Universe []) step where
    step :: Universe n -> Universe ('S n)
    step (Universe xs) = Universe (FZ : map FS xs)

-- | >>> inlineUniverse1 :: NonEmpty (Fin N.Nat3)
-- 0 :| [1,2]
inlineUniverse1 :: N.InlineInduction n => NonEmpty (Fin ('S n))
inlineUniverse1 = getUniverse1 $ N.inlineInduction (Universe1 (FZ :| [])) step where
    step :: Universe1 n -> Universe1 ('S n)
    step (Universe1 xs) = Universe1 (NE.cons FZ (fmap FS xs))

newtype Universe  n = Universe  { getUniverse  :: [Fin n] }
newtype Universe1 n = Universe1 { getUniverse1 :: NonEmpty (Fin ('S n)) }

-- | @'Fin' 'N.Nat0'@ is inhabited.
absurd :: Fin N.Nat0 -> b
absurd n = case n of {}

-- | Counting to one is boring.
--
-- >>> boring
-- 0
boring :: Fin N.Nat1
boring = FZ

-------------------------------------------------------------------------------
-- Append & Split
-------------------------------------------------------------------------------

weakenLeft :: forall n m. N.InlineInduction n => Proxy m -> Fin n -> Fin (N.Plus n m)
weakenLeft _ = getWeakenLeft (N.inlineInduction start step :: WeakenLeft m n) where
    start :: WeakenLeft m 'Z
    start = WeakenLeft absurd

    step :: WeakenLeft m p -> WeakenLeft m ('S p)
    step (WeakenLeft go) = WeakenLeft $ \n -> case n of
        FZ    -> FZ
        FS n' -> FS (go n')

newtype WeakenLeft m n = WeakenLeft { getWeakenLeft :: Fin n -> Fin (N.Plus n m) }

weakenRight :: forall n m. N.InlineInduction n => Proxy n -> Fin m -> Fin (N.Plus n m)
weakenRight _ = getWeakenRight (N.inlineInduction start step :: WeakenRight m n) where
    start = WeakenRight id
    step (WeakenRight go) = WeakenRight $ \x -> FS $ go x

newtype WeakenRight m n = WeakenRight { getWeakenRight :: Fin m -> Fin (N.Plus n m) }

-- | Append two 'Fin's together.
--
-- >>> append (Left fin2 :: Either (Fin N.Nat5) (Fin N.Nat4))
-- 2
--
-- >>> append (Right fin2 :: Either (Fin N.Nat5) (Fin N.Nat4))
-- 7
--
append :: forall n m. N.InlineInduction n => Either (Fin n) (Fin m) -> Fin (N.Plus n m)
append (Left n)  = weakenLeft (Proxy :: Proxy m) n
append (Right m) = weakenRight (Proxy :: Proxy n) m

-- | Inverse of 'append'.
--
-- >>> split fin2 :: Either (Fin N.Nat2) (Fin N.Nat3)
-- Right 0
--
-- >>> split fin1 :: Either (Fin N.Nat2) (Fin N.Nat3)
-- Left 1
--
-- >>> map split universe :: [Either (Fin N.Nat2) (Fin N.Nat3)]
-- [Left 0,Left 1,Right 0,Right 1,Right 2]
--
split :: forall n m. N.InlineInduction n => Fin (N.Plus n m) -> Either (Fin n) (Fin m)
split = getSplit (N.inlineInduction start step) where
    start :: Split m 'Z
    start = Split Right

    step :: Split m p -> Split m ('S p)
    step (Split go) = Split $ \x -> case x of
        FZ    -> Left FZ
        FS x' -> bimap FS id $ go x'

newtype Split m n = Split { getSplit :: Fin (N.Plus n m) -> Either (Fin n) (Fin m) }

-------------------------------------------------------------------------------
-- Aliases
-------------------------------------------------------------------------------

fin0 :: Fin (N.Plus N.Nat0 ('S n))
fin1 :: Fin (N.Plus N.Nat1 ('S n))
fin2 :: Fin (N.Plus N.Nat2 ('S n))
fin3 :: Fin (N.Plus N.Nat3 ('S n))
fin4 :: Fin (N.Plus N.Nat4 ('S n))
fin5 :: Fin (N.Plus N.Nat5 ('S n))
fin6 :: Fin (N.Plus N.Nat6 ('S n))
fin7 :: Fin (N.Plus N.Nat7 ('S n))
fin8 :: Fin (N.Plus N.Nat8 ('S n))
fin9 :: Fin (N.Plus N.Nat9 ('S n))

fin0 = FZ
fin1 = FS fin0
fin2 = FS fin1
fin3 = FS fin2
fin4 = FS fin3
fin5 = FS fin4
fin6 = FS fin5
fin7 = FS fin6
fin8 = FS fin7
fin9 = FS fin8