{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Comfort.Shape (
C(..),
Indexed(..),
InvIndexed(..),
ZeroBased(..),
OneBased(..),
Range(..),
Shifted(..),
(:+:)(..),
) where
import Foreign.Storable
(Storable, sizeOf, alignment, poke, peek, pokeElemOff, peekElemOff)
import Foreign.Ptr (Ptr, castPtr)
import qualified GHC.Arr as Ix
import qualified Control.Monad.Trans.State as MS
import qualified Control.Monad.HT as Monad
import qualified Control.Applicative.Backwards as Back
import Control.Applicative (liftA2, liftA3)
import Text.Printf (printf)
import Data.Tuple.HT (mapFst, mapPair, swap)
class C sh where
size :: sh -> Int
uncheckedSize :: sh -> Int
uncheckedSize = size
class C sh => Indexed sh where
{-# MINIMAL indices, (sizeOffset|offset), inBounds #-}
type Index sh :: *
indices :: sh -> [Index sh]
offset :: sh -> Index sh -> Int
offset sh ix = snd $ sizeOffset sh ix
uncheckedOffset :: sh -> Index sh -> Int
uncheckedOffset = offset
inBounds :: sh -> Index sh -> Bool
sizeOffset :: sh -> Index sh -> (Int,Int)
sizeOffset sh ix = (size sh, offset sh ix)
uncheckedSizeOffset :: sh -> Index sh -> (Int,Int)
uncheckedSizeOffset sh ix = (uncheckedSize sh, uncheckedOffset sh ix)
class Indexed sh => InvIndexed sh where
indexFromOffset :: sh -> Int -> Index sh
uncheckedIndexFromOffset :: sh -> Int -> Index sh
uncheckedIndexFromOffset = indexFromOffset
errorIndexFromOffset :: String -> Int -> a
errorIndexFromOffset name k =
error $ printf "indexFromOffset (%s): index %d out of range" name k
instance C () where
size () = 1
uncheckedSize () = 1
instance Indexed () where
type Index () = ()
indices () = [()]
offset () () = 0
uncheckedOffset () () = 0
inBounds () () = True
instance InvIndexed () where
indexFromOffset () 0 = ()
indexFromOffset () k = errorIndexFromOffset "()" k
uncheckedIndexFromOffset () _ = ()
newtype ZeroBased n = ZeroBased {zeroBasedSize :: n}
deriving (Eq, Show)
instance (Integral n) => C (ZeroBased n) where
size (ZeroBased len) = fromIntegral len
uncheckedSize (ZeroBased len) = fromIntegral len
instance (Integral n) => Indexed (ZeroBased n) where
type Index (ZeroBased n) = n
indices (ZeroBased len) = indices $ Shifted 0 len
offset (ZeroBased len) = offset $ Shifted 0 len
uncheckedOffset _ ix = fromIntegral ix
inBounds (ZeroBased len) ix = 0<=ix && ix<len
instance (Integral n) => InvIndexed (ZeroBased n) where
indexFromOffset (ZeroBased len) k0 =
let k = fromIntegral k0
in if 0<=k && k<len
then k
else errorIndexFromOffset "ZeroBased" k0
uncheckedIndexFromOffset _ k = fromIntegral k
newtype OneBased n = OneBased {oneBasedSize :: n}
deriving (Eq, Show)
instance (Integral n) => C (OneBased n) where
size (OneBased len) = fromIntegral len
uncheckedSize (OneBased len) = fromIntegral len
instance (Integral n) => Indexed (OneBased n) where
type Index (OneBased n) = n
indices (OneBased len) = indices $ Shifted 1 len
offset (OneBased len) = offset $ Shifted 1 len
uncheckedOffset _ ix = fromIntegral ix - 1
inBounds (OneBased len) ix = 0<ix && ix<=len
instance (Integral n) => InvIndexed (OneBased n) where
indexFromOffset (OneBased len) k0 =
let k = fromIntegral k0
in if 0<=k && k<len
then 1+k
else errorIndexFromOffset "OneBased" k0
uncheckedIndexFromOffset _ k = 1 + fromIntegral k
data Range n = Range {rangeFrom, rangeTo :: n}
deriving (Eq, Show)
instance (Ix.Ix n) => C (Range n) where
size (Range from to) = Ix.rangeSize (from,to)
uncheckedSize (Range from to) = Ix.unsafeRangeSize (from,to)
instance (Ix.Ix n) => Indexed (Range n) where
type Index (Range n) = n
indices (Range from to) = Ix.range (from,to)
offset (Range from to) ix = Ix.index (from,to) ix
uncheckedOffset (Range from to) ix = Ix.unsafeIndex (from,to) ix
inBounds (Range from to) ix = Ix.inRange (from,to) ix
instance (Ix.Ix n) => InvIndexed (Range n) where
indexFromOffset (Range from to) k =
if 0<=k && k < Ix.rangeSize (from,to)
then Ix.range (from,to) !! k
else errorIndexFromOffset "Range" k
uncheckedIndexFromOffset (Range from to) k = Ix.range (from,to) !! k
instance Storable n => Storable (Range n) where
{-# INLINE sizeOf #-}
{-# INLINE alignment #-}
{-# INLINE peek #-}
{-# INLINE poke #-}
sizeOf ~(Range l r) = sizeOf l + mod (- sizeOf l) (alignment r) + sizeOf r
alignment ~(Range l _) = alignment l
poke p (Range l r) =
let q = castToElemPtr p
in poke q l >> pokeElemOff q 1 r
peek p =
let q = castToElemPtr p
in Monad.lift2 Range (peek q) (peekElemOff q 1)
data Shifted n = Shifted {shiftedOffset, shiftedSize :: n}
deriving (Eq, Show)
instance (Integral n) => C (Shifted n) where
size (Shifted _offs len) = fromIntegral len
uncheckedSize (Shifted _offs len) = fromIntegral len
instance (Integral n) => Indexed (Shifted n) where
type Index (Shifted n) = n
indices (Shifted offs len) =
map snd $
takeWhile ((>0) . fst) $
zip
(iterate (subtract 1) len)
(iterate (1+) offs)
offset (Shifted offs len) ix =
if ix<offs
then error "Shape.Shifted: array index too small"
else
let k = ix-offs
in if k<len
then fromIntegral k
else error "Shape.Shifted: array index too big"
uncheckedOffset (Shifted offs _len) ix = fromIntegral $ ix-offs
inBounds (Shifted offs len) ix = ix < offs+len
instance (Integral n) => InvIndexed (Shifted n) where
indexFromOffset (Shifted offs len) k0 =
let k = fromIntegral k0
in if 0<=k && k<len
then offs+k
else errorIndexFromOffset "Shifted" k0
uncheckedIndexFromOffset (Shifted offs _len) k = offs + fromIntegral k
instance Storable n => Storable (Shifted n) where
{-# INLINE sizeOf #-}
{-# INLINE alignment #-}
{-# INLINE peek #-}
{-# INLINE poke #-}
sizeOf ~(Shifted l n) = sizeOf l + mod (- sizeOf l) (alignment n) + sizeOf n
alignment ~(Shifted l _) = alignment l
poke p (Shifted l n) =
let q = castToElemPtr p
in poke q l >> pokeElemOff q 1 n
peek p =
let q = castToElemPtr p
in Monad.lift2 Shifted (peek q) (peekElemOff q 1)
{-# INLINE castToElemPtr #-}
castToElemPtr :: Ptr (f a) -> Ptr a
castToElemPtr = castPtr
instance (C sh0, C sh1) => C (sh0,sh1) where
size (sh0,sh1) = size sh0 * size sh1
uncheckedSize (sh0,sh1) = uncheckedSize sh0 * uncheckedSize sh1
instance (Indexed sh0, Indexed sh1) => Indexed (sh0,sh1) where
type Index (sh0,sh1) = (Index sh0, Index sh1)
indices (sh0,sh1) = Monad.lift2 (,) (indices sh0) (indices sh1)
offset (sh0,sh1) (ix0,ix1) =
offset sh0 ix0 `combineOffset` sizeOffset sh1 ix1
uncheckedOffset (sh0,sh1) (ix0,ix1) =
uncheckedOffset sh0 ix0 `combineOffset` uncheckedSizeOffset sh1 ix1
sizeOffset (sh0,sh1) (ix0,ix1) =
sizeOffset sh0 ix0 `combineSizeOffset` sizeOffset sh1 ix1
uncheckedSizeOffset (sh0,sh1) (ix0,ix1) =
uncheckedSizeOffset sh0 ix0
`combineSizeOffset`
uncheckedSizeOffset sh1 ix1
inBounds (sh0,sh1) (ix0,ix1) = inBounds sh0 ix0 && inBounds sh1 ix1
instance (InvIndexed sh0, InvIndexed sh1) => InvIndexed (sh0,sh1) where
indexFromOffset (sh0,sh1) k =
runInvIndex k $ liftA2 (,) (pickLastIndex sh0) (pickIndex sh1)
uncheckedIndexFromOffset (sh0,sh1) k =
runInvIndex k $ liftA2 (,) (uncheckedPickLastIndex sh0) (pickIndex sh1)
instance (C sh0, C sh1, C sh2) => C (sh0,sh1,sh2) where
size (sh0,sh1,sh2) = size sh0 * size sh1 * size sh2
uncheckedSize (sh0,sh1,sh2) =
uncheckedSize sh0 * uncheckedSize sh1 * uncheckedSize sh2
instance (Indexed sh0, Indexed sh1, Indexed sh2) => Indexed (sh0,sh1,sh2) where
type Index (sh0,sh1,sh2) = (Index sh0, Index sh1, Index sh2)
indices (sh0,sh1,sh2) =
Monad.lift3 (,,) (indices sh0) (indices sh1) (indices sh2)
uncheckedOffset (sh0,sh1,sh2) (ix0,ix1,ix2) =
uncheckedOffset sh0 ix0
`combineOffset`
uncheckedSizeOffset sh1 ix1
`combineSizeOffset`
uncheckedSizeOffset sh2 ix2
sizeOffset (sh0,sh1,sh2) (ix0,ix1,ix2) =
sizeOffset sh0 ix0
`combineSizeOffset`
sizeOffset sh1 ix1
`combineSizeOffset`
sizeOffset sh2 ix2
uncheckedSizeOffset (sh0,sh1,sh2) (ix0,ix1,ix2) =
uncheckedSizeOffset sh0 ix0
`combineSizeOffset`
uncheckedSizeOffset sh1 ix1
`combineSizeOffset`
uncheckedSizeOffset sh2 ix2
inBounds (sh0,sh1,sh2) (ix0,ix1,ix2) =
inBounds sh0 ix0 && inBounds sh1 ix1 && inBounds sh2 ix2
instance
(InvIndexed sh0, InvIndexed sh1, InvIndexed sh2) =>
InvIndexed (sh0,sh1,sh2) where
indexFromOffset (sh0,sh1,sh2) k =
runInvIndex k $
liftA3 (,,) (pickLastIndex sh0) (pickIndex sh1) (pickIndex sh2)
uncheckedIndexFromOffset (sh0,sh1,sh2) k =
runInvIndex k $
liftA3 (,,) (uncheckedPickLastIndex sh0) (pickIndex sh1) (pickIndex sh2)
runInvIndex :: s -> Back.Backwards (MS.State s) a -> a
runInvIndex k = flip MS.evalState k . Back.forwards
pickLastIndex ::
(InvIndexed sh) => sh -> Back.Backwards (MS.State Int) (Index sh)
pickLastIndex sh =
Back.Backwards $ MS.gets $ indexFromOffset sh
uncheckedPickLastIndex ::
(InvIndexed sh) => sh -> Back.Backwards (MS.State Int) (Index sh)
uncheckedPickLastIndex sh =
Back.Backwards $ MS.gets $ uncheckedIndexFromOffset sh
pickIndex :: (InvIndexed sh) => sh -> Back.Backwards (MS.State Int) (Index sh)
pickIndex sh =
fmap (uncheckedIndexFromOffset sh) $
Back.Backwards $ MS.state $ \k -> swap $ divMod k $ size sh
infixr 7 `combineOffset`, `combineSizeOffset`
{-# INLINE combineOffset #-}
combineOffset :: Num a => a -> (a, a) -> a
combineOffset offset0 (size1,offset1) = offset0 * size1 + offset1
{-# INLINE combineSizeOffset #-}
combineSizeOffset :: Num a => (a, a) -> (a, a) -> (a, a)
combineSizeOffset (size0,offset0) (size1,offset1) =
(size0*size1, offset0 * size1 + offset1)
infixr 5 :+:
data sh0:+:sh1 = sh0:+:sh1
deriving (Eq, Show)
instance (C sh0, C sh1) => C (sh0:+:sh1) where
size (sh0:+:sh1) = size sh0 + size sh1
uncheckedSize (sh0:+:sh1) = uncheckedSize sh0 + uncheckedSize sh1
instance (Indexed sh0, Indexed sh1) => Indexed (sh0:+:sh1) where
type Index (sh0:+:sh1) = Either (Index sh0) (Index sh1)
indices (sh0:+:sh1) = map Left (indices sh0) ++ map Right (indices sh1)
offset (sh0:+:sh1) ix =
case ix of
Left ix0 -> offset sh0 ix0
Right ix1 -> size sh0 + offset sh1 ix1
uncheckedOffset (sh0:+:sh1) ix =
case ix of
Left ix0 -> uncheckedOffset sh0 ix0
Right ix1 -> uncheckedSize sh0 + uncheckedOffset sh1 ix1
sizeOffset (sh0:+:sh1) ix =
case ix of
Left ix0 -> mapFst (+ size sh1) $ sizeOffset sh0 ix0
Right ix1 ->
let size0 = size sh0
in mapPair ((size0+), (size0+)) $ sizeOffset sh1 ix1
uncheckedSizeOffset (sh0:+:sh1) ix =
case ix of
Left ix0 -> mapFst (+ uncheckedSize sh1) $ uncheckedSizeOffset sh0 ix0
Right ix1 ->
let size0 = uncheckedSize sh0
in mapPair ((size0+), (size0+)) $ uncheckedSizeOffset sh1 ix1
inBounds (sh0:+:sh1) = either (inBounds sh0) (inBounds sh1)
instance (InvIndexed sh0, InvIndexed sh1) => InvIndexed (sh0:+:sh1) where
indexFromOffset (sh0:+:sh1) k =
let pivot = size sh0
in if k < pivot
then Left $ indexFromOffset sh0 k
else Right $ indexFromOffset sh1 $ k-pivot
uncheckedIndexFromOffset (sh0:+:sh1) k =
let pivot = size sh0
in if k < pivot
then Left $ uncheckedIndexFromOffset sh0 k
else Right $ uncheckedIndexFromOffset sh1 $ k-pivot