{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Comfort.Shape (
C(..),
ZeroBased(..),
OneBased(..),
Range(..),
Shifted(..),
(:+:)(..),
) where
import Foreign.Storable
(Storable, sizeOf, alignment, poke, peek, pokeElemOff, peekElemOff)
import Foreign.Ptr (Ptr, castPtr)
import qualified GHC.Arr as Ix
import qualified Control.Monad.HT as Monad
import Data.Tuple.HT (mapFst, mapPair)
class C sh where
{-# MINIMAL indices, size, (sizeOffset|offset), inBounds #-}
type Index sh :: *
indices :: sh -> [Index sh]
offset :: sh -> Index sh -> Int
offset sh ix = snd $ sizeOffset sh ix
uncheckedOffset :: sh -> Index sh -> Int
uncheckedOffset = offset
inBounds :: sh -> Index sh -> Bool
size :: sh -> Int
uncheckedSize :: sh -> Int
uncheckedSize = size
sizeOffset :: sh -> Index sh -> (Int,Int)
sizeOffset sh ix = (size sh, offset sh ix)
uncheckedSizeOffset :: sh -> Index sh -> (Int,Int)
uncheckedSizeOffset sh ix = (uncheckedSize sh, uncheckedOffset sh ix)
instance C () where
type Index () = ()
indices () = [()]
offset () () = 0
uncheckedOffset () () = 0
inBounds () () = True
size () = 1
uncheckedSize () = 1
newtype ZeroBased n = ZeroBased {zeroBasedSize :: n}
deriving (Eq, Show)
instance (Integral n) => C (ZeroBased n) where
type Index (ZeroBased n) = n
indices (ZeroBased len) = indices $ Shifted 0 len
offset (ZeroBased len) = offset $ Shifted 0 len
uncheckedOffset _ ix = fromIntegral ix
inBounds (ZeroBased len) ix = 0<=ix && ix<len
size (ZeroBased len) = fromIntegral len
uncheckedSize (ZeroBased len) = fromIntegral len
newtype OneBased n = OneBased {oneBasedSize :: n}
deriving (Eq, Show)
instance (Integral n) => C (OneBased n) where
type Index (OneBased n) = n
indices (OneBased len) = indices $ Shifted 1 len
offset (OneBased len) = offset $ Shifted 1 len
uncheckedOffset _ ix = fromIntegral ix - 1
inBounds (OneBased len) ix = 0<ix && ix<=len
size (OneBased len) = fromIntegral len
uncheckedSize (OneBased len) = fromIntegral len
data Range n = Range {rangeFrom, rangeTo :: n}
deriving (Eq, Show)
instance (Ix.Ix n) => C (Range n) where
type Index (Range n) = n
indices (Range from to) = Ix.range (from,to)
offset (Range from to) ix = Ix.index (from,to) ix
uncheckedOffset (Range from to) ix = Ix.unsafeIndex (from,to) ix
inBounds (Range from to) ix = Ix.inRange (from,to) ix
size (Range from to) = Ix.rangeSize (from,to)
uncheckedSize (Range from to) = Ix.unsafeRangeSize (from,to)
instance Storable n => Storable (Range n) where
{-# INLINE sizeOf #-}
{-# INLINE alignment #-}
{-# INLINE peek #-}
{-# INLINE poke #-}
sizeOf ~(Range l r) = sizeOf l + mod (- sizeOf l) (alignment r) + sizeOf r
alignment ~(Range l _) = alignment l
poke p (Range l r) =
let q = castToElemPtr p
in poke q l >> pokeElemOff q 1 r
peek p =
let q = castToElemPtr p
in Monad.lift2 Range (peek q) (peekElemOff q 1)
data Shifted n = Shifted {shiftedOffset, shiftedSize :: n}
deriving (Eq, Show)
instance (Integral n) => C (Shifted n) where
type Index (Shifted n) = n
indices (Shifted offs len) =
map snd $
takeWhile ((>0) . fst) $
zip
(iterate (subtract 1) len)
(iterate (1+) offs)
offset (Shifted offs len) ix =
if ix<offs
then error "Shape.Shifted: array index too small"
else
let k = ix-offs
in if k<len
then fromIntegral k
else error "Shape.Shifted: array index too big"
uncheckedOffset (Shifted offs _len) ix = fromIntegral $ ix-offs
inBounds (Shifted offs len) ix = ix < offs+len
size (Shifted _offs len) = fromIntegral len
uncheckedSize (Shifted _offs len) = fromIntegral len
instance Storable n => Storable (Shifted n) where
{-# INLINE sizeOf #-}
{-# INLINE alignment #-}
{-# INLINE peek #-}
{-# INLINE poke #-}
sizeOf ~(Shifted l n) = sizeOf l + mod (- sizeOf l) (alignment n) + sizeOf n
alignment ~(Shifted l _) = alignment l
poke p (Shifted l n) =
let q = castToElemPtr p
in poke q l >> pokeElemOff q 1 n
peek p =
let q = castToElemPtr p
in Monad.lift2 Shifted (peek q) (peekElemOff q 1)
{-# INLINE castToElemPtr #-}
castToElemPtr :: Ptr (f a) -> Ptr a
castToElemPtr = castPtr
instance (C sh0, C sh1) => C (sh0,sh1) where
type Index (sh0,sh1) = (Index sh0, Index sh1)
indices (sh0,sh1) = Monad.lift2 (,) (indices sh0) (indices sh1)
offset (sh0,sh1) (ix0,ix1) =
offset sh0 ix0 `combineOffset` sizeOffset sh1 ix1
uncheckedOffset (sh0,sh1) (ix0,ix1) =
uncheckedOffset sh0 ix0 `combineOffset` uncheckedSizeOffset sh1 ix1
sizeOffset (sh0,sh1) (ix0,ix1) =
sizeOffset sh0 ix0 `combineSizeOffset` sizeOffset sh1 ix1
uncheckedSizeOffset (sh0,sh1) (ix0,ix1) =
uncheckedSizeOffset sh0 ix0
`combineSizeOffset`
uncheckedSizeOffset sh1 ix1
inBounds (sh0,sh1) (ix0,ix1) = inBounds sh0 ix0 && inBounds sh1 ix1
size (sh0,sh1) = size sh0 * size sh1
uncheckedSize (sh0,sh1) = uncheckedSize sh0 * uncheckedSize sh1
instance (C sh0, C sh1, C sh2) => C (sh0,sh1,sh2) where
type Index (sh0,sh1,sh2) = (Index sh0, Index sh1, Index sh2)
indices (sh0,sh1,sh2) =
Monad.lift3 (,,) (indices sh0) (indices sh1) (indices sh2)
uncheckedOffset (sh0,sh1,sh2) (ix0,ix1,ix2) =
uncheckedOffset sh0 ix0
`combineOffset`
uncheckedSizeOffset sh1 ix1
`combineSizeOffset`
uncheckedSizeOffset sh2 ix2
sizeOffset (sh0,sh1,sh2) (ix0,ix1,ix2) =
sizeOffset sh0 ix0
`combineSizeOffset`
sizeOffset sh1 ix1
`combineSizeOffset`
sizeOffset sh2 ix2
uncheckedSizeOffset (sh0,sh1,sh2) (ix0,ix1,ix2) =
uncheckedSizeOffset sh0 ix0
`combineSizeOffset`
uncheckedSizeOffset sh1 ix1
`combineSizeOffset`
uncheckedSizeOffset sh2 ix2
inBounds (sh0,sh1,sh2) (ix0,ix1,ix2) =
inBounds sh0 ix0 && inBounds sh1 ix1 && inBounds sh2 ix2
size (sh0,sh1,sh2) = size sh0 * size sh1 * size sh2
uncheckedSize (sh0,sh1,sh2) =
uncheckedSize sh0 * uncheckedSize sh1 * uncheckedSize sh2
infixr 7 `combineOffset`, `combineSizeOffset`
{-# INLINE combineOffset #-}
combineOffset :: Num a => a -> (a, a) -> a
combineOffset offset0 (size1,offset1) = offset0 * size1 + offset1
{-# INLINE combineSizeOffset #-}
combineSizeOffset :: Num a => (a, a) -> (a, a) -> (a, a)
combineSizeOffset (size0,offset0) (size1,offset1) =
(size0*size1, offset0 * size1 + offset1)
infixr 5 :+:
data sh0:+:sh1 = sh0:+:sh1
deriving (Eq, Show)
instance (C sh0, C sh1) => C (sh0:+:sh1) where
type Index (sh0:+:sh1) = Either (Index sh0) (Index sh1)
indices (sh0:+:sh1) = map Left (indices sh0) ++ map Right (indices sh1)
offset (sh0:+:sh1) ix =
case ix of
Left ix0 -> offset sh0 ix0
Right ix1 -> size sh0 + offset sh1 ix1
uncheckedOffset (sh0:+:sh1) ix =
case ix of
Left ix0 -> uncheckedOffset sh0 ix0
Right ix1 -> uncheckedSize sh0 + uncheckedOffset sh1 ix1
sizeOffset (sh0:+:sh1) ix =
case ix of
Left ix0 -> mapFst (+ size sh1) $ sizeOffset sh0 ix0
Right ix1 ->
let size0 = size sh0
in mapPair ((size0+), (size0+)) $ sizeOffset sh1 ix1
uncheckedSizeOffset (sh0:+:sh1) ix =
case ix of
Left ix0 -> mapFst (+ uncheckedSize sh1) $ uncheckedSizeOffset sh0 ix0
Right ix1 ->
let size0 = uncheckedSize sh0
in mapPair ((size0+), (size0+)) $ uncheckedSizeOffset sh1 ix1
inBounds (sh0:+:sh1) ix =
case ix of
Left ix0 -> inBounds sh0 ix0
Right ix1 -> inBounds sh1 ix1
size (sh0:+:sh1) = size sh0 + size sh1
uncheckedSize (sh0:+:sh1) = uncheckedSize sh0 + uncheckedSize sh1