{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeSynonymInstances #-}
#if __GLASGOW_HASKELL__ < 820
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
#endif
module Data.Massiv.Core.Index.Internal
( Sz(SafeSz)
, pattern Sz
, pattern Sz1
, type Sz1
, unSz
, zeroSz
, oneSz
, consSz
, unconsSz
, snocSz
, unsnocSz
, setSzM
, insertSzM
, pullOutSzM
, Dim(..)
, Dimension(DimN)
, pattern Dim1
, pattern Dim2
, pattern Dim3
, pattern Dim4
, pattern Dim5
, IsIndexDimension
, Lower
, Index(..)
, Ix0(..)
, type Ix1
, pattern Ix1
, IndexException(..)
, SizeException(..)
, ShapeException(..)
) where
import Control.DeepSeq
import Control.Exception (Exception(..))
import Control.Monad.Catch (MonadThrow(..))
import Data.Coerce
import Data.Massiv.Core.Iterator
import Data.Typeable
import GHC.TypeLits
newtype Sz ix =
SafeSz ix
deriving (Eq, Ord, NFData)
pattern Sz :: Index ix => ix -> Sz ix
pattern Sz ix <- SafeSz ix where
Sz ix = SafeSz (liftIndex (max 0) ix)
{-# COMPLETE Sz #-}
type Sz1 = Sz Ix1
pattern Sz1 :: Ix1 -> Sz1
pattern Sz1 ix <- SafeSz ix where
Sz1 ix = SafeSz (max 0 ix)
{-# COMPLETE Sz1 #-}
instance Index ix => Show (Sz ix) where
showsPrec n sz@(SafeSz usz) s =
if n == 0
then str ++ s
else '(' : str ++ ')' : s
where
str =
"Sz" ++
case unDim (dimensions sz) of
1 -> "1 " ++ show usz
_ -> " (" ++ show usz ++ ")"
instance (Num ix, Index ix) => Num (Sz ix) where
(+) x y = Sz (coerce x + coerce y)
{-# INLINE (+) #-}
(-) x y = Sz (coerce x - coerce y)
{-# INLINE (-) #-}
(*) x y = SafeSz (coerce x * coerce y)
{-# INLINE (*) #-}
abs !x = x
{-# INLINE abs #-}
negate !_x = 0
{-# INLINE negate #-}
signum x = SafeSz (signum (coerce x))
{-# INLINE signum #-}
fromInteger = Sz . fromInteger
{-# INLINE fromInteger #-}
unSz :: Sz ix -> ix
unSz (SafeSz ix) = ix
{-# INLINE unSz #-}
zeroSz :: Index ix => Sz ix
zeroSz = SafeSz (pureIndex 0)
{-# INLINE zeroSz #-}
oneSz :: Index ix => Sz ix
oneSz = SafeSz (pureIndex 1)
{-# INLINE oneSz #-}
consSz :: Index ix => Sz1 -> Sz (Lower ix) -> Sz ix
consSz (SafeSz i) (SafeSz ix) = SafeSz (consDim i ix)
{-# INLINE consSz #-}
snocSz :: Index ix => Sz (Lower ix) -> Sz1 -> Sz ix
snocSz (SafeSz i) (SafeSz ix) = SafeSz (snocDim i ix)
{-# INLINE snocSz #-}
setSzM :: (MonadThrow m, Index ix) => Sz ix -> Dim -> Sz Int -> m (Sz ix)
setSzM (SafeSz sz) dim (SafeSz sz1) = SafeSz <$> setDimM sz dim sz1
{-# INLINE setSzM #-}
insertSzM :: (MonadThrow m, Index ix) => Sz (Lower ix) -> Dim -> Sz Int -> m (Sz ix)
insertSzM (SafeSz sz) dim (SafeSz sz1) = SafeSz <$> insertDimM sz dim sz1
{-# INLINE insertSzM #-}
unconsSz :: Index ix => Sz ix -> (Sz1, Sz (Lower ix))
unconsSz (SafeSz sz) = coerce (unconsDim sz)
{-# INLINE unconsSz #-}
unsnocSz :: Index ix => Sz ix -> (Sz (Lower ix), Sz1)
unsnocSz (SafeSz sz) = coerce (unsnocDim sz)
{-# INLINE unsnocSz #-}
pullOutSzM :: (MonadThrow m, Index ix) => Sz ix -> Dim -> m (Sz Ix1, Sz (Lower ix))
pullOutSzM (SafeSz sz) = fmap coerce . pullOutDimM sz
{-# INLINE pullOutSzM #-}
newtype Dim = Dim { unDim :: Int } deriving (Eq, Ord, Num, Real, Integral, Enum)
instance Show Dim where
show (Dim d) = "(Dim " ++ show d ++ ")"
data Dimension (n :: Nat) where
DimN :: (1 <= n, KnownNat n) => Dimension n
pattern Dim1 :: Dimension 1
pattern Dim1 = DimN
pattern Dim2 :: Dimension 2
pattern Dim2 = DimN
pattern Dim3 :: Dimension 3
pattern Dim3 = DimN
pattern Dim4 :: Dimension 4
pattern Dim4 = DimN
pattern Dim5 :: Dimension 5
pattern Dim5 = DimN
type IsIndexDimension ix n = (1 <= n, n <= Dimensions ix, Index ix, KnownNat n)
type family Lower ix :: *
class ( Eq ix
, Ord ix
, Show ix
, NFData ix
, Eq (Lower ix)
, Ord (Lower ix)
, Show (Lower ix)
, NFData (Lower ix)
) =>
Index ix
where
type Dimensions ix :: Nat
dimensions :: proxy ix -> Dim
totalElem :: Sz ix -> Int
consDim :: Int -> Lower ix -> ix
unconsDim :: ix -> (Int, Lower ix)
snocDim :: Lower ix -> Int -> ix
unsnocDim :: ix -> (Lower ix, Int)
pullOutDimM :: MonadThrow m => ix -> Dim -> m (Int, Lower ix)
insertDimM :: MonadThrow m => Lower ix -> Dim -> Int -> m ix
getDimM :: MonadThrow m => ix -> Dim -> m Int
setDimM :: MonadThrow m => ix -> Dim -> Int -> m ix
pureIndex :: Int -> ix
liftIndex2 :: (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex :: (Int -> Int) -> ix -> ix
liftIndex f = liftIndex2 (\_ i -> f i) (pureIndex 0)
{-# INLINE [1] liftIndex #-}
foldlIndex :: (a -> Int -> a) -> a -> ix -> a
default foldlIndex :: Index (Lower ix) =>
(a -> Int -> a) -> a -> ix -> a
foldlIndex f !acc !ix = foldlIndex f (f acc i0) ixL
where
!(i0, ixL) = unconsDim ix
{-# INLINE [1] foldlIndex #-}
isSafeIndex ::
Sz ix
-> ix
-> Bool
default isSafeIndex :: Index (Lower ix) =>
Sz ix -> ix -> Bool
isSafeIndex sz !ix = isSafeIndex n0 i0 && isSafeIndex szL ixL
where
!(n0, szL) = unconsSz sz
!(i0, ixL) = unconsDim ix
{-# INLINE [1] isSafeIndex #-}
toLinearIndex ::
Sz ix
-> ix
-> Int
default toLinearIndex :: Index (Lower ix) =>
Sz ix -> ix -> Int
toLinearIndex (SafeSz sz) !ix = toLinearIndex (SafeSz szL) ixL * n + i
where
!(szL, n) = unsnocDim sz
!(ixL, i) = unsnocDim ix
{-# INLINE [1] toLinearIndex #-}
toLinearIndexAcc :: Int -> ix -> ix -> Int
default toLinearIndexAcc :: Index (Lower ix) =>
Int -> ix -> ix -> Int
toLinearIndexAcc !acc !sz !ix = toLinearIndexAcc (acc * n + i) szL ixL
where
!(n, szL) = unconsDim sz
!(i, ixL) = unconsDim ix
{-# INLINE [1] toLinearIndexAcc #-}
fromLinearIndex :: Sz ix -> Int -> ix
default fromLinearIndex :: Index (Lower ix) =>
Sz ix -> Int -> ix
fromLinearIndex (SafeSz sz) k = consDim q ixL
where
!(q, ixL) = fromLinearIndexAcc (snd (unconsDim sz)) k
{-# INLINE [1] fromLinearIndex #-}
fromLinearIndexAcc :: ix -> Int -> (Int, ix)
default fromLinearIndexAcc :: Index (Lower ix) =>
ix -> Int -> (Int, ix)
fromLinearIndexAcc ix' !k = (q, consDim r ixL)
where
!(m, ix) = unconsDim ix'
!(kL, ixL) = fromLinearIndexAcc ix k
!(q, r) = quotRem kL m
{-# INLINE [1] fromLinearIndexAcc #-}
repairIndex ::
Sz ix
-> ix
-> (Sz Int -> Int -> Int)
-> (Sz Int -> Int -> Int)
-> ix
default repairIndex :: Index (Lower ix) =>
Sz ix -> ix -> (Sz Int -> Int -> Int) -> (Sz Int -> Int -> Int) -> ix
repairIndex sz !ix rBelow rOver =
consDim (repairIndex n i rBelow rOver) (repairIndex szL ixL rBelow rOver)
where
!(n, szL) = unconsSz sz
!(i, ixL) = unconsDim ix
{-# INLINE [1] repairIndex #-}
iterM ::
Monad m
=> ix
-> ix
-> ix
-> (Int -> Int -> Bool)
-> a
-> (ix -> a -> m a)
-> m a
default iterM :: (Index (Lower ix), Monad m) =>
ix -> ix -> ix -> (Int -> Int -> Bool) -> a -> (ix -> a -> m a) -> m a
iterM !sIx eIx !incIx cond !acc f =
loopM s (`cond` e) (+ inc) acc $ \ !i !acc0 ->
iterM sIxL eIxL incIxL cond acc0 $ \ !ix -> f (consDim i ix)
where
!(s, sIxL) = unconsDim sIx
!(e, eIxL) = unconsDim eIx
!(inc, incIxL) = unconsDim incIx
{-# INLINE iterM #-}
iterM_ :: Monad m => ix -> ix -> ix -> (Int -> Int -> Bool) -> (ix -> m a) -> m ()
default iterM_ :: (Index (Lower ix), Monad m) =>
ix -> ix -> ix -> (Int -> Int -> Bool) -> (ix -> m a) -> m ()
iterM_ !sIx eIx !incIx cond f =
loopM_ s (`cond` e) (+ inc) $ \ !i -> iterM_ sIxL eIxL incIxL cond $ \ !ix -> f (consDim i ix)
where
!(s, sIxL) = unconsDim sIx
!(e, eIxL) = unconsDim eIx
!(inc, incIxL) = unconsDim incIx
{-# INLINE iterM_ #-}
data Ix0 = Ix0 deriving (Eq, Ord, Show)
instance NFData Ix0 where
rnf Ix0 = ()
type Ix1 = Int
pattern Ix1 :: Int -> Ix1
pattern Ix1 i = i
type instance Lower Int = Ix0
instance Index Ix1 where
type Dimensions Ix1 = 1
dimensions _ = 1
{-# INLINE [1] dimensions #-}
totalElem = unSz
{-# INLINE [1] totalElem #-}
isSafeIndex (SafeSz k) !i = 0 <= i && i < k
{-# INLINE [1] isSafeIndex #-}
toLinearIndex _ = id
{-# INLINE [1] toLinearIndex #-}
toLinearIndexAcc !acc m i = acc * m + i
{-# INLINE [1] toLinearIndexAcc #-}
fromLinearIndex _ = id
{-# INLINE [1] fromLinearIndex #-}
fromLinearIndexAcc n k = k `quotRem` n
{-# INLINE [1] fromLinearIndexAcc #-}
repairIndex k@(SafeSz ksz) !i rBelow rOver
| i < 0 = rBelow k i
| i >= ksz = rOver k i
| otherwise = i
{-# INLINE [1] repairIndex #-}
consDim i _ = i
{-# INLINE [1] consDim #-}
unconsDim i = (i, Ix0)
{-# INLINE [1] unconsDim #-}
snocDim _ i = i
{-# INLINE [1] snocDim #-}
unsnocDim i = (Ix0, i)
{-# INLINE [1] unsnocDim #-}
getDimM i 1 = pure i
getDimM ix d = throwM $ IndexDimensionException ix d
{-# INLINE [1] getDimM #-}
setDimM _ 1 i = pure i
setDimM ix d _ = throwM $ IndexDimensionException ix d
{-# INLINE [1] setDimM #-}
pullOutDimM i 1 = pure (i, Ix0)
pullOutDimM ix d = throwM $ IndexDimensionException ix d
{-# INLINE [1] pullOutDimM #-}
insertDimM Ix0 1 i = pure i
insertDimM ix d _ = throwM $ IndexDimensionException ix d
{-# INLINE [1] insertDimM #-}
pureIndex i = i
{-# INLINE [1] pureIndex #-}
liftIndex f = f
{-# INLINE [1] liftIndex #-}
liftIndex2 f = f
{-# INLINE [1] liftIndex2 #-}
foldlIndex f = f
{-# INLINE [1] foldlIndex #-}
iterM k0 k1 inc cond = loopM k0 (`cond` k1) (+inc)
{-# INLINE iterM #-}
iterM_ k0 k1 inc cond = loopM_ k0 (`cond` k1) (+inc)
{-# INLINE iterM_ #-}
data IndexException where
IndexZeroException :: Index ix => !ix -> IndexException
IndexDimensionException :: (Show ix, Typeable ix) => !ix -> Dim -> IndexException
IndexOutOfBoundsException :: Index ix => !(Sz ix) -> !ix -> IndexException
instance Show IndexException where
show (IndexZeroException ix) = "IndexZeroException: " ++ show ix
show (IndexDimensionException ix dim) =
"IndexDimensionException: " ++ show dim ++ " for " ++ show ix
show (IndexOutOfBoundsException sz ix) =
"IndexOutOfBoundsException: " ++ showsPrec 1 ix " not safe for (" ++ show sz ++ ")"
showsPrec 0 arr s = show arr ++ s
showsPrec _ arr s = '(' : show arr ++ ")" ++ s
instance Exception IndexException
data SizeException where
SizeMismatchException :: Index ix => !(Sz ix) -> !(Sz ix) -> SizeException
SizeElementsMismatchException :: (Index ix, Index ix') => !(Sz ix) -> !(Sz ix') -> SizeException
SizeSubregionException :: Index ix => !(Sz ix) -> !ix -> !(Sz ix) -> SizeException
SizeEmptyException :: Index ix => !(Sz ix) -> SizeException
instance Exception SizeException
instance Show SizeException where
show (SizeMismatchException sz sz') =
"SizeMismatchException: (" ++ show sz ++ ") vs (" ++ show sz' ++ ")"
show (SizeElementsMismatchException sz sz') =
"SizeElementsMismatchException: (" ++ show sz ++ ") vs (" ++ show sz' ++ ")"
show (SizeSubregionException sz' ix sz) =
"SizeSubregionException: (" ++
show sz' ++ ") is to small for " ++ show ix ++ " (" ++ show sz ++ ")"
show (SizeEmptyException sz) =
"SizeEmptyException: (" ++ show sz ++ ") corresponds to an empty array"
showsPrec 0 arr s = show arr ++ s
showsPrec _ arr s = '(' : show arr ++ ")" ++ s
data ShapeException
= DimTooShortException !Sz1 !Sz1
| DimTooLongException
deriving Eq
instance Show ShapeException where
show (DimTooShortException sz sz') =
"DimTooShortException: expected (" ++ show sz ++ "), got (" ++ show sz' ++ ")"
show DimTooLongException =
"DimTooLongException"
showsPrec 0 arr s = show arr ++ s
showsPrec _ arr s = '(' : show arr ++ ")" ++ s
instance Exception ShapeException