{-# LANGUAGE AllowAmbiguousTypes #-}

module Telescope.Asdf.NDArray
  ( NDArrayData (..)
  , FromNDArray (..)
  , ToNDArray (..)
  , DataType (..)
  , IsDataType (..)
  , parseGet
  , ndArrayPut
  , ndArrayMassiv
  , parseMassiv
  , parseNDArray
  , ByteOrder (..)
  , getUcs4
  , putUcs4
  , Parser
  )
where

import Control.Monad (replicateM)
import Control.Monad.Catch (try)
import Data.Binary.Get hiding (getBytes)
import Data.Binary.Put
import Data.ByteString (ByteString)
import Data.ByteString qualified as BS
import Data.ByteString.Lazy qualified as BL
import Data.Massiv.Array (Array, D, Prim, Sz (..))
import Data.Massiv.Array qualified as M
import Data.Text (Text)
import Data.Text.Encoding qualified as T
import Effectful
import Telescope.Asdf.NDArray.Types
import Telescope.Asdf.Node
import Telescope.Data.Array
import Telescope.Data.Axes
import Telescope.Data.Binary
import Telescope.Data.Parser


-- import Telescope.Asdf.Node

{- | Convert an 'NDArrayData' into a type
https://asdf-standard.readthedocs.io/en/latest/generated/stsci.edu/asdf/core/ndarray-1.1.0.html
-}
class FromNDArray a where
  fromNDArray :: (Parser :> es) => NDArrayData -> Eff es a


{- | Convert a type to an 'NDArrayData'
https://asdf-standard.readthedocs.io/en/latest/generated/stsci.edu/asdf/core/ndarray-1.1.0.html
-}
class ToNDArray a where
  toNDArray :: a -> NDArrayData


instance {-# OVERLAPPABLE #-} (BinaryValue a, IsDataType a) => ToNDArray [a] where
  toNDArray :: [a] -> NDArrayData
toNDArray = ([a] -> Axes 'Row) -> ([a] -> Put) -> [a] -> NDArrayData
forall a.
IsDataType a =>
(a -> Axes 'Row) -> (a -> Put) -> a -> NDArrayData
ndArrayPut [a] -> Axes 'Row
forall {t :: * -> *} {a}. Foldable t => t a -> Axes 'Row
shape [a] -> Put
putBytes
   where
    putBytes :: [a] -> Put
putBytes = (a -> Put) -> [a] -> Put
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (ByteOrder -> a -> Put
forall a. BinaryValue a => ByteOrder -> a -> Put
put ByteOrder
BigEndian)
    shape :: t a -> Axes 'Row
shape t a
as = [Axis] -> Axes 'Row
axesRowMajor [t a -> Axis
forall a. t a -> Axis
forall (t :: * -> *) a. Foldable t => t a -> Axis
length t a
as]


instance {-# OVERLAPPING #-} (BinaryValue a, IsDataType a) => ToNDArray [[a]] where
  toNDArray :: [[a]] -> NDArrayData
toNDArray = ([[a]] -> Axes 'Row) -> ([[a]] -> Put) -> [[a]] -> NDArrayData
forall a.
IsDataType a =>
(a -> Axes 'Row) -> (a -> Put) -> a -> NDArrayData
ndArrayPut [[a]] -> Axes 'Row
forall {a}. [[a]] -> Axes 'Row
shape [[a]] -> Put
putBytes
   where
    putBytes :: [[a]] -> Put
putBytes = ([a] -> Put) -> [[a]] -> Put
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((a -> Put) -> [a] -> Put
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (ByteOrder -> a -> Put
forall a. BinaryValue a => ByteOrder -> a -> Put
put ByteOrder
BigEndian))
    shape :: [[a]] -> Axes 'Row
shape = [Axis] -> Axes 'Row
axesRowMajor ([Axis] -> Axes 'Row) -> ([[a]] -> [Axis]) -> [[a]] -> Axes 'Row
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[a]] -> [Axis]
forall {t :: * -> *} {a}. Foldable t => [t a] -> [Axis]
dimensions
    dimensions :: [t a] -> [Axis]
dimensions [] = []
    dimensions (t a
r1 : [t a]
rs) =
      [[t a] -> Axis
forall a. [a] -> Axis
forall (t :: * -> *) a. Foldable t => t a -> Axis
length [t a]
rs Axis -> Axis -> Axis
forall a. Num a => a -> a -> a
+ Axis
1, t a -> Axis
forall a. t a -> Axis
forall (t :: * -> *) a. Foldable t => t a -> Axis
length t a
r1]


instance {-# OVERLAPPABLE #-} (BinaryValue a) => FromNDArray [a] where
  fromNDArray :: forall (es :: [Effect]).
(Parser :> es) =>
NDArrayData -> Eff es [a]
fromNDArray NDArrayData
arr = Get [a] -> ByteString -> Eff es [a]
forall (es :: [Effect]) a.
(Parser :> es) =>
Get a -> ByteString -> Eff es a
parseGet (ByteOrder -> Axes 'Row -> Get [a]
forall {a} {a :: Major}.
BinaryValue a =>
ByteOrder -> Axes a -> Get [a]
getBytes NDArrayData
arr.byteorder NDArrayData
arr.shape) NDArrayData
arr.bytes
   where
    getBytes :: ByteOrder -> Axes a -> Get [a]
getBytes ByteOrder
bo Axes a
axes = do
      let num :: Axis
num = Axes a -> Axis
forall (a :: Major). Axes a -> Axis
totalItems Axes a
axes
      Axis -> Get a -> Get [a]
forall (m :: * -> *) a. Applicative m => Axis -> m a -> m [a]
replicateM Axis
num (ByteOrder -> Get a
forall a. BinaryValue a => ByteOrder -> Get a
get ByteOrder
bo)


instance FromNDArray [Text] where
  fromNDArray :: forall (es :: [Effect]).
(Parser :> es) =>
NDArrayData -> Eff es [Text]
fromNDArray NDArrayData
arr = do
    Axis
n <- Axis -> Axis
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Axis -> Axis) -> Eff es Axis -> Eff es Axis
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DataType -> Eff es Axis
ucs4Size NDArrayData
arr.datatype
    Get [Text] -> ByteString -> Eff es [Text]
forall (es :: [Effect]) a.
(Parser :> es) =>
Get a -> ByteString -> Eff es a
parseGet (Axis -> Get Text -> Get [Text]
forall (m :: * -> *) a. Applicative m => Axis -> m a -> m [a]
replicateM (Axes 'Row -> Axis
forall (a :: Major). Axes a -> Axis
totalItems NDArrayData
arr.shape) (ByteOrder -> Axis -> Get Text
getUcs4 NDArrayData
arr.byteorder Axis
n)) NDArrayData
arr.bytes
   where
    ucs4Size :: DataType -> Eff es Axis
ucs4Size = \case
      Ucs4 Axis
n -> Axis -> Eff es Axis
forall a. a -> Eff es a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Axis
n
      DataType
dt -> String -> DataType -> Eff es Axis
forall value (es :: [Effect]) a.
(Show value, Parser :> es) =>
String -> value -> Eff es a
expected String
"Ucs4" DataType
dt


-- decode LittleEndian = T.decodeUtf32LE
-- decode BigEndian = T.decodeUtf32BE

instance {-# OVERLAPPING #-} (BinaryValue a) => FromNDArray [[a]] where
  fromNDArray :: forall (es :: [Effect]).
(Parser :> es) =>
NDArrayData -> Eff es [[a]]
fromNDArray NDArrayData
arr = Get [[a]] -> ByteString -> Eff es [[a]]
forall (es :: [Effect]) a.
(Parser :> es) =>
Get a -> ByteString -> Eff es a
parseGet (Axes 'Row -> Get [[a]]
getBytes NDArrayData
arr.shape) NDArrayData
arr.bytes
   where
    getBytes :: Axes 'Row -> Get [[a]]
getBytes (Axes [Axis]
rows) = (Axis -> Get [a]) -> [Axis] -> Get [[a]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Axis -> Get [a]
getRow [Axis]
rows
    getRow :: Axis -> Get [a]
getRow Axis
n = Axis -> Get a -> Get [a]
forall (m :: * -> *) a. Applicative m => Axis -> m a -> m [a]
replicateM Axis
n (ByteOrder -> Get a
forall a. BinaryValue a => ByteOrder -> Get a
get NDArrayData
arr.byteorder)


instance (BinaryValue a, Prim a, AxesIndex ix) => FromNDArray (Array D ix a) where
  fromNDArray :: forall (es :: [Effect]).
(Parser :> es) =>
NDArrayData -> Eff es (Array D ix a)
fromNDArray = NDArrayData -> Eff es (Array D ix a)
forall a ix (es :: [Effect]).
(BinaryValue a, AxesIndex ix, Parser :> es) =>
NDArrayData -> Eff es (Array D ix a)
parseMassiv


instance (BinaryValue a, IsDataType a, Prim a, AxesIndex ix, PutArray ix) => ToNDArray (Array D ix a) where
  toNDArray :: Array D ix a -> NDArrayData
toNDArray = Array D ix a -> NDArrayData
forall a ix.
(IsDataType a, BinaryValue a, Prim a, AxesIndex ix, PutArray ix) =>
Array D ix a -> NDArrayData
ndArrayMassiv


parseGet :: (Parser :> es) => Get a -> ByteString -> Eff es a
parseGet :: forall (es :: [Effect]) a.
(Parser :> es) =>
Get a -> ByteString -> Eff es a
parseGet Get a
gt ByteString
bytes =
  case Get a
-> ByteString
-> Either
     (ByteString, ByteOffset, String) (ByteString, ByteOffset, a)
forall a.
Get a
-> ByteString
-> Either
     (ByteString, ByteOffset, String) (ByteString, ByteOffset, a)
runGetOrFail Get a
gt (ByteString -> ByteString
BL.fromStrict ByteString
bytes) of
    Left (ByteString
rest, ByteOffset
nused, String
err) ->
      String -> Eff es a
forall (es :: [Effect]) a. (Parser :> es) => String -> Eff es a
parseFail (String -> Eff es a) -> String -> Eff es a
forall a b. (a -> b) -> a -> b
$ String
"could not decode binary data at (" String -> String -> String
forall a. [a] -> [a] -> [a]
++ ByteOffset -> String
forall a. Show a => a -> String
show ByteOffset
nused String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
") (rest " String -> String -> String
forall a. [a] -> [a] -> [a]
++ ByteOffset -> String
forall a. Show a => a -> String
show (ByteString -> ByteOffset
BL.length ByteString
rest) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"): " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
err
    Right (ByteString
_, ByteOffset
_, a
a) -> a -> Eff es a
forall a. a -> Eff es a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a


ndArrayPut :: forall a. (IsDataType a) => (a -> Axes Row) -> (a -> Put) -> a -> NDArrayData
ndArrayPut :: forall a.
IsDataType a =>
(a -> Axes 'Row) -> (a -> Put) -> a -> NDArrayData
ndArrayPut a -> Axes 'Row
toShape a -> Put
putA a
a =
  let bytes :: ByteString
bytes = ByteString -> ByteString
BL.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Put -> ByteString
runPut (a -> Put
putA a
a)
   in NDArrayData{ByteString
bytes :: ByteString
bytes :: ByteString
bytes, byteorder :: ByteOrder
byteorder = ByteOrder
BigEndian, datatype :: DataType
datatype = forall a. IsDataType a => DataType
forall {k} (a :: k). IsDataType a => DataType
dataType @a, shape :: Axes 'Row
shape = a -> Axes 'Row
toShape a
a}


ndArrayMassiv :: forall a ix. (IsDataType a, BinaryValue a, Prim a, AxesIndex ix, PutArray ix) => Array D ix a -> NDArrayData
ndArrayMassiv :: forall a ix.
(IsDataType a, BinaryValue a, Prim a, AxesIndex ix, PutArray ix) =>
Array D ix a -> NDArrayData
ndArrayMassiv Array D ix a
arr =
  let bytes :: ByteString
bytes = Array D ix a -> ByteString
forall r a ix.
(Source r a, Stream r Axis a, PutArray ix, BinaryValue a,
 Prim a) =>
Array r ix a -> ByteString
encodeArray Array D ix a
arr
      Sz ix
ix = Array D ix a -> Sz ix
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array D ix e -> Sz ix
M.size Array D ix a
arr
      shape :: Axes 'Row
shape = ix -> Axes 'Row
forall ix. AxesIndex ix => ix -> Axes 'Row
indexAxes ix
ix
      datatype :: DataType
datatype = forall a. IsDataType a => DataType
forall {k} (a :: k). IsDataType a => DataType
dataType @a
   in NDArrayData{ByteString
bytes :: ByteString
bytes :: ByteString
bytes, Axes 'Row
shape :: Axes 'Row
shape :: Axes 'Row
shape, byteorder :: ByteOrder
byteorder = ByteOrder
BigEndian, DataType
datatype :: DataType
datatype :: DataType
datatype}


parseMassiv :: (BinaryValue a, AxesIndex ix, Parser :> es) => NDArrayData -> Eff es (Array D ix a)
parseMassiv :: forall a ix (es :: [Effect]).
(BinaryValue a, AxesIndex ix, Parser :> es) =>
NDArrayData -> Eff es (Array D ix a)
parseMassiv NDArrayData
nda = do
  Either ArrayError (Array D ix a)
ea <- Eff es (Array D ix a) -> Eff es (Either ArrayError (Array D ix a))
forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
try (Eff es (Array D ix a)
 -> Eff es (Either ArrayError (Array D ix a)))
-> Eff es (Array D ix a)
-> Eff es (Either ArrayError (Array D ix a))
forall a b. (a -> b) -> a -> b
$ ByteOrder -> Axes 'Row -> ByteString -> Eff es (Array D ix a)
forall ix a (m :: * -> *).
(AxesIndex ix, BinaryValue a, MonadThrow m, MonadCatch m) =>
ByteOrder -> Axes 'Row -> ByteString -> m (Array D ix a)
decodeArrayOrder NDArrayData
nda.byteorder NDArrayData
nda.shape NDArrayData
nda.bytes
  case Either ArrayError (Array D ix a)
ea of
    Left (ArrayError
e :: ArrayError) -> String -> Eff es (Array D ix a)
forall (es :: [Effect]) a. (Parser :> es) => String -> Eff es a
parseFail (String -> Eff es (Array D ix a))
-> String -> Eff es (Array D ix a)
forall a b. (a -> b) -> a -> b
$ ArrayError -> String
forall a. Show a => a -> String
show ArrayError
e
    Right Array D ix a
a -> Array D ix a -> Eff es (Array D ix a)
forall a. a -> Eff es a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Array D ix a
a


putUcs4 :: Int -> Text -> Put
putUcs4 :: Axis -> Text -> Put
putUcs4 Axis
n Text
t = ByteString -> Put
putByteString (ByteString -> Put) -> ByteString -> Put
forall a b. (a -> b) -> a -> b
$ Axis -> ByteString -> ByteString
justifyUcs4 Axis
n (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
T.encodeUtf32BE Text
t


getUcs4 :: ByteOrder -> Int -> Get Text
getUcs4 :: ByteOrder -> Axis -> Get Text
getUcs4 ByteOrder
bo Axis
n =
  ByteString -> Text
decodeUcs4 (ByteString -> Text) -> Get ByteString -> Get Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Axis -> Get ByteString
getByteString (Axis
n Axis -> Axis -> Axis
forall a. Num a => a -> a -> a
* Axis
4)
 where
  decodeUcs4 :: ByteString -> Text
decodeUcs4 ByteString
bs =
    case ByteOrder
bo of
      ByteOrder
BigEndian -> ByteString -> Text
T.decodeUtf32BE (ByteString -> Text)
-> (ByteString -> ByteString) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word8 -> Bool) -> ByteString -> ByteString
BS.dropWhileEnd (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x0) (ByteString -> Text) -> ByteString -> Text
forall a b. (a -> b) -> a -> b
$ ByteString
bs
      ByteOrder
LittleEndian -> ByteString -> Text
T.decodeUtf32LE (ByteString -> Text)
-> (ByteString -> ByteString) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word8 -> Bool) -> ByteString -> ByteString
BS.dropWhile (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x0) (ByteString -> Text) -> ByteString -> Text
forall a b. (a -> b) -> a -> b
$ ByteString
bs


justifyUcs4 :: Int -> BS.ByteString -> BS.ByteString
justifyUcs4 :: Axis -> ByteString -> ByteString
justifyUcs4 Axis
len ByteString
bs =
  let nulls :: Axis
nulls = Axis
len Axis -> Axis -> Axis
forall a. Num a => a -> a -> a
* Axis
4 Axis -> Axis -> Axis
forall a. Num a => a -> a -> a
- ByteString -> Axis
BS.length ByteString
bs
   in ByteString
bs ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Axis -> Word8 -> ByteString
BS.replicate Axis
nulls Word8
0x0


parseNDArray :: (FromNDArray a, Parser :> es) => Value -> Eff es a
parseNDArray :: forall a (es :: [Effect]).
(FromNDArray a, Parser :> es) =>
Value -> Eff es a
parseNDArray Value
val = do
  NDArrayData
dat <- Value -> Eff es NDArrayData
forall {es :: [Effect]}.
(Parser :> es) =>
Value -> Eff es NDArrayData
ndarray Value
val
  NDArrayData -> Eff es a
forall (es :: [Effect]). (Parser :> es) => NDArrayData -> Eff es a
forall a (es :: [Effect]).
(FromNDArray a, Parser :> es) =>
NDArrayData -> Eff es a
fromNDArray NDArrayData
dat
 where
  ndarray :: Value -> Eff es NDArrayData
ndarray (NDArray NDArrayData
a) = NDArrayData -> Eff es NDArrayData
forall a. a -> Eff es a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NDArrayData
a
  ndarray Value
v = String -> Value -> Eff es NDArrayData
forall value (es :: [Effect]) a.
(Show value, Parser :> es) =>
String -> value -> Eff es a
expected String
"NDArray" Value
v