{-# language BangPatterns #-}
{-# language DataKinds #-}
{-# language ExplicitNamespaces #-}
{-# language GADTs #-}
{-# language KindSignatures #-}
{-# language ScopedTypeVariables #-}
{-# language TypeApplications #-}
{-# language TypeOperators #-}
module Arithmetic.Fin
(
incrementL
, incrementR
, weakenL
, weakenR
, ascend
, ascendM
, ascendM_
, ascending
, descending
, ascendingSlice
, absurd
, demote
) where
import Prelude hiding (last)
import Arithmetic.Nat ((<?))
import Arithmetic.Types (Fin(..),Difference(..),Nat,type (<), type (<=), type (:=:))
import GHC.TypeNats (type (+))
import qualified Arithmetic.Lt as Lt
import qualified Arithmetic.Lte as Lte
import qualified Arithmetic.Nat as Nat
import qualified Arithmetic.Plus as Plus
incrementR :: forall n m. Nat m -> Fin n -> Fin (n + m)
incrementR m (Fin i pf) = Fin (Nat.plus i m) (Lt.incrementR @m pf)
incrementL :: forall n m. Nat m -> Fin n -> Fin (m + n)
incrementL m (Fin i pf) = Fin (Nat.plus m i) (Lt.incrementL @m pf)
weakenL :: forall n m. Fin n -> Fin (m + n)
weakenL (Fin i pf) = Fin i
( Lt.substituteR
(Plus.commutative @n @m)
(Lt.plus pf (Lte.zero @m))
)
weakenR :: forall n m. Fin n -> Fin (n + m)
weakenR (Fin i pf) = Fin i (Lt.plus pf Lte.zero)
absurd :: Fin 0 -> void
absurd (Fin _ pf) = Lt.absurd pf
ascend :: forall a n.
Nat n
-> a
-> (Fin n -> a -> a)
-> a
{-# inline ascend #-}
ascend !n !b0 f = go Nat.zero b0
where
go :: Nat m -> a -> a
go !m !b = case m <? n of
Nothing -> b
Just lt -> go (Nat.succ m) (f (Fin m lt) b)
ascendM :: forall m a n. Monad m
=> Nat n
-> a
-> (Fin n -> a -> m a)
-> m a
{-# inline ascendM #-}
ascendM !n !b0 f = go Nat.zero b0
where
go :: Nat p -> a -> m a
go !m !b = case m <? n of
Nothing -> pure b
Just lt -> go (Nat.succ m) =<< f (Fin m lt) b
ascendM_ :: forall m a n. Applicative m
=> Nat n
-> (Fin n -> m a)
-> m ()
{-# inline ascendM_ #-}
ascendM_ !n f = go Nat.zero
where
go :: Nat p -> m ()
go !m = case m <? n of
Nothing -> pure ()
Just lt -> f (Fin m lt) *> go (Nat.succ m)
ascending :: forall n. Nat n -> [Fin n]
ascending !n = go Nat.zero
where
go :: Nat m -> [Fin n]
go !m = case m <? n of
Nothing -> []
Just lt -> Fin m lt : go (Nat.succ m)
descending :: forall n. Nat n -> [Fin n]
descending n = go n Lte.reflexive
where
go :: forall m. Nat m -> (m <= n) -> [Fin n]
go !m !lt = case Nat.monus m Nat.one of
Nothing -> []
Just (Difference mpred eq) -> go2 lt mpred eq
go2 :: forall m c. (m <= n) -> Nat c -> (c + 1 :=: m) -> [Fin n]
go2 !lt !c !eq =
let ceeLtEm :: c < m
ceeLtEm = id
$ Lt.substituteR eq
$ Lt.substituteL Plus.zeroL
$ Lt.incrementL @c Lt.zero
in Fin c (Lt.transitiveNonstrictR ceeLtEm lt) : go c
(Lte.transitive (Lte.substituteR eq (Lte.weakenR @1 (Lte.reflexive @c))) lt)
ascendingSlice :: forall n off len.
Nat off
-> Nat len
-> (off + len < n)
-> [Fin n]
ascendingSlice off len !offPlusLenLtEn = go Nat.zero
where
go :: Nat m -> [Fin n]
go !m = case m <? len of
Nothing -> []
Just emLtLen ->
let !offPlusEmLtOffPlusLen = Lt.incrementL @off emLtLen
!offPlusEmLtEn = Lt.transitive offPlusEmLtOffPlusLen offPlusLenLtEn
in Fin (Nat.plus off m) offPlusEmLtEn : go (Nat.succ m)
demote :: Fin n -> Int
demote (Fin i _) = Nat.demote i