{-|
Copyright        : (c) Galois, Inc 2022

See "Data.Parameterized.FinMap".
-}

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}

module Data.Parameterized.FinMap.Safe
  ( FinMap
  -- * Query
  , null
  , lookup
  , size
  -- * Construction
  , incMax
  , embed
  , empty
  , singleton
  , insert
  , buildFinMap
  , append
  , fromVector
  -- * Operations
  , delete
  , decMax
  , mapWithKey
  , unionWithKey
  , unionWith
  , union
  ) where

import           Prelude hiding (lookup, null)

import           Data.Foldable.WithIndex (FoldableWithIndex(ifoldMap))
import           Data.Functor.WithIndex (FunctorWithIndex(imap))
import           Data.Maybe (isJust)
import           Data.Proxy (Proxy(Proxy))
import           Data.Map (Map)
import qualified Data.Map as Map
import           GHC.TypeLits (KnownNat, Nat)

import           Data.Parameterized.Fin (Fin)
import qualified Data.Parameterized.Fin as Fin
import           Data.Parameterized.NatRepr (NatRepr, type (+), type (<=))
import qualified Data.Parameterized.NatRepr as NatRepr
import           Data.Parameterized.Vector (Vector)
import qualified Data.Parameterized.Vector as Vec

------------------------------------------------------------------------
-- Type

-- | @'FinMap' n a@ is a map with @'Fin' n@ keys and @a@ values.
data FinMap (n :: Nat) a =
  FinMap
    { forall (n :: Nat) a. FinMap n a -> Map (Fin n) a
getFinMap :: Map (Fin n) a
    , forall (n :: Nat) a. FinMap n a -> NatRepr n
maxSize :: NatRepr n
    }

instance Eq a => Eq (FinMap n a) where
  FinMap n a
fm1 == :: FinMap n a -> FinMap n a -> Bool
== FinMap n a
fm2 = forall (n :: Nat) a. FinMap n a -> Map (Fin n) a
getFinMap FinMap n a
fm1 forall a. Eq a => a -> a -> Bool
== forall (n :: Nat) a. FinMap n a -> Map (Fin n) a
getFinMap FinMap n a
fm2
  {-# INLINABLE (==) #-}

instance Semigroup (FinMap n a) where
  <> :: FinMap n a -> FinMap n a -> FinMap n a
(<>) = forall (n :: Nat) a. FinMap n a -> FinMap n a -> FinMap n a
union
  {-# INLINE (<>) #-}

instance KnownNat n => Monoid (FinMap n a) where
  mempty :: FinMap n a
mempty = forall (n :: Nat) a. KnownNat n => FinMap n a
empty
  {-# INLINE mempty #-}

instance Functor (FinMap n) where
  fmap :: forall a b. (a -> b) -> FinMap n a -> FinMap n b
fmap a -> b
f FinMap n a
fm = FinMap n a
fm { getFinMap :: Map (Fin n) b
getFinMap = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f (forall (n :: Nat) a. FinMap n a -> Map (Fin n) a
getFinMap FinMap n a
fm) }
  {-# INLINABLE fmap #-}

instance Foldable (FinMap n) where
  foldMap :: forall m a. Monoid m => (a -> m) -> FinMap n a -> m
foldMap a -> m
f = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap a -> m
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) a. FinMap n a -> Map (Fin n) a
getFinMap
  {-# INLINABLE foldMap #-}

instance Traversable (FinMap n) where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> FinMap n a -> f (FinMap n b)
traverse a -> f b
f FinMap n a
fm = forall (n :: Nat) a. Map (Fin n) a -> NatRepr n -> FinMap n a
FinMap forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f (forall (n :: Nat) a. FinMap n a -> Map (Fin n) a
getFinMap FinMap n a
fm) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (n :: Nat) a. FinMap n a -> NatRepr n
maxSize FinMap n a
fm)

instance FunctorWithIndex (Fin n) (FinMap n) where
  imap :: forall a b. (Fin n -> a -> b) -> FinMap n a -> FinMap n b
imap Fin n -> a -> b
f FinMap n a
fm = FinMap n a
fm { getFinMap :: Map (Fin n) b
getFinMap = forall k a b. (k -> a -> b) -> Map k a -> Map k b
Map.mapWithKey Fin n -> a -> b
f (forall (n :: Nat) a. FinMap n a -> Map (Fin n) a
getFinMap FinMap n a
fm) }
  -- Inline so that RULES for Map.mapWithKey can fire
  {-# INLINE imap #-}

instance FoldableWithIndex (Fin n) (FinMap n) where
  ifoldMap :: forall m a. Monoid m => (Fin n -> a -> m) -> FinMap n a -> m
ifoldMap Fin n -> a -> m
f = forall m k a. Monoid m => (k -> a -> m) -> Map k a -> m
Map.foldMapWithKey Fin n -> a -> m
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) a. FinMap n a -> Map (Fin n) a
getFinMap
  {-# INLINABLE ifoldMap #-}

-- | Non-lawful instance, provided for testing
instance Show a => Show (FinMap n a) where
  show :: FinMap n a -> String
show FinMap n a
fm = forall a. Show a => a -> String
show (forall (n :: Nat) a. FinMap n a -> Map (Fin n) a
getFinMap FinMap n a
fm)
  {-# INLINABLE show #-}

------------------------------------------------------------------------
-- Query

-- | /O(1)/. Is the map empty?
null :: FinMap n a -> Bool
null :: forall (n :: Nat) a. FinMap n a -> Bool
null = forall k a. Map k a -> Bool
Map.null forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) a. FinMap n a -> Map (Fin n) a
getFinMap
{-# INLINABLE null #-}

-- | /O(log n)/. Fetch the value at the given key in the map.
lookup :: Fin n -> FinMap n a -> Maybe a
lookup :: forall (n :: Nat) a. Fin n -> FinMap n a -> Maybe a
lookup Fin n
k = forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Fin n
k forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) a. FinMap n a -> Map (Fin n) a
getFinMap
{-# INLINABLE lookup #-}

-- | /O(nlog(n))/. Number of elements in the map.
--
-- This operation is much slower than 'Data.Parameterized.FinMap.Unsafe.size'
-- because its implementation must provide significant evidence to the
-- type-checker, and the easiest way to do that is fairly inefficient.
-- If speed is a concern, use "Data.Parameterized.FinMap.Unsafe".
size :: forall n a. FinMap n a -> Fin (n + 1)
size :: forall (n :: Nat) a. FinMap n a -> Fin (n + 1)
size FinMap n a
fm =
  forall (m :: Nat).
NatRepr m
-> (forall (n :: Nat).
    ((n + 1) <= m) =>
    NatRepr n -> Fin (n + 1) -> Bool)
-> Fin (m + 1)
Fin.countFin (forall (n :: Nat) a. FinMap n a -> NatRepr n
maxSize FinMap n a
fm) (\NatRepr n
k Fin (n + 1)
_count -> forall a. Maybe a -> Bool
isJust (forall (n :: Nat) a. Fin n -> FinMap n a -> Maybe a
lookup (forall (i :: Nat) (n :: Nat). ((i + 1) <= n) => NatRepr i -> Fin n
Fin.mkFin NatRepr n
k) FinMap n a
fm))

------------------------------------------------------------------------
-- Construction

-- | /O(n log n)/. Increase maximum key/size by 1.
--
-- This does not alter the key-value pairs in the map, but rather increases the
-- maximum number of key-value pairs that the map can hold. See
-- "Data.Parameterized.FinMap" for more information.
--
-- Requires @n + 1 < (maxBound :: Int)@.
incMax :: forall n a. FinMap n a -> FinMap (n + 1) a
incMax :: forall (n :: Nat) a. FinMap n a -> FinMap (n + 1) a
incMax FinMap n a
fm =
  case forall (f :: Nat -> *) (z :: Nat). f z -> LeqProof z (z + 1)
NatRepr.leqSucc (forall {k} (t :: k). Proxy t
Proxy :: Proxy n) of
    LeqProof n (n + 1)
NatRepr.LeqProof -> forall (n :: Nat) (m :: Nat) a.
(n <= m) =>
NatRepr m -> FinMap n a -> FinMap m a
embed (forall (n :: Nat). NatRepr n -> NatRepr (n + 1)
NatRepr.incNat (forall (n :: Nat) a. FinMap n a -> NatRepr n
maxSize FinMap n a
fm)) FinMap n a
fm

-- | /O(n log n)/. Increase maximum key/size.
--
-- Requires @m < (maxBound :: Int)@.
embed :: (n <= m) => NatRepr m -> FinMap n a -> FinMap m a
embed :: forall (n :: Nat) (m :: Nat) a.
(n <= m) =>
NatRepr m -> FinMap n a -> FinMap m a
embed NatRepr m
m FinMap n a
fm =
  FinMap
    { getFinMap :: Map (Fin m) a
getFinMap = forall k2 k1 a. Ord k2 => (k1 -> k2) -> Map k1 a -> Map k2 a
Map.mapKeys forall (n :: Nat) (m :: Nat). (n <= m) => Fin n -> Fin m
Fin.embed (forall (n :: Nat) a. FinMap n a -> Map (Fin n) a
getFinMap FinMap n a
fm)
    , maxSize :: NatRepr m
maxSize = NatRepr m
m
    }

-- | /O(1)/. The empty map.
empty :: KnownNat n => FinMap n a
empty :: forall (n :: Nat) a. KnownNat n => FinMap n a
empty = forall (n :: Nat) a. Map (Fin n) a -> NatRepr n -> FinMap n a
FinMap forall k a. Map k a
Map.empty forall (n :: Nat). KnownNat n => NatRepr n
NatRepr.knownNat
{-# INLINABLE empty #-}

-- | /O(1)/. A map with one element.
singleton :: a -> FinMap 1 a
singleton :: forall a. a -> FinMap 1 a
singleton a
item =
  FinMap
    { getFinMap :: Map (Fin 1) a
getFinMap = forall k a. k -> a -> Map k a
Map.singleton (forall (i :: Nat) (n :: Nat). ((i + 1) <= n) => NatRepr i -> Fin n
Fin.mkFin (forall (n :: Nat). KnownNat n => NatRepr n
NatRepr.knownNat :: NatRepr 0)) a
item
    , maxSize :: NatRepr 1
maxSize = forall (n :: Nat). KnownNat n => NatRepr n
NatRepr.knownNat :: NatRepr 1
    }

-- | /O(log n)/.
insert :: Fin n -> a -> FinMap n a -> FinMap n a
insert :: forall (n :: Nat) a. Fin n -> a -> FinMap n a -> FinMap n a
insert Fin n
k a
v FinMap n a
fm = FinMap n a
fm { getFinMap :: Map (Fin n) a
getFinMap = forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Fin n
k a
v (forall (n :: Nat) a. FinMap n a -> Map (Fin n) a
getFinMap FinMap n a
fm) }
{-# INLINABLE insert #-}

-- buildFinMap, append, and fromVector are duplicated exactly between the safe
-- and unsafe modules because they are used in comparative testing (and so
-- implementations must be available for both types).

newtype FinMap' a (n :: Nat) = FinMap' { forall a (n :: Nat). FinMap' a n -> FinMap n a
unFinMap' :: FinMap n a }

buildFinMap ::
  forall m a.
  NatRepr m ->
  (forall n. (n + 1 <= m) => NatRepr n -> FinMap n a -> FinMap (n + 1) a) ->
  FinMap m a
buildFinMap :: forall (m :: Nat) a.
NatRepr m
-> (forall (n :: Nat).
    ((n + 1) <= m) =>
    NatRepr n -> FinMap n a -> FinMap (n + 1) a)
-> FinMap m a
buildFinMap NatRepr m
m forall (n :: Nat).
((n + 1) <= m) =>
NatRepr n -> FinMap n a -> FinMap (n + 1) a
f =
  let f' :: forall k. (k + 1 <= m) => NatRepr k -> FinMap' a k -> FinMap' a (k + 1)
      f' :: forall (k :: Nat).
((k + 1) <= m) =>
NatRepr k -> FinMap' a k -> FinMap' a (k + 1)
f' = (\NatRepr k
n (FinMap' FinMap k a
fin) -> forall a (n :: Nat). FinMap n a -> FinMap' a n
FinMap' (forall (n :: Nat).
((n + 1) <= m) =>
NatRepr n -> FinMap n a -> FinMap (n + 1) a
f NatRepr k
n FinMap k a
fin))
  in forall a (n :: Nat). FinMap' a n -> FinMap n a
unFinMap' (forall (m :: Nat) (f :: Nat -> *).
NatRepr m
-> f 0
-> (forall (n :: Nat).
    ((n + 1) <= m) =>
    NatRepr n -> f n -> f (n + 1))
-> f m
NatRepr.natRecStrictlyBounded NatRepr m
m (forall a (n :: Nat). FinMap n a -> FinMap' a n
FinMap' forall (n :: Nat) a. KnownNat n => FinMap n a
empty) forall (k :: Nat).
((k + 1) <= m) =>
NatRepr k -> FinMap' a k -> FinMap' a (k + 1)
f')

-- | /O(min(n,W))/.
append :: NatRepr n -> a -> FinMap n a -> FinMap (n + 1) a
append :: forall (n :: Nat) a.
NatRepr n -> a -> FinMap n a -> FinMap (n + 1) a
append NatRepr n
k a
v FinMap n a
fm =
  case forall (f :: Nat -> *) (z :: Nat). f z -> LeqProof z (z + 1)
NatRepr.leqSucc NatRepr n
k of
    LeqProof n (n + 1)
NatRepr.LeqProof -> forall (n :: Nat) a. Fin n -> a -> FinMap n a -> FinMap n a
insert (forall (i :: Nat) (n :: Nat). ((i + 1) <= n) => NatRepr i -> Fin n
Fin.mkFin NatRepr n
k) a
v (forall (n :: Nat) a. FinMap n a -> FinMap (n + 1) a
incMax FinMap n a
fm)

fromVector :: forall n a. Vector n (Maybe a) -> FinMap n a
fromVector :: forall (n :: Nat) a. Vector n (Maybe a) -> FinMap n a
fromVector Vector n (Maybe a)
v =
  forall (m :: Nat) a.
NatRepr m
-> (forall (n :: Nat).
    ((n + 1) <= m) =>
    NatRepr n -> FinMap n a -> FinMap (n + 1) a)
-> FinMap m a
buildFinMap
    (forall (n :: Nat) a. Vector n a -> NatRepr n
Vec.length Vector n (Maybe a)
v)
    (\NatRepr n
k FinMap n a
m ->
      case forall (i :: Nat) (n :: Nat) a.
((i + 1) <= n) =>
NatRepr i -> Vector n a -> a
Vec.elemAt NatRepr n
k Vector n (Maybe a)
v of
        Just a
e -> forall (n :: Nat) a.
NatRepr n -> a -> FinMap n a -> FinMap (n + 1) a
append NatRepr n
k a
e FinMap n a
m
        Maybe a
Nothing -> forall (n :: Nat) a. FinMap n a -> FinMap (n + 1) a
incMax FinMap n a
m)

------------------------------------------------------------------------
-- Operations

-- | /O(log n)/.
delete :: Fin n -> FinMap n a -> FinMap n a
delete :: forall (n :: Nat) a. Fin n -> FinMap n a -> FinMap n a
delete Fin n
k FinMap n a
fm = FinMap n a
fm { getFinMap :: Map (Fin n) a
getFinMap = forall k a. Ord k => k -> Map k a -> Map k a
Map.delete Fin n
k (forall (n :: Nat) a. FinMap n a -> Map (Fin n) a
getFinMap FinMap n a
fm) }
{-# INLINABLE delete #-}

-- | Decrement the key/size, removing the item at key @n + 1@ if present.
decMax :: NatRepr n -> FinMap (n + 1) a -> FinMap n a
decMax :: forall (n :: Nat) a. NatRepr n -> FinMap (n + 1) a -> FinMap n a
decMax NatRepr n
n FinMap (n + 1) a
fm =
  FinMap
    { getFinMap :: Map (Fin n) a
getFinMap = forall k2 k1 a. Ord k2 => (k1 -> Maybe k2) -> Map k1 a -> Map k2 a
maybeMapKeys (forall (n :: Nat) (m :: Nat).
NatRepr n -> NatRepr m -> Fin n -> Maybe (Fin m)
Fin.tryEmbed NatRepr (n + 1)
sz NatRepr n
n) (forall (n :: Nat) a. FinMap n a -> Map (Fin n) a
getFinMap FinMap (n + 1) a
fm)
    , maxSize :: NatRepr n
maxSize = NatRepr n
n
    }
  where
    sz :: NatRepr (n + 1)
sz = forall (n :: Nat) a. FinMap n a -> NatRepr n
maxSize FinMap (n + 1) a
fm

    maybeMapKeys :: Ord k2 => (k1 -> Maybe k2) -> Map k1 a -> Map k2 a
    maybeMapKeys :: forall k2 k1 a. Ord k2 => (k1 -> Maybe k2) -> Map k1 a -> Map k2 a
maybeMapKeys k1 -> Maybe k2
f Map k1 a
m =
      forall k a b. (k -> a -> b -> b) -> b -> Map k a -> b
Map.foldrWithKey
        (\k1
k a
v Map k2 a
accum ->
           case k1 -> Maybe k2
f k1
k of
             Just k2
k' -> forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert k2
k' a
v Map k2 a
accum
             Maybe k2
Nothing -> Map k2 a
accum)
        forall k a. Map k a
Map.empty
        Map k1 a
m

mapWithKey :: (Fin n -> a -> b) -> FinMap n a -> FinMap n b
mapWithKey :: forall (n :: Nat) a b.
(Fin n -> a -> b) -> FinMap n a -> FinMap n b
mapWithKey Fin n -> a -> b
f FinMap n a
fm = FinMap n a
fm { getFinMap :: Map (Fin n) b
getFinMap = forall k a b. (k -> a -> b) -> Map k a -> Map k b
Map.mapWithKey Fin n -> a -> b
f (forall (n :: Nat) a. FinMap n a -> Map (Fin n) a
getFinMap FinMap n a
fm) }
-- Inline so that RULES for Map.mapWithKey can fire
{-# INLINE mapWithKey #-}

-- | /O(n+m)/.
unionWithKey :: (Fin n -> a -> a -> a) -> FinMap n a -> FinMap n a -> FinMap n a
unionWithKey :: forall (n :: Nat) a.
(Fin n -> a -> a -> a) -> FinMap n a -> FinMap n a -> FinMap n a
unionWithKey Fin n -> a -> a -> a
f FinMap n a
fm1 FinMap n a
fm2 =
  FinMap
    { getFinMap :: Map (Fin n) a
getFinMap = forall k a.
Ord k =>
(k -> a -> a -> a) -> Map k a -> Map k a -> Map k a
Map.unionWithKey Fin n -> a -> a -> a
f (forall (n :: Nat) a. FinMap n a -> Map (Fin n) a
getFinMap FinMap n a
fm1) (forall (n :: Nat) a. FinMap n a -> Map (Fin n) a
getFinMap FinMap n a
fm2)
    , maxSize :: NatRepr n
maxSize = forall (n :: Nat) a. FinMap n a -> NatRepr n
maxSize FinMap n a
fm1
    }

-- | /O(n+m)/.
unionWith :: (a -> a -> a) -> FinMap n a -> FinMap n a -> FinMap n a
unionWith :: forall a (n :: Nat).
(a -> a -> a) -> FinMap n a -> FinMap n a -> FinMap n a
unionWith a -> a -> a
f = forall (n :: Nat) a.
(Fin n -> a -> a -> a) -> FinMap n a -> FinMap n a -> FinMap n a
unionWithKey (\Fin n
_ a
v1 a
v2 -> a -> a -> a
f a
v1 a
v2)

-- | /O(n+m)/. Left-biased union, i.e. (@'union' == 'unionWith' 'const'@).
union :: FinMap n a -> FinMap n a -> FinMap n a
union :: forall (n :: Nat) a. FinMap n a -> FinMap n a -> FinMap n a
union = forall a (n :: Nat).
(a -> a -> a) -> FinMap n a -> FinMap n a -> FinMap n a
unionWith forall a b. a -> b -> a
const