{-# LANGUAGE UndecidableInstances #-}

{- |
Module: TREXIO.CooArray
Description: Coordinate list sparse array representation for TREXIO
Copyright: Phillip Seeber 2024
License: BSD-3-Clause
Maintainer: phillip.seeber@uni-jena.de
Stability: experimental
Portability: POSIX
-}
module TREXIO.CooArray (
  CooArray,
  values,
  coords,
  cooSize,
  mkCooArrayF,
  mkCooArray,
)
where

import Data.Foldable
import Data.Massiv.Array as Massiv hiding (all, toList)
import GHC.Generics (Generic)

-- | A coordinate list array representation.
data CooArray r ix a = CooArray
  { forall r ix a. CooArray r ix a -> Vector r a
values_ :: Vector r a
  , forall r ix a. CooArray r ix a -> Vector r ix
coords_ :: Vector r ix
  , forall r ix a. CooArray r ix a -> Sz ix
cooSize_ :: Sz ix
  -- ^ Size of the COO array
  }
  deriving ((forall x. CooArray r ix a -> Rep (CooArray r ix a) x)
-> (forall x. Rep (CooArray r ix a) x -> CooArray r ix a)
-> Generic (CooArray r ix a)
forall x. Rep (CooArray r ix a) x -> CooArray r ix a
forall x. CooArray r ix a -> Rep (CooArray r ix a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall r ix a x. Rep (CooArray r ix a) x -> CooArray r ix a
forall r ix a x. CooArray r ix a -> Rep (CooArray r ix a) x
$cfrom :: forall r ix a x. CooArray r ix a -> Rep (CooArray r ix a) x
from :: forall x. CooArray r ix a -> Rep (CooArray r ix a) x
$cto :: forall r ix a x. Rep (CooArray r ix a) x -> CooArray r ix a
to :: forall x. Rep (CooArray r ix a) x -> CooArray r ix a
Generic)

instance (Eq (Vector r a), Eq (Vector r ix), Eq ix) => Eq (CooArray r ix a) where
  CooArray Vector r a
v1 Vector r ix
c1 Sz ix
s1 == :: CooArray r ix a -> CooArray r ix a -> Bool
== CooArray Vector r a
v2 Vector r ix
c2 Sz ix
s2 = Vector r a
v1 Vector r a -> Vector r a -> Bool
forall a. Eq a => a -> a -> Bool
== Vector r a
v2 Bool -> Bool -> Bool
&& Vector r ix
c1 Vector r ix -> Vector r ix -> Bool
forall a. Eq a => a -> a -> Bool
== Vector r ix
c2 Bool -> Bool -> Bool
&& Sz ix
s1 Sz ix -> Sz ix -> Bool
forall a. Eq a => a -> a -> Bool
== Sz ix
s2

instance (Ord (Vector r a), Ord (Vector r ix), Ord ix) => Ord (CooArray r ix a) where
  compare :: CooArray r ix a -> CooArray r ix a -> Ordering
compare (CooArray Vector r a
v1 Vector r ix
c1 Sz ix
s1) (CooArray Vector r a
v2 Vector r ix
c2 Sz ix
s2) =
    Vector r a -> Vector r a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Vector r a
v1 Vector r a
v2 Ordering -> Ordering -> Ordering
forall a. Semigroup a => a -> a -> a
<> Vector r ix -> Vector r ix -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Vector r ix
c1 Vector r ix
c2 Ordering -> Ordering -> Ordering
forall a. Semigroup a => a -> a -> a
<> Sz ix -> Sz ix -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Sz ix
s1 Sz ix
s2

instance
  (Show (Vector r a), Show (Vector r ix), Index ix, Show ix) =>
  Show (CooArray r ix a)
  where
  show :: CooArray r ix a -> String
show CooArray{Sz ix
Vector r a
Vector r ix
$sel:values_:CooArray :: forall r ix a. CooArray r ix a -> Vector r a
$sel:coords_:CooArray :: forall r ix a. CooArray r ix a -> Vector r ix
$sel:cooSize_:CooArray :: forall r ix a. CooArray r ix a -> Sz ix
values_ :: Vector r a
coords_ :: Vector r ix
cooSize_ :: Sz ix
..} =
    String
"CooArray "
      String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Vector r a -> String
forall a. Show a => a -> String
show Vector r a
values_
      String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" "
      String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Vector r ix -> String
forall a. Show a => a -> String
show Vector r ix
coords_
      String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" "
      String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Sz ix -> String
forall a. Show a => a -> String
show Sz ix
cooSize_

values :: CooArray r ix a -> Vector r a
values :: forall r ix a. CooArray r ix a -> Vector r a
values CooArray{Vector r a
$sel:values_:CooArray :: forall r ix a. CooArray r ix a -> Vector r a
values_ :: Vector r a
values_} = Vector r a
values_

coords :: CooArray r ix a -> Vector r ix
coords :: forall r ix a. CooArray r ix a -> Vector r ix
coords CooArray{Vector r ix
$sel:coords_:CooArray :: forall r ix a. CooArray r ix a -> Vector r ix
coords_ :: Vector r ix
coords_} = Vector r ix
coords_

cooSize :: CooArray r ix a -> Sz ix
cooSize :: forall r ix a. CooArray r ix a -> Sz ix
cooSize CooArray{Sz ix
$sel:cooSize_:CooArray :: forall r ix a. CooArray r ix a -> Sz ix
cooSize_ :: Sz ix
cooSize_} = Sz ix
cooSize_

-- | Make a 'CooArray' from a list of coordinate-value pairs.
mkCooArrayF ::
  (Foldable f, Index ix, Manifest r a, Manifest r ix, MonadThrow m, Stream r Ix1 ix) =>
  Sz ix ->
  f (ix, a) ->
  m (CooArray r ix a)
mkCooArrayF :: forall (f :: * -> *) ix r a (m :: * -> *).
(Foldable f, Index ix, Manifest r a, Manifest r ix, MonadThrow m,
 Stream r Int ix) =>
Sz ix -> f (ix, a) -> m (CooArray r ix a)
mkCooArrayF Sz ix
cooSize_ f (ix, a)
coo
  | Array DS Int ix -> Bool
forall e. Array DS Int e -> Bool
forall r ix e. Shape r ix => Array r ix e -> Bool
isNull Array DS Int ix
unsafeInds = CooArray r ix a -> m (CooArray r ix a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return CooArray{Sz ix
Vector r ix
Vector r a
$sel:values_:CooArray :: Vector r a
$sel:coords_:CooArray :: Vector r ix
$sel:cooSize_:CooArray :: Sz ix
cooSize_ :: Sz ix
values_ :: Vector r a
coords_ :: Vector r ix
..}
  | Bool
otherwise = IndexException -> m (CooArray r ix a)
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM (IndexException -> m (CooArray r ix a))
-> IndexException -> m (CooArray r ix a)
forall a b. (a -> b) -> a -> b
$ Sz ix -> ix -> IndexException
forall ix. Index ix => Sz ix -> ix -> IndexException
IndexOutOfBoundsException Sz ix
cooSize_ (Array DS Int ix -> ix
forall r e. (HasCallStack, Stream r Int e) => Vector r e -> e
shead' Array DS Int ix
unsafeInds)
 where
  arr :: Vector B (ix, a)
arr = forall r e. Manifest r e => Comp -> [e] -> Vector r e
Massiv.fromList @B Comp
Par ([(ix, a)] -> Vector B (ix, a))
-> (f (ix, a) -> [(ix, a)]) -> f (ix, a) -> Vector B (ix, a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (ix, a) -> [(ix, a)]
forall a. f a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (f (ix, a) -> Vector B (ix, a)) -> f (ix, a) -> Vector B (ix, a)
forall a b. (a -> b) -> a -> b
$ f (ix, a)
coo
  values_ :: Vector r a
values_ = Array D Int a -> Vector r a
forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
Massiv.compute (Array D Int a -> Vector r a)
-> (Vector B (ix, a) -> Array D Int a)
-> Vector B (ix, a)
-> Vector r a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((ix, a) -> a) -> Vector B (ix, a) -> Array D Int a
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
Massiv.map (ix, a) -> a
forall a b. (a, b) -> b
snd (Vector B (ix, a) -> Vector r a) -> Vector B (ix, a) -> Vector r a
forall a b. (a -> b) -> a -> b
$ Vector B (ix, a)
arr
  coords_ :: Vector r ix
coords_ = Array D Int ix -> Vector r ix
forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
Massiv.compute (Array D Int ix -> Vector r ix)
-> (Vector B (ix, a) -> Array D Int ix)
-> Vector B (ix, a)
-> Vector r ix
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((ix, a) -> ix) -> Vector B (ix, a) -> Array D Int ix
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
Massiv.map (ix, a) -> ix
forall a b. (a, b) -> a
fst (Vector B (ix, a) -> Vector r ix)
-> Vector B (ix, a) -> Vector r ix
forall a b. (a -> b) -> a -> b
$ Vector B (ix, a)
arr
  unsafeInds :: Array DS Int ix
unsafeInds = (ix -> Bool) -> Vector r ix -> Array DS Int ix
forall r ix e.
Stream r ix e =>
(e -> Bool) -> Array r ix e -> Vector DS e
Massiv.sfilter (Bool -> Bool
not (Bool -> Bool) -> (ix -> Bool) -> ix -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sz ix -> ix -> Bool
forall ix. Index ix => Sz ix -> ix -> Bool
isSafeIndex Sz ix
cooSize_) Vector r ix
coords_

-- | Make a 'CooArray' from a indices and values.
mkCooArray ::
  (MonadThrow m, Index ix, Size r, Stream r Ix1 ix) =>
  Sz ix ->
  Vector r ix ->
  Vector r a ->
  m (CooArray r ix a)
mkCooArray :: forall (m :: * -> *) ix r a.
(MonadThrow m, Index ix, Size r, Stream r Int ix) =>
Sz ix -> Vector r ix -> Vector r a -> m (CooArray r ix a)
mkCooArray Sz ix
cooSize_ Vector r ix
coords_ Vector r a
values_
  | Vector r ix -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array r ix e -> Sz ix
Massiv.size Vector r ix
coords_ Sz Int -> Sz Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Vector r a -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array r ix e -> Sz ix
Massiv.size Vector r a
values_ = SizeException -> m (CooArray r ix a)
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM (SizeException -> m (CooArray r ix a))
-> SizeException -> m (CooArray r ix a)
forall a b. (a -> b) -> a -> b
$ Sz Int -> Sz Int -> SizeException
forall ix. Index ix => Sz ix -> Sz ix -> SizeException
SizeMismatchException (Vector r ix -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array r ix e -> Sz ix
Massiv.size Vector r ix
coords_) (Vector r a -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array r ix e -> Sz ix
Massiv.size Vector r a
values_)
  | Bool -> Bool
not (Bool -> Bool) -> (Vector DS ix -> Bool) -> Vector DS ix -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector DS ix -> Bool
forall e. Array DS Int e -> Bool
forall r ix e. Shape r ix => Array r ix e -> Bool
isNull (Vector DS ix -> Bool) -> Vector DS ix -> Bool
forall a b. (a -> b) -> a -> b
$ Vector DS ix
unsafeInds = IndexException -> m (CooArray r ix a)
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM (IndexException -> m (CooArray r ix a))
-> IndexException -> m (CooArray r ix a)
forall a b. (a -> b) -> a -> b
$ Sz ix -> ix -> IndexException
forall ix. Index ix => Sz ix -> ix -> IndexException
IndexOutOfBoundsException Sz ix
cooSize_ (Vector DS ix -> ix
forall r e. (HasCallStack, Stream r Int e) => Vector r e -> e
shead' Vector DS ix
unsafeInds)
  | Bool
otherwise = CooArray r ix a -> m (CooArray r ix a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return CooArray{Sz ix
Vector r ix
Vector r a
$sel:values_:CooArray :: Vector r a
$sel:coords_:CooArray :: Vector r ix
$sel:cooSize_:CooArray :: Sz ix
cooSize_ :: Sz ix
coords_ :: Vector r ix
values_ :: Vector r a
..}
 where
  unsafeInds :: Vector DS ix
unsafeInds = (ix -> Bool) -> Vector r ix -> Vector DS ix
forall r ix e.
Stream r ix e =>
(e -> Bool) -> Array r ix e -> Vector DS e
Massiv.sfilter (Bool -> Bool
not (Bool -> Bool) -> (ix -> Bool) -> ix -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sz ix -> ix -> Bool
forall ix. Index ix => Sz ix -> ix -> Bool
isSafeIndex Sz ix
cooSize_) Vector r ix
coords_