{-# language BangPatterns #-}
{-# language DataKinds #-}
{-# language ExplicitNamespaces #-}
{-# language GADTs #-}
{-# language KindSignatures #-}
{-# language ScopedTypeVariables #-}
{-# language TypeApplications #-}
{-# language TypeOperators #-}
module Arithmetic.Fin
(
incrementL
, incrementR
, weaken
, weakenL
, weakenR
, ascend
, ascend'
, ascendM
, ascendM_
, descend
, descend'
, descendM
, descendM_
, ascending
, descending
, ascendingSlice
, descendingSlice
, absurd
, demote
) where
import Prelude hiding (last)
import Arithmetic.Nat ((<?))
import Arithmetic.Types (Fin(..),Difference(..),Nat,type (<), type (<=), type (:=:))
import GHC.TypeNats (type (+))
import qualified Arithmetic.Lt as Lt
import qualified Arithmetic.Lte as Lte
import qualified Arithmetic.Equal as Eq
import qualified Arithmetic.Nat as Nat
import qualified Arithmetic.Plus as Plus
incrementR :: forall n m. Nat m -> Fin n -> Fin (n + m)
incrementR m (Fin i pf) = Fin (Nat.plus i m) (Lt.incrementR @m pf)
incrementL :: forall n m. Nat m -> Fin n -> Fin (m + n)
incrementL m (Fin i pf) = Fin (Nat.plus m i) (Lt.incrementL @m pf)
weakenL :: forall n m. Fin n -> Fin (m + n)
weakenL (Fin i pf) = Fin i
( Lt.substituteR
(Plus.commutative @n @m)
(Lt.plus pf (Lte.zero @m))
)
weakenR :: forall n m. Fin n -> Fin (n + m)
weakenR (Fin i pf) = Fin i (Lt.plus pf Lte.zero)
weaken :: forall n m. (n <= m) -> Fin n -> Fin m
weaken lt (Fin i pf) = Fin i (Lt.transitiveNonstrictR pf lt)
absurd :: Fin 0 -> void
absurd (Fin _ pf) = Lt.absurd pf
descend :: forall a n.
Nat n
-> a
-> (Fin n -> a -> a)
-> a
{-# inline descend #-}
descend !n b0 f = go Nat.zero
where
go :: Nat m -> a
go !m = case m <? n of
Nothing -> b0
Just lt -> f (Fin m lt) (go (Nat.succ m))
descend' :: forall a n.
Nat n
-> a
-> (Fin n -> a -> a)
-> a
{-# inline descend' #-}
descend' !n !b0 f = go n Lte.reflexive b0
where
go :: Nat p -> p <= n -> a -> a
go !m pLteEn !b = case Nat.monus m Nat.one of
Nothing -> b
Just (Difference (mpred :: Nat c) cPlusOneEqP) ->
let !cLtEn = descendLemma cPlusOneEqP pLteEn
in go mpred (Lte.fromStrict cLtEn) (f (Fin mpred cLtEn) b)
ascend :: forall a n.
Nat n
-> a
-> (Fin n -> a -> a)
-> a
{-# inline ascend #-}
ascend !n !b0 f = go n Lte.reflexive
where
go :: Nat p -> (p <= n) -> a
go !m pLteEn = case Nat.monus m Nat.one of
Nothing -> b0
Just (Difference (mpred :: Nat c) cPlusOneEqP) ->
let !cLtEn = descendLemma cPlusOneEqP pLteEn
in f (Fin mpred cLtEn) (go mpred (Lte.fromStrict cLtEn))
ascend' :: forall a n.
Nat n
-> a
-> (Fin n -> a -> a)
-> a
{-# inline ascend' #-}
ascend' !n !b0 f = go Nat.zero b0
where
go :: Nat m -> a -> a
go !m !b = case m <? n of
Nothing -> b
Just lt -> go (Nat.succ m) (f (Fin m lt) b)
ascendM :: forall m a n. Monad m
=> Nat n
-> a
-> (Fin n -> a -> m a)
-> m a
{-# inline ascendM #-}
ascendM !n !b0 f = go Nat.zero b0
where
go :: Nat p -> a -> m a
go !m !b = case m <? n of
Nothing -> pure b
Just lt -> go (Nat.succ m) =<< f (Fin m lt) b
ascendM_ :: forall m a n. Applicative m
=> Nat n
-> (Fin n -> m a)
-> m ()
{-# inline ascendM_ #-}
ascendM_ !n f = go Nat.zero
where
go :: Nat p -> m ()
go !m = case m <? n of
Nothing -> pure ()
Just lt -> f (Fin m lt) *> go (Nat.succ m)
descendLemma :: forall a b c. a + 1 :=: b -> b <= c -> a < c
{-# inline descendLemma #-}
descendLemma !aPlusOneEqB !bLteC = id
$ Lt.transitiveNonstrictR
(Lt.substituteR (Plus.commutative @1 @a)
(Lt.plus Lt.zero Lte.reflexive))
$ Lte.substituteL (Eq.symmetric aPlusOneEqB) bLteC
descendM :: forall m a n. Monad m
=> Nat n
-> a
-> (Fin n -> a -> m a)
-> m a
{-# inline descendM #-}
descendM !n !b0 f = go n Lte.reflexive b0
where
go :: Nat p -> p <= n -> a -> m a
go !m pLteEn !b = case Nat.monus m Nat.one of
Nothing -> pure b
Just (Difference (mpred :: Nat c) cPlusOneEqP) ->
let !cLtEn = descendLemma cPlusOneEqP pLteEn
in go mpred (Lte.fromStrict cLtEn) =<< f (Fin mpred cLtEn) b
descendM_ :: forall m a n. Applicative m
=> Nat n
-> (Fin n -> m a)
-> m ()
{-# inline descendM_ #-}
descendM_ !n f = go n Lte.reflexive
where
go :: Nat p -> p <= n -> m ()
go !m !pLteEn = case Nat.monus m Nat.one of
Nothing -> pure ()
Just (Difference (mpred :: Nat c) cPlusOneEqP) ->
let !cLtEn = descendLemma cPlusOneEqP pLteEn
in f (Fin mpred cLtEn) *> go mpred (Lte.fromStrict cLtEn)
ascending :: forall n. Nat n -> [Fin n]
ascending !n = go Nat.zero
where
go :: Nat m -> [Fin n]
go !m = case m <? n of
Nothing -> []
Just lt -> Fin m lt : go (Nat.succ m)
descending :: forall n. Nat n -> [Fin n]
descending !n = go n Lte.reflexive
where
go :: Nat p -> (p <= n) -> [Fin n]
go !m !pLteEn = case Nat.monus m Nat.one of
Nothing -> []
Just (Difference (mpred :: Nat c) cPlusOneEqP) ->
let !cLtEn = descendLemma cPlusOneEqP pLteEn
in Fin mpred cLtEn : go mpred (Lte.fromStrict cLtEn)
ascendingSlice
:: forall n off len
. Nat off
-> Nat len
-> off + len <= n
-> [Fin n]
{-# inline ascendingSlice #-}
ascendingSlice off len !offPlusLenLteEn = go Nat.zero
where
go :: Nat m -> [Fin n]
go !m = case m <? len of
Nothing -> []
Just emLtLen ->
let !offPlusEmLtOffPlusLen = Lt.incrementL @off emLtLen
!offPlusEmLtEn = Lt.transitiveNonstrictR offPlusEmLtOffPlusLen offPlusLenLteEn
in Fin (Nat.plus off m) offPlusEmLtEn : go (Nat.succ m)
descendingSlice
:: forall n off len
. Nat off
-> Nat len
-> off + len <= n
-> [Fin n]
{-# inline descendingSlice #-}
descendingSlice !off !len !offPlusLenLteEn =
go len Lte.reflexive
where
go :: Nat m -> m <= len -> [Fin n]
go !m !mLteEn = case Nat.monus m Nat.one of
Nothing -> []
Just (Difference (mpred :: Nat c) cPlusOneEqEm) ->
let !cLtLen = Lt.transitiveNonstrictR
(Lt.substituteR (Plus.commutative @1 @c) (Lt.plus Lt.zero Lte.reflexive))
(Lte.substituteL (Eq.symmetric cPlusOneEqEm) mLteEn)
!cPlusOffLtEn = Lt.transitiveNonstrictR
(Lt.substituteR
(Plus.commutative @len @off)
(Lt.plus cLtLen (Lte.reflexive @off)))
offPlusLenLteEn
in Fin (mpred `Nat.plus` off) cPlusOffLtEn : go mpred (Lte.fromStrict cLtLen)
demote :: Fin n -> Int
demote (Fin i _) = Nat.demote i