{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-
Tests:
   map offset (indices sh) = [0..]
   length (indices sh) = size sh
   sizeOffset sh ix = (size sh, offset sh ix)
-}
module Data.Array.Comfort.Shape (
   C(..),

   ZeroBased(..),
   OneBased(..),

   Range(..),
   Shifted(..),
   (:+:)(..),
   ) where

import Foreign.Storable
         (Storable, sizeOf, alignment, poke, peek, pokeElemOff, peekElemOff)
import Foreign.Ptr (Ptr, castPtr)

import qualified GHC.Arr as Ix

import qualified Control.Monad.HT as Monad

import Data.Tuple.HT (mapFst, mapPair)


class C sh where
   {-# MINIMAL indices, size, (sizeOffset|offset), inBounds #-}
   type Index sh :: *
   -- Ix.range
   indices :: sh -> [Index sh]
   -- Ix.index
   offset :: sh -> Index sh -> Int
   offset sh ix = snd $ sizeOffset sh ix
   -- Ix.unsafeIndex
   uncheckedOffset :: sh -> Index sh -> Int
   uncheckedOffset = offset
   -- Ix.inRange
   inBounds :: sh -> Index sh -> Bool
   -- Ix.rangeSize
   size :: sh -> Int
   -- Ix.unsafeRangeSize
   uncheckedSize :: sh -> Int
   uncheckedSize = size

   sizeOffset :: sh -> Index sh -> (Int,Int)
   sizeOffset sh ix = (size sh, offset sh ix)
   uncheckedSizeOffset :: sh -> Index sh -> (Int,Int)
   uncheckedSizeOffset sh ix = (uncheckedSize sh, uncheckedOffset sh ix)


instance C () where
   type Index () = ()
   indices () = [()]
   offset () () = 0
   uncheckedOffset () () = 0
   inBounds () () = True
   size () = 1
   uncheckedSize () = 1


{- |
'ZeroBased' denotes a range starting at zero and has a certain length.
-}
newtype ZeroBased n = ZeroBased {zeroBasedSize :: n}
   deriving (Eq, Show)

instance (Integral n) => C (ZeroBased n) where
   type Index (ZeroBased n) = n
   indices (ZeroBased len) = indices $ Shifted 0 len
   offset (ZeroBased len) = offset $ Shifted 0 len
   uncheckedOffset _ ix = fromIntegral ix
   inBounds (ZeroBased len) ix = 0<=ix && ix<len
   size (ZeroBased len) = fromIntegral len
   uncheckedSize (ZeroBased len) = fromIntegral len

{- |
'OneBased' denotes a range starting at zero and has a certain length.
-}
newtype OneBased n = OneBased {oneBasedSize :: n}
   deriving (Eq, Show)

instance (Integral n) => C (OneBased n) where
   type Index (OneBased n) = n
   indices (OneBased len) = indices $ Shifted 1 len
   offset (OneBased len) = offset $ Shifted 1 len
   uncheckedOffset _ ix = fromIntegral ix - 1
   inBounds (OneBased len) ix = 0<ix && ix<=len
   size (OneBased len) = fromIntegral len
   uncheckedSize (OneBased len) = fromIntegral len


{- |
'Range' denotes an inclusive range like
those of the Haskell 98 standard @Array@ type from the @array@ package.
E.g. the shape type @(Range Int32, Range Int64)@
is equivalent to the ix type @(Int32, Int64)@ for @Array@s.
-}
data Range n = Range {rangeFrom, rangeTo :: n}
   deriving (Eq, Show)

instance (Ix.Ix n) => C (Range n) where
   type Index (Range n) = n
   indices (Range from to) = Ix.range (from,to)
   offset (Range from to) ix = Ix.index (from,to) ix
   uncheckedOffset (Range from to) ix = Ix.unsafeIndex (from,to) ix
   inBounds (Range from to) ix = Ix.inRange (from,to) ix
   size (Range from to) = Ix.rangeSize (from,to)
   uncheckedSize (Range from to) = Ix.unsafeRangeSize (from,to)

-- cf. sample-frame:Stereo
instance Storable n => Storable (Range n) where
   {-# INLINE sizeOf #-}
   {-# INLINE alignment #-}
   {-# INLINE peek #-}
   {-# INLINE poke #-}
   sizeOf ~(Range l r) = sizeOf l + mod (- sizeOf l) (alignment r) + sizeOf r
   alignment ~(Range l _) = alignment l
   poke p (Range l r) =
      let q = castToElemPtr p
      in  poke q l >> pokeElemOff q 1 r
   peek p =
      let q = castToElemPtr p
      in  Monad.lift2 Range (peek q) (peekElemOff q 1)


{- |
'Shifted' denotes a range defined by the start index and the length.
-}
data Shifted n = Shifted {shiftedOffset, shiftedSize :: n}
   deriving (Eq, Show)

instance (Integral n) => C (Shifted n) where
   type Index (Shifted n) = n
   indices (Shifted offs len) =
      map snd $
      takeWhile ((>0) . fst) $
      zip
         (iterate (subtract 1) len)
         (iterate (1+) offs)
   offset (Shifted offs len) ix =
      if ix<offs
        then error "Shape.Shifted: array index too small"
        else
          let k = ix-offs
          in  if k<len
                then fromIntegral k
                else error "Shape.Shifted: array index too big"
   uncheckedOffset (Shifted offs _len) ix = fromIntegral $ ix-offs
   inBounds (Shifted offs len) ix = ix < offs+len
   size (Shifted _offs len) = fromIntegral len
   uncheckedSize (Shifted _offs len) = fromIntegral len

-- cf. sample-frame:Stereo
instance Storable n => Storable (Shifted n) where
   {-# INLINE sizeOf #-}
   {-# INLINE alignment #-}
   {-# INLINE peek #-}
   {-# INLINE poke #-}
   sizeOf ~(Shifted l n) = sizeOf l + mod (- sizeOf l) (alignment n) + sizeOf n
   alignment ~(Shifted l _) = alignment l
   poke p (Shifted l n) =
      let q = castToElemPtr p
      in  poke q l >> pokeElemOff q 1 n
   peek p =
      let q = castToElemPtr p
      in  Monad.lift2 Shifted (peek q) (peekElemOff q 1)


{-# INLINE castToElemPtr #-}
castToElemPtr :: Ptr (f a) -> Ptr a
castToElemPtr = castPtr



{- |
Row-major composition of two dimensions.
-}
instance (C sh0, C sh1) => C (sh0,sh1) where
   type Index (sh0,sh1) = (Index sh0, Index sh1)
   indices (sh0,sh1) = Monad.lift2 (,) (indices sh0) (indices sh1)
   offset (sh0,sh1) (ix0,ix1) =
      offset sh0 ix0 `combineOffset` sizeOffset sh1 ix1
   uncheckedOffset (sh0,sh1) (ix0,ix1) =
      uncheckedOffset sh0 ix0 `combineOffset` uncheckedSizeOffset sh1 ix1
   sizeOffset (sh0,sh1) (ix0,ix1) =
      sizeOffset sh0 ix0 `combineSizeOffset` sizeOffset sh1 ix1
   uncheckedSizeOffset (sh0,sh1) (ix0,ix1) =
      uncheckedSizeOffset sh0 ix0
      `combineSizeOffset`
      uncheckedSizeOffset sh1 ix1
   inBounds (sh0,sh1) (ix0,ix1) = inBounds sh0 ix0 && inBounds sh1 ix1
   size (sh0,sh1) = size sh0 * size sh1
   uncheckedSize (sh0,sh1) = uncheckedSize sh0 * uncheckedSize sh1

instance (C sh0, C sh1, C sh2) => C (sh0,sh1,sh2) where
   type Index (sh0,sh1,sh2) = (Index sh0, Index sh1, Index sh2)
   indices (sh0,sh1,sh2) =
      Monad.lift3 (,,) (indices sh0) (indices sh1) (indices sh2)
   uncheckedOffset (sh0,sh1,sh2) (ix0,ix1,ix2) =
      uncheckedOffset sh0 ix0
      `combineOffset`
      uncheckedSizeOffset sh1 ix1
      `combineSizeOffset`
      uncheckedSizeOffset sh2 ix2
   sizeOffset (sh0,sh1,sh2) (ix0,ix1,ix2) =
      sizeOffset sh0 ix0
      `combineSizeOffset`
      sizeOffset sh1 ix1
      `combineSizeOffset`
      sizeOffset sh2 ix2
   uncheckedSizeOffset (sh0,sh1,sh2) (ix0,ix1,ix2) =
      uncheckedSizeOffset sh0 ix0
      `combineSizeOffset`
      uncheckedSizeOffset sh1 ix1
      `combineSizeOffset`
      uncheckedSizeOffset sh2 ix2
   inBounds (sh0,sh1,sh2) (ix0,ix1,ix2) =
      inBounds sh0 ix0 && inBounds sh1 ix1 && inBounds sh2 ix2
   size (sh0,sh1,sh2) = size sh0 * size sh1 * size sh2
   uncheckedSize (sh0,sh1,sh2) =
      uncheckedSize sh0 * uncheckedSize sh1 * uncheckedSize sh2


infixr 7 `combineOffset`, `combineSizeOffset`

{-# INLINE combineOffset #-}
combineOffset :: Num a => a -> (a, a) -> a
combineOffset offset0 (size1,offset1) = offset0 * size1 + offset1

{-# INLINE combineSizeOffset #-}
combineSizeOffset :: Num a => (a, a) -> (a, a) -> (a, a)
combineSizeOffset (size0,offset0) (size1,offset1) =
   (size0*size1, offset0 * size1 + offset1)



infixr 5 :+:

data sh0:+:sh1 = sh0:+:sh1
   deriving (Eq, Show)

{-
-- cf. sample-frame:Stereo
instance (Storable sh0, Storable sh1) => Storable (sh0:+:sh1) where
   {-# INLINE sizeOf #-}
   {-# INLINE alignment #-}
   {-# INLINE peek #-}
   {-# INLINE poke #-}
   sizeOf ~(sh0:+:sh1) = sizeOf sh0 + mod (- sizeOf sh0) (alignment sh1) + sizeOf sh1
   alignment ~(sh0:+:sh1) = alignment sh0
   poke p (sh0:+:sh1) =
      let q = castToElemPtr p
      in  poke q sh0 >> pokeElemOff q 1 sh1
   peek p =
      let q = castToElemPtr p
      in  Monad.lift2 Shifted (peek q) (peekElemOff q 1)
-}

instance (C sh0, C sh1) => C (sh0:+:sh1) where
   type Index (sh0:+:sh1) = Either (Index sh0) (Index sh1)
   indices (sh0:+:sh1) = map Left (indices sh0) ++ map Right (indices sh1)
   offset (sh0:+:sh1) ix =
      case ix of
         Left ix0 -> offset sh0 ix0
         Right ix1 -> size sh0 + offset sh1 ix1
   uncheckedOffset (sh0:+:sh1) ix =
      case ix of
         Left ix0 -> uncheckedOffset sh0 ix0
         Right ix1 -> uncheckedSize sh0 + uncheckedOffset sh1 ix1
   sizeOffset (sh0:+:sh1) ix =
      case ix of
         Left ix0 -> mapFst (+ size sh1) $ sizeOffset sh0 ix0
         Right ix1 ->
            let size0 = size sh0
            in  mapPair ((size0+), (size0+)) $ sizeOffset sh1 ix1
   uncheckedSizeOffset (sh0:+:sh1) ix =
      case ix of
         Left ix0 -> mapFst (+ uncheckedSize sh1) $ uncheckedSizeOffset sh0 ix0
         Right ix1 ->
            let size0 = uncheckedSize sh0
            in  mapPair ((size0+), (size0+)) $ uncheckedSizeOffset sh1 ix1
   inBounds (sh0:+:sh1) ix =
      case ix of
         Left ix0 -> inBounds sh0 ix0
         Right ix1 -> inBounds sh1 ix1
   size (sh0:+:sh1) = size sh0 + size sh1
   uncheckedSize (sh0:+:sh1) = uncheckedSize sh0 + uncheckedSize sh1