{-|
Copyright        : (c) Galois, Inc 2022

See "Data.Parameterized.FinMap".
-}

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}

module Data.Parameterized.FinMap.Unsafe
  ( FinMap
  -- * Query
  , null
  , lookup
  , size
  -- * Construction
  , incMax
  , embed
  , empty
  , singleton
  , insert
  , buildFinMap
  , append
  , fromVector
  -- * Operations
  , delete
  , decMax
  , mapWithKey
  , unionWithKey
  , unionWith
  , union
  ) where

import           Prelude hiding (lookup, null)

import           Data.Functor.WithIndex (FunctorWithIndex(imap))
import           Data.Foldable.WithIndex (FoldableWithIndex(ifoldMap))
import           Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import           GHC.TypeLits (KnownNat, Nat)
import           Numeric.Natural (Natural)
import           Unsafe.Coerce (unsafeCoerce)

import           Data.Parameterized.Fin (Fin, mkFin)
import qualified Data.Parameterized.Fin as Fin
import           Data.Parameterized.NatRepr (LeqProof, NatRepr, type (+), type (<=))
import qualified Data.Parameterized.NatRepr as NatRepr
import           Data.Parameterized.Some (Some(Some))
import           Data.Parameterized.Vector (Vector)
import qualified Data.Parameterized.Vector as Vec

-- This is pulled out as a function so that it's obvious that its use is safe
-- (since Natural is unbounded).
intToNat :: Int -> Natural
intToNat :: Int -> Natural
intToNat = forall a b. (Integral a, Num b) => a -> b
fromIntegral
{-# INLINE intToNat #-}

-- These are pulled out as functions so that it's obvious that their use is
-- unsafe (since Natural is unbounded).

unsafeFinToInt :: Fin n -> Int
unsafeFinToInt :: forall (n :: Natural). Fin n -> Int
unsafeFinToInt = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural). Fin n -> Natural
Fin.finToNat
{-# INLINE unsafeFinToInt #-}

unsafeNatReprToInt :: NatRepr n -> Int
unsafeNatReprToInt :: forall (n :: Natural). NatRepr n -> Int
unsafeNatReprToInt = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural). NatRepr n -> Natural
NatRepr.natValue
{-# INLINE unsafeNatReprToInt #-}

------------------------------------------------------------------------
-- Type

-- This datatype has two important invariants:
--
-- * Its keys must be less than the nat in its type.
-- * Its size must be less than the maximum Int.
--
-- If these invariants hold, all of the unsafe operations in this module
-- (fromJust, unsafeCoerce) will work as intended.

-- | @'FinMap' n a@ is a map with @'Fin' n@ keys and @a@ values.
newtype FinMap (n :: Nat) a = FinMap { forall (n :: Natural) a. FinMap n a -> IntMap a
getFinMap :: IntMap a }

instance Eq a => Eq (FinMap n a) where
  FinMap n a
fm1 == :: FinMap n a -> FinMap n a -> Bool
== FinMap n a
fm2 = forall (n :: Natural) a. FinMap n a -> IntMap a
getFinMap FinMap n a
fm1 forall a. Eq a => a -> a -> Bool
== forall (n :: Natural) a. FinMap n a -> IntMap a
getFinMap FinMap n a
fm2
  {-# INLINABLE (==) #-}

instance Semigroup (FinMap n a) where
  <> :: FinMap n a -> FinMap n a -> FinMap n a
(<>) = forall (n :: Natural) a. FinMap n a -> FinMap n a -> FinMap n a
union
  {-# INLINE (<>) #-}

instance KnownNat n => Monoid (FinMap n a) where
  mempty :: FinMap n a
mempty = forall (n :: Natural) a. KnownNat n => FinMap n a
empty
  {-# INLINE mempty #-}

instance Functor (FinMap n) where
  fmap :: forall a b. (a -> b) -> FinMap n a -> FinMap n b
fmap a -> b
f = forall (n :: Natural) a. IntMap a -> FinMap n a
FinMap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural) a. FinMap n a -> IntMap a
getFinMap
  {-# INLINABLE fmap #-}

instance Foldable (FinMap n) where
  foldMap :: forall m a. Monoid m => (a -> m) -> FinMap n a -> m
foldMap a -> m
f = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap a -> m
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural) a. FinMap n a -> IntMap a
getFinMap
  {-# INLINABLE foldMap #-}

instance Traversable (FinMap n) where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> FinMap n a -> f (FinMap n b)
traverse a -> f b
f FinMap n a
fm = forall (n :: Natural) a. IntMap a -> FinMap n a
FinMap forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f (forall (n :: Natural) a. FinMap n a -> IntMap a
getFinMap FinMap n a
fm)

instance FunctorWithIndex (Fin n) (FinMap n) where
  imap :: forall a b. (Fin n -> a -> b) -> FinMap n a -> FinMap n b
imap Fin n -> a -> b
f = forall (n :: Natural) a. IntMap a -> FinMap n a
FinMap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Int -> a -> b) -> IntMap a -> IntMap b
IntMap.mapWithKey (Fin n -> a -> b
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural). Int -> Fin n
unsafeFin) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural) a. FinMap n a -> IntMap a
getFinMap
  -- Inline so that RULES for IntMap.mapWithKey can fire
  {-# INLINE imap #-}

instance FoldableWithIndex (Fin n) (FinMap n) where
  ifoldMap :: forall m a. Monoid m => (Fin n -> a -> m) -> FinMap n a -> m
ifoldMap Fin n -> a -> m
f = forall m a. Monoid m => (Int -> a -> m) -> IntMap a -> m
IntMap.foldMapWithKey (Fin n -> a -> m
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural). Int -> Fin n
unsafeFin) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural) a. FinMap n a -> IntMap a
getFinMap

