{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

{-|
Copyright        : (c) Galois, Inc 2021

@'Fin' n@ is a finite type with exactly @n@ elements. Essentially, they bundle a
'NatRepr' that has an existentially-quantified type parameter with a proof that
its parameter is less than some fixed natural.

They are useful in combination with types of a fixed size. For example 'Fin' is
used as the index in the 'Data.Functor.WithIndex.FunctorWithIndex' instance for
'Data.Parameterized.Vector'. As another example, a @Map ('Fin' n) a@ is a @Map@
that naturally has a fixed size bound of @n@.
-}
module Data.Parameterized.Fin
  ( Fin
  , mkFin
  , buildFin
  , countFin
  , viewFin
  , finToNat
  , embed
  , tryEmbed
  , minFin
  , incFin
  , fin0Void
  , fin1Unit
  , fin2Bool
  ) where

import Control.Lens.Iso (Iso', iso)
import GHC.TypeNats (KnownNat)
import Numeric.Natural (Natural)
import Data.Void (Void, absurd)

import Data.Parameterized.NatRepr

-- | The type @'Fin' n@ has exactly @n@ inhabitants.
data Fin n =
  -- GHC 8.6 and 8.4 require parentheses around 'i + 1 <= n'
  forall i. (i + 1 <= n) => Fin { ()
_getFin :: NatRepr i }

instance Eq (Fin n) where
  Fin n
i == :: Fin n -> Fin n -> Bool
== Fin n
j = forall (n :: Natural). Fin n -> Natural
finToNat Fin n
i forall a. Eq a => a -> a -> Bool
== forall (n :: Natural). Fin n -> Natural
finToNat Fin n
j

instance Ord (Fin n) where
  compare :: Fin n -> Fin n -> Ordering
compare Fin n
i Fin n
j = forall a. Ord a => a -> a -> Ordering
compare (forall (n :: Natural). Fin n -> Natural
finToNat Fin n
i) (forall (n :: Natural). Fin n -> Natural
finToNat Fin n
j)

instance (1 <= n, KnownNat n) => Bounded (Fin n) where
  minBound :: Fin n
minBound = forall (n :: Natural) (i :: Natural).
((i + 1) <= n) =>
NatRepr i -> Fin n
Fin (forall (n :: Natural). KnownNat n => NatRepr n
knownNat @0)
  maxBound :: Fin n
maxBound =
    case forall (f :: Natural -> *) (m :: Natural) (g :: Natural -> *)
       (n :: Natural).
(n <= m) =>
f m -> g n -> ((m - n) + n) :~: m
minusPlusCancel (forall (n :: Natural). KnownNat n => NatRepr n
knownNat @n) (forall (n :: Natural). KnownNat n => NatRepr n
knownNat @1) of
      ((n - 1) + 1) :~: n
Refl -> forall (n :: Natural) (i :: Natural).
((i + 1) <= n) =>
NatRepr i -> Fin n
Fin (forall (n :: Natural). (1 <= n) => NatRepr n -> NatRepr (n - 1)
decNat (forall (n :: Natural). KnownNat n => NatRepr n
knownNat @n))

-- | Non-lawful instance, intended only for testing.
instance Show (Fin n) where
  show :: Fin n -> String
show Fin n
i = String
"Fin " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (forall (n :: Natural). Fin n -> Natural
finToNat Fin n
i)

mkFin :: forall i n. (i + 1 <= n) => NatRepr i -> Fin n
mkFin :: forall (i :: Natural) (n :: Natural).
((i + 1) <= n) =>
NatRepr i -> Fin n
mkFin = forall (n :: Natural) (i :: Natural).
((i + 1) <= n) =>
NatRepr i -> Fin n
Fin
{-# INLINE mkFin #-}

newtype Fin' n = Fin' { forall (n :: Natural). Fin' n -> Fin (n + 1)
getFin' :: Fin (n + 1) }

buildFin ::
  forall m.
  NatRepr m ->
  (forall n. (n + 1 <= m) => NatRepr n -> Fin (n + 1) -> Fin (n + 1 + 1)) ->
  Fin (m + 1)
buildFin :: forall (m :: Natural).
NatRepr m
-> (forall (n :: Natural).
    ((n + 1) <= m) =>
    NatRepr n -> Fin (n + 1) -> Fin ((n + 1) + 1))
-> Fin (m + 1)
buildFin NatRepr m
m forall (n :: Natural).
((n + 1) <= m) =>
NatRepr n -> Fin (n + 1) -> Fin ((n + 1) + 1)
f =
  let f' :: forall k. (k + 1 <= m) => NatRepr k -> Fin' k -> Fin' (k + 1)
      f' :: forall (k :: Natural).
((k + 1) <= m) =>
NatRepr k -> Fin' k -> Fin' (k + 1)
f' = (\NatRepr k
n (Fin' Fin (k + 1)
fin) -> forall (n :: Natural). Fin (n + 1) -> Fin' n
Fin' (forall (n :: Natural).
((n + 1) <= m) =>
NatRepr n -> Fin (n + 1) -> Fin ((n + 1) + 1)
f NatRepr k
n Fin (k + 1)
fin))
  in forall (n :: Natural). Fin' n -> Fin (n + 1)
getFin' (forall (m :: Natural) (f :: Natural -> *).
NatRepr m
-> f 0
-> (forall (n :: Natural).
    ((n + 1) <= m) =>
    NatRepr n -> f n -> f (n + 1))
-> f m
natRecStrictlyBounded NatRepr m
m (forall (n :: Natural). Fin (n + 1) -> Fin' n
Fin' forall (n :: Natural). (1 <= n) => Fin n
minFin) forall (k :: Natural).
((k + 1) <= m) =>
NatRepr k -> Fin' k -> Fin' (k + 1)
f')

-- | Count all of the numbers up to @m@ that meet some condition.
countFin ::
  NatRepr m ->
  (forall n. (n + 1 <= m) => NatRepr n -> Fin (n + 1) -> Bool) ->
  Fin (m + 1)
countFin :: forall (m :: Natural).
NatRepr m
-> (forall (n :: Natural).
    ((n + 1) <= m) =>
    NatRepr n -> Fin (n + 1) -> Bool)
-> Fin (m + 1)
countFin NatRepr m
m forall (n :: Natural).
((n + 1) <= m) =>
NatRepr n -> Fin (n + 1) -> Bool
f =
  forall (m :: Natural).
NatRepr m
-> (forall (n :: Natural).
    ((n + 1) <= m) =>
    NatRepr n -> Fin (n + 1) -> Fin ((n + 1) + 1))
-> Fin (m + 1)
buildFin NatRepr m
m forall a b. (a -> b) -> a -> b
$
    \NatRepr n
n Fin (n + 1)
count ->
      if forall (n :: Natural).
((n + 1) <= m) =>
NatRepr n -> Fin (n + 1) -> Bool
f NatRepr n
n Fin (n + 1)
count
      then forall (n :: Natural). Fin n -> Fin (n + 1)
incFin Fin (n + 1)
count
      else case forall (f :: Natural -> *) (z :: Natural).
f z -> LeqProof z (z + 1)
leqSucc Fin (n + 1)
count of
              LeqProof (n + 1) ((n + 1) + 1)
LeqProof -> forall (n :: Natural) (m :: Natural). (n <= m) => Fin n -> Fin m
embed Fin (n + 1)
count

viewFin ::  (forall i. (i + 1 <= n) => NatRepr i -> r) -> Fin n -> r
viewFin :: forall (n :: Natural) r.
(forall (i :: Natural). ((i + 1) <= n) => NatRepr i -> r)
-> Fin n -> r
viewFin forall (i :: Natural). ((i + 1) <= n) => NatRepr i -> r
f (Fin NatRepr i
i) = forall (i :: Natural). ((i + 1) <= n) => NatRepr i -> r
f NatRepr i
i

finToNat :: Fin n -> Natural
finToNat :: forall (n :: Natural). Fin n -> Natural
finToNat (Fin NatRepr i
i) = forall (n :: Natural). NatRepr n -> Natural
natValue NatRepr i
i
{-# INLINABLE finToNat #-}

embed :: forall n m. (n <= m) => Fin n -> Fin m
embed :: forall (n :: Natural) (m :: Natural). (n <= m) => Fin n -> Fin m
embed =
  forall (n :: Natural) r.
(forall (i :: Natural). ((i + 1) <= n) => NatRepr i -> r)
-> Fin n -> r
viewFin
    (\(NatRepr i
x :: NatRepr o) ->
      case forall (m :: Natural) (n :: Natural) (p :: Natural).
LeqProof m n -> LeqProof n p -> LeqProof m p
leqTrans (forall (m :: Natural) (n :: Natural). (m <= n) => LeqProof m n
LeqProof :: LeqProof (o + 1) n) (forall (m :: Natural) (n :: Natural). (m <= n) => LeqProof m n
LeqProof :: LeqProof n m) of
        LeqProof (i + 1) m
LeqProof -> forall (n :: Natural) (i :: Natural).
((i + 1) <= n) =>
NatRepr i -> Fin n
Fin NatRepr i
x
    )

tryEmbed :: NatRepr n -> NatRepr m -> Fin n -> Maybe (Fin m)
tryEmbed :: forall (n :: Natural) (m :: Natural).
NatRepr n -> NatRepr m -> Fin n -> Maybe (Fin m)
tryEmbed NatRepr n
n NatRepr m
m Fin n
i =
  case forall (m :: Natural) (n :: Natural).
NatRepr m -> NatRepr n -> Maybe (LeqProof m n)
testLeq NatRepr n
n NatRepr m
m of
    Just LeqProof n m
LeqProof -> forall a. a -> Maybe a
Just (forall (n :: Natural) (m :: Natural). (n <= m) => Fin n -> Fin m
embed Fin n
i)
    Maybe (LeqProof n m)
Nothing -> forall a. Maybe a
Nothing

-- | The smallest element of @'Fin' n@
minFin :: (1 <= n) => Fin n
minFin :: forall (n :: Natural). (1 <= n) => Fin n
minFin = forall (n :: Natural) (i :: Natural).
((i + 1) <= n) =>
NatRepr i -> Fin n
Fin (forall (n :: Natural). KnownNat n => NatRepr n
knownNat @0)
{-# INLINABLE minFin #-}

incFin :: forall n. Fin n -> Fin (n + 1)
incFin :: forall (n :: Natural). Fin n -> Fin (n + 1)
incFin (Fin (NatRepr i
i :: NatRepr i)) =
  case forall (x_l :: Natural) (x_h :: Natural) (y_l :: Natural)
       (y_h :: Natural).
LeqProof x_l x_h
-> LeqProof y_l y_h -> LeqProof (x_l + y_l) (x_h + y_h)
leqAdd2 (forall (m :: Natural) (n :: Natural). (m <= n) => LeqProof m n
LeqProof :: LeqProof (i + 1) n) (forall (m :: Natural) (n :: Natural). (m <= n) => LeqProof m n
LeqProof :: LeqProof 1 1) of
    LeqProof ((i + 1) + 1) (n + 1)
LeqProof -> forall (i :: Natural) (n :: Natural).
((i + 1) <= n) =>
NatRepr i -> Fin n
mkFin (forall (n :: Natural). NatRepr n -> NatRepr (n + 1)
incNat NatRepr i
i)

fin0Void :: Iso' (Fin 0) Void
fin0Void :: Iso' (Fin 0) Void
fin0Void =
  forall s a b t. (s -> a) -> (b -> t) -> Iso s t a b
iso
    (forall (n :: Natural) r.
(forall (i :: Natural). ((i + 1) <= n) => NatRepr i -> r)
-> Fin n -> r
viewFin
      (\(NatRepr i
x :: NatRepr o) ->
        case forall (f :: Natural -> *) (m :: Natural) (g :: Natural -> *)
       (n :: Natural).
f m -> g n -> (m + n) :~: (n + m)
plusComm NatRepr i
x (forall (n :: Natural). KnownNat n => NatRepr n
knownNat @1) of
          (i + 1) :~: (1 + i)
Refl ->
            case forall (n :: Natural) (n' :: Natural) (m :: Natural).
LeqProof (n + n') m -> LeqProof n m
addIsLeqLeft1 @1 @o @0 forall (m :: Natural) (n :: Natural). (m <= n) => LeqProof m n
LeqProof of {}))
    forall a. Void -> a
absurd

fin1Unit :: Iso' (Fin 1) ()
fin1Unit :: Iso' (Fin 1) ()
fin1Unit = forall s a b t. (s -> a) -> (b -> t) -> Iso s t a b
iso (forall a b. a -> b -> a
const ()) (forall a b. a -> b -> a
const forall (n :: Natural). (1 <= n) => Fin n
minFin)

fin2Bool :: Iso' (Fin 2) Bool
fin2Bool :: Iso' (Fin 2) Bool
fin2Bool =
  forall s a b t. (s -> a) -> (b -> t) -> Iso s t a b
iso
    (forall (n :: Natural) r.
(forall (i :: Natural). ((i + 1) <= n) => NatRepr i -> r)
-> Fin n -> r
viewFin
      (\NatRepr i
n ->
         case forall (n :: Natural). NatRepr n -> IsZeroNat n
isZeroNat NatRepr i
n of
           IsZeroNat i
ZeroNat -> Bool
False
           IsZeroNat i
NonZeroNat -> Bool
True))
    (\Bool
b -> if Bool
b then forall a. Bounded a => a
maxBound else forall a. Bounded a => a
minBound)