-- | Non-lawful instance, provided for testing
instance Show a => Show (FinMap n a) where
  show :: FinMap n a -> String
show FinMap n a
fm = forall a. Show a => a -> String
show (forall (n :: Natural) a. FinMap n a -> IntMap a
getFinMap FinMap n a
fm)
  {-# INLINABLE show #-}

------------------------------------------------------------------------
-- Query

-- | /O(1)/. Is the map empty?
null :: FinMap n a -> Bool
null :: forall (n :: Natural) a. FinMap n a -> Bool
null = forall a. IntMap a -> Bool
IntMap.null forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural) a. FinMap n a -> IntMap a
getFinMap
{-# INLINABLE null #-}

-- | /O(min(n,W))/. Fetch the value at the given key in the map.
lookup :: Fin n -> FinMap n a -> Maybe a
lookup :: forall (n :: Natural) a. Fin n -> FinMap n a -> Maybe a
lookup Fin n
k = forall a. Int -> IntMap a -> Maybe a
IntMap.lookup (forall (n :: Natural). Fin n -> Int
unsafeFinToInt Fin n
k) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural) a. FinMap n a -> IntMap a
getFinMap
{-# INLINABLE lookup #-}

-- | Unsafely create a @'Fin' n@ from an 'Int' which is known to be less than
-- @n@ for reasons not visible to the type system.
unsafeFin :: forall n. Int -> Fin n
unsafeFin :: forall (n :: Natural). Int -> Fin n
unsafeFin Int
i =
  case Natural -> Some NatRepr
NatRepr.mkNatRepr (Int -> Natural
intToNat Int
i) of
    Some (NatRepr x
repr :: NatRepr m) ->
      case forall a b. a -> b
unsafeCoerce (forall (m :: Natural) (n :: Natural). (m <= n) => LeqProof m n
NatRepr.LeqProof :: LeqProof 0 0) :: LeqProof (m + 1) n of
        LeqProof (x + 1) n
NatRepr.LeqProof -> forall (i :: Natural) (n :: Natural).
((i + 1) <= n) =>
NatRepr i -> Fin n
mkFin @m @n NatRepr x
repr

-- | /O(1)/. Number of elements in the map.
size :: forall n a. FinMap n a -> Fin (n + 1)
size :: forall (n :: Natural) a. FinMap n a -> Fin (n + 1)
size = forall (n :: Natural). Int -> Fin n
unsafeFin forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IntMap a -> Int
IntMap.size forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural) a. FinMap n a -> IntMap a
getFinMap
{-# INLINEABLE size #-}

------------------------------------------------------------------------
-- Construction

-- | /O(1)/. Increase maximum key/size by 1.
--
-- This does not alter the key-value pairs in the map, but rather increases the
-- maximum number of key-value pairs that the map can hold. See
-- "Data.Parameterized.FinMap" for more information.
--
-- Requires @n + 1 < (maxBound :: Int)@.
incMax :: FinMap n a -> FinMap (n + 1) a
incMax :: forall (n :: Natural) a. FinMap n a -> FinMap (n + 1) a
incMax = forall (n :: Natural) a. IntMap a -> FinMap n a
FinMap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural) a. FinMap n a -> IntMap a
getFinMap
{-# INLINE incMax #-}

-- | /O(1)/. Increase maximum key/size.
--
-- Requires @m < (maxBound :: Int)@.
embed :: (n <= m) => NatRepr m -> FinMap n a -> FinMap m a
embed :: forall (n :: Natural) (m :: Natural) a.
(n <= m) =>
NatRepr m -> FinMap n a -> FinMap m a
embed NatRepr m
_ = forall (n :: Natural) a. IntMap a -> FinMap n a
FinMap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural) a. FinMap n a -> IntMap a
getFinMap
{-# INLINE embed #-}

-- | /O(1)/. The empty map.
empty :: KnownNat n => FinMap n a
empty :: forall (n :: Natural) a. KnownNat n => FinMap n a
empty = forall (n :: Natural) a. IntMap a -> FinMap n a
FinMap forall a. IntMap a
IntMap.empty
{-# INLINE empty #-}

-- | /O(1)/. A map with one element.
singleton :: a -> FinMap 1 a
singleton :: forall a. a -> FinMap 1 a
singleton = forall (n :: Natural) a. IntMap a -> FinMap n a
FinMap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> a -> IntMap a
IntMap.singleton Int
0
{-# INLINABLE singleton #-}

-- | /O(min(n,W))/.
insert :: Fin n -> a -> FinMap n a -> FinMap n a
insert :: forall (n :: Natural) a. Fin n -> a -> FinMap n a -> FinMap n a
insert Fin n
k a
v = forall (n :: Natural) a. IntMap a -> FinMap n a
FinMap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert (forall (n :: Natural). Fin n -> Int
unsafeFinToInt Fin n
k) a
v forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural) a. FinMap n a -> IntMap a
getFinMap
{-# INLINABLE insert #-}

-- buildFinMap, append, and fromVector are duplicated exactly between the safe
-- and unsafe modules because they are used in comparative testing (and so
-- implementations must be available for both types).

newtype FinMap' a (n :: Nat) = FinMap' { forall a (n :: Natural). FinMap' a n -> FinMap n a
unFinMap' :: FinMap n a }

buildFinMap ::
  forall m a.
  NatRepr m ->
  (forall n. (n + 1 <= m) => NatRepr n -> FinMap n a -> FinMap (n + 1) a) ->
  FinMap m a
buildFinMap :: forall (m :: Natural) a.
NatRepr m
-> (forall (n :: Natural).
    ((n + 1) <= m) =>
    NatRepr n -> FinMap n a -> FinMap (n + 1) a)
-> FinMap m a
buildFinMap NatRepr m
m forall (n :: Natural).
((n + 1) <= m) =>
NatRepr n -> FinMap n a -> FinMap (n + 1) a
f =
  let f' :: forall k. (k + 1 <= m) => NatRepr k -> FinMap' a k -> FinMap' a (k + 1)
      f' :: forall (k :: Natural).
((k + 1) <= m) =>
NatRepr k -> FinMap' a k -> FinMap' a (k + 1)
f' = (\NatRepr k
n (FinMap' FinMap k a
fin) -> forall a (n :: Natural). FinMap n a -> FinMap' a n
FinMap' (forall (n :: Natural).
((n + 1) <= m) =>
NatRepr n -> FinMap n a -> FinMap (n + 1) a
f NatRepr k
n FinMap k a
fin))
  in forall a (n :: Natural). FinMap' a n -> FinMap n a
unFinMap' (forall (m :: Natural) (f :: Natural -> *).
NatRepr m
-> f 0
-> (forall (n :: Natural).
    ((n + 1) <= m) =>
    NatRepr n -> f n -> f (n + 1))
-> f m
NatRepr.natRecStrictlyBounded NatRepr m
m (forall a (n :: Natural). FinMap n a -> FinMap' a n
FinMap' forall (n :: Natural) a. KnownNat n => FinMap n a
empty) forall (k :: Natural).
((k + 1) <= m) =>
NatRepr k -> FinMap' a k -> FinMap' a (k + 1)
f')

-- | /O(min(n,W))/.
append :: NatRepr n -> a -> FinMap n a -> FinMap (n + 1) a
append :: forall (n :: Natural) a.
NatRepr n -> a -> FinMap n a -> FinMap (n + 1) a
append NatRepr n
k a
v FinMap n a
fm =
  case forall (f :: Natural -> *) (z :: Natural).
f z -> LeqProof z (z + 1)
NatRepr.leqSucc NatRepr n
k of
    LeqProof n (n + 1)
NatRepr.LeqProof -> forall (n :: Natural) a. Fin n -> a -> FinMap n a -> FinMap n a
insert (forall (i :: Natural) (n :: Natural).
((i + 1) <= n) =>
NatRepr i -> Fin n
mkFin NatRepr n
k) a
v (forall (n :: Natural) a. FinMap n a -> FinMap (n + 1) a
incMax FinMap n a
fm)

fromVector :: forall n a. Vector n (Maybe a) -> FinMap n a
fromVector :: forall (n :: Natural) a. Vector n (Maybe a) -> FinMap n a
fromVector Vector n (Maybe a)
v =
  forall (m :: Natural) a.
NatRepr m
-> (forall (n :: Natural).
    ((n + 1) <= m) =>
    NatRepr n -> FinMap n a -> FinMap (n + 1) a)
-> FinMap m a
buildFinMap
    (forall (n :: Natural) a. Vector n a -> NatRepr n
Vec.length Vector n (Maybe a)
v)
    (\NatRepr n
k FinMap n a
m ->
      case forall (i :: Natural) (n :: Natural) a.
((i + 1) <= n) =>
NatRepr i -> Vector n a -> a
Vec.elemAt NatRepr n
k Vector n (Maybe a)
v of
        Just a
e -> forall (n :: Natural) a.
NatRepr n -> a -> FinMap n a -> FinMap (n + 1) a
append NatRepr n
k a
e FinMap n a
m
        Maybe a
Nothing -> forall (n :: Natural) a. FinMap n a -> FinMap (n + 1) a
incMax FinMap n a
m)

------------------------------------------------------------------------
-- Operations

-- | /O(min(n,W))/.
delete :: Fin n -> FinMap n a -> FinMap n a
delete :: forall (n :: Natural) a. Fin n -> FinMap n a -> FinMap n a
delete Fin n
k = forall (n :: Natural) a. IntMap a -> FinMap n a
FinMap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> IntMap a -> IntMap a
IntMap.delete (forall (n :: Natural). Fin n -> Int
unsafeFinToInt Fin n
k) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural) a. FinMap n a -> IntMap a
getFinMap
{-# INLINABLE delete #-}

-- | Decrement the key/size, removing the item at key @n + 1@ if present.
decMax :: NatRepr n -> FinMap (n + 1) a -> FinMap n a
decMax :: forall (n :: Natural) a.
NatRepr n -> FinMap (n + 1) a -> FinMap n a
decMax NatRepr n
k = forall (n :: Natural) a. IntMap a -> FinMap n a
FinMap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> IntMap a -> IntMap a
IntMap.delete (forall (n :: Natural). NatRepr n -> Int
unsafeNatReprToInt NatRepr n
k) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural) a. FinMap n a -> IntMap a
getFinMap
{-# INLINABLE decMax #-}

mapWithKey :: (Fin n -> a -> b) -> FinMap n a -> FinMap n b
mapWithKey :: forall (n :: Natural) a b.
(Fin n -> a -> b) -> FinMap n a -> FinMap n b
mapWithKey Fin n -> a -> b
f = forall (n :: Natural) a. IntMap a -> FinMap n a
FinMap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Int -> a -> b) -> IntMap a -> IntMap b
IntMap.mapWithKey (Fin n -> a -> b
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural). Int -> Fin n
unsafeFin) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural) a. FinMap n a -> IntMap a
getFinMap
-- Inline so that RULES for IntMap.mapWithKey can fire
{-# INLINE mapWithKey #-}

-- | /O(n+m)/.
unionWithKey :: (Fin n -> a -> a -> a) -> FinMap n a -> FinMap n a -> FinMap n a
unionWithKey :: forall (n :: Natural) a.
(Fin n -> a -> a -> a) -> FinMap n a -> FinMap n a -> FinMap n a
unionWithKey Fin n -> a -> a -> a
f FinMap n a
fm1 FinMap n a
fm2 =
  forall (n :: Natural) a. IntMap a -> FinMap n a
FinMap (forall a. (Int -> a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
IntMap.unionWithKey (Fin n -> a -> a -> a
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural). Int -> Fin n
unsafeFin) (forall (n :: Natural) a. FinMap n a -> IntMap a
getFinMap FinMap n a
fm1) (forall (n :: Natural) a. FinMap n a -> IntMap a
getFinMap FinMap n a
fm2))

-- | /O(n+m)/.
unionWith :: (a -> a -> a) -> FinMap n a -> FinMap n a -> FinMap n a
unionWith :: forall a (n :: Natural).
(a -> a -> a) -> FinMap n a -> FinMap n a -> FinMap n a
unionWith a -> a -> a
f = forall (n :: Natural) a.
(Fin n -> a -> a -> a) -> FinMap n a -> FinMap n a -> FinMap n a
unionWithKey (\Fin n
_ a
v1 a
v2 -> a -> a -> a
f a
v1 a
v2)

-- | /O(n+m)/. Left-biased union, i.e. (@'union' == 'unionWith' 'const'@).
union :: FinMap n a -> FinMap n a -> FinMap n a
union :: forall (n :: Natural) a. FinMap n a -> FinMap n a -> FinMap n a
union = forall a (n :: Natural).
(a -> a -> a) -> FinMap n a -> FinMap n a -> FinMap n a
unionWith forall a b. a -> b -> a
const