{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Comfort.Shape (
   C(..),
   Indexed(..),
   InvIndexed(..),

   ZeroBased(..),
   OneBased(..),

   Range(..),
   Shifted(..),
   (:+:)(..),
   ) where

import qualified Foreign.Storable.Newtype as Store
import Foreign.Storable
         (Storable, sizeOf, alignment, poke, peek, pokeElemOff, peekElemOff)
import Foreign.Ptr (Ptr, castPtr)

import qualified GHC.Arr as Ix

import qualified Control.Monad.Trans.State as MS
import qualified Control.Monad.HT as Monad
import qualified Control.Applicative.Backwards as Back
import Control.Applicative (liftA2, liftA3)

import Text.Printf (printf)

import Data.Tuple.HT (mapFst, mapPair, swap)


class C sh where
   -- Ix.rangeSize
   size :: sh -> Int
   -- Ix.unsafeRangeSize
   uncheckedSize :: sh -> Int
   uncheckedSize = size

class C sh => Indexed sh where
   {-# MINIMAL indices, (sizeOffset|offset), inBounds #-}
   type Index sh :: *
   -- Ix.range
   indices :: sh -> [Index sh]
   -- Ix.index
   offset :: sh -> Index sh -> Int
   offset sh ix = snd $ sizeOffset sh ix
   -- Ix.unsafeIndex
   uncheckedOffset :: sh -> Index sh -> Int
   uncheckedOffset = offset
   -- Ix.inRange
   inBounds :: sh -> Index sh -> Bool

   sizeOffset :: sh -> Index sh -> (Int,Int)
   sizeOffset sh ix = (size sh, offset sh ix)
   uncheckedSizeOffset :: sh -> Index sh -> (Int,Int)
   uncheckedSizeOffset sh ix = (uncheckedSize sh, uncheckedOffset sh ix)

class Indexed sh => InvIndexed sh where
   {- |
   It should hold @indexFromOffset sh k == indices sh !! k@,
   but 'indexFromOffset' should generally be faster.
   -}
   indexFromOffset :: sh -> Int -> Index sh
   uncheckedIndexFromOffset :: sh -> Int -> Index sh
   uncheckedIndexFromOffset = indexFromOffset

errorIndexFromOffset :: String -> Int -> a
errorIndexFromOffset name k =
   error $ printf "indexFromOffset (%s): index %d out of range" name k


instance C () where
   size () = 1
   uncheckedSize () = 1

instance Indexed () where
   type Index () = ()
   indices () = [()]
   offset () () = 0
   uncheckedOffset () () = 0
   inBounds () () = True

instance InvIndexed () where
   indexFromOffset () 0 = ()
   indexFromOffset () k = errorIndexFromOffset "()" k
   uncheckedIndexFromOffset () _ = ()


{- |
'ZeroBased' denotes a range starting at zero and has a certain length.
-}
newtype ZeroBased n = ZeroBased {zeroBasedSize :: n}
   deriving (Eq, Show)

instance Functor ZeroBased where
   fmap f (ZeroBased n) = ZeroBased $ f n

instance (Storable n) => Storable (ZeroBased n) where
   sizeOf = Store.sizeOf zeroBasedSize
   alignment = Store.alignment zeroBasedSize
   peek = Store.peek ZeroBased
   poke = Store.poke zeroBasedSize

instance (Integral n) => C (ZeroBased n) where
   size (ZeroBased len) = fromIntegral len
   uncheckedSize (ZeroBased len) = fromIntegral len

instance (Integral n) => Indexed (ZeroBased n) where
   type Index (ZeroBased n) = n
   indices (ZeroBased len) = indices $ Shifted 0 len
   offset (ZeroBased len) = offset $ Shifted 0 len
   uncheckedOffset _ ix = fromIntegral ix
   inBounds (ZeroBased len) ix = 0<=ix && ix<len

instance (Integral n) => InvIndexed (ZeroBased n) where
   indexFromOffset (ZeroBased len) k0 =
      let k = fromIntegral k0
      in  if 0<=k && k<len
            then k
            else errorIndexFromOffset "ZeroBased" k0
   uncheckedIndexFromOffset _ k = fromIntegral k


{- |
'OneBased' denotes a range starting at one and has a certain length.
-}
newtype OneBased n = OneBased {oneBasedSize :: n}
   deriving (Eq, Show)

instance Functor OneBased where
   fmap f (OneBased n) = OneBased $ f n

instance (Storable n) => Storable (OneBased n) where
   sizeOf = Store.sizeOf oneBasedSize
   alignment = Store.alignment oneBasedSize
   peek = Store.peek OneBased
   poke = Store.poke oneBasedSize

instance (Integral n) => C (OneBased n) where
   size (OneBased len) = fromIntegral len
   uncheckedSize (OneBased len) = fromIntegral len

instance (Integral n) => Indexed (OneBased n) where
   type Index (OneBased n) = n
   indices (OneBased len) = indices $ Shifted 1 len
   offset (OneBased len) = offset $ Shifted 1 len
   uncheckedOffset _ ix = fromIntegral ix - 1
   inBounds (OneBased len) ix = 0<ix && ix<=len

instance (Integral n) => InvIndexed (OneBased n) where
   indexFromOffset (OneBased len) k0 =
      let k = fromIntegral k0
      in  if 0<=k && k<len
            then 1+k
            else errorIndexFromOffset "OneBased" k0
   uncheckedIndexFromOffset _ k = 1 + fromIntegral k


{- |
'Range' denotes an inclusive range like
those of the Haskell 98 standard @Array@ type from the @array@ package.
E.g. the shape type @(Range Int32, Range Int64)@
is equivalent to the ix type @(Int32, Int64)@ for @Array@s.
-}
data Range n = Range {rangeFrom, rangeTo :: n}
   deriving (Eq, Show)

instance Functor Range where
   fmap f (Range from to) = Range (f from) (f to)

instance (Ix.Ix n) => C (Range n) where
   size (Range from to) = Ix.rangeSize (from,to)
   uncheckedSize (Range from to) = Ix.unsafeRangeSize (from,to)

instance (Ix.Ix n) => Indexed (Range n) where
   type Index (Range n) = n
   indices (Range from to) = Ix.range (from,to)
   offset (Range from to) ix = Ix.index (from,to) ix
   uncheckedOffset (Range from to) ix = Ix.unsafeIndex (from,to) ix
   inBounds (Range from to) ix = Ix.inRange (from,to) ix

-- pretty inefficient when we rely solely on Ix
instance (Ix.Ix n) => InvIndexed (Range n) where
   indexFromOffset (Range from to) k =
      if 0<=k && k < Ix.rangeSize (from,to)
         then Ix.range (from,to) !! k
         else errorIndexFromOffset "Range" k
   uncheckedIndexFromOffset (Range from to) k = Ix.range (from,to) !! k

-- cf. sample-frame:Stereo
instance Storable n => Storable (Range n) where
   {-# INLINE sizeOf #-}
   {-# INLINE alignment #-}
   {-# INLINE peek #-}
   {-# INLINE poke #-}
   sizeOf ~(Range l r) = sizeOf l + mod (- sizeOf l) (alignment r) + sizeOf r
   alignment ~(Range l _) = alignment l
   poke p (Range l r) =
      let q = castToElemPtr p
      in  poke q l >> pokeElemOff q 1 r
   peek p =
      let q = castToElemPtr p
      in  Monad.lift2 Range (peek q) (peekElemOff q 1)


{- |
'Shifted' denotes a range defined by the start index and the length.
-}
data Shifted n = Shifted {shiftedOffset, shiftedSize :: n}
   deriving (Eq, Show)

instance Functor Shifted where
   fmap f (Shifted from to) = Shifted (f from) (f to)

instance (Integral n) => C (Shifted n) where
   size (Shifted _offs len) = fromIntegral len
   uncheckedSize (Shifted _offs len) = fromIntegral len

instance (Integral n) => Indexed (Shifted n) where
   type Index (Shifted n) = n
   indices (Shifted offs len) =
      map snd $
      takeWhile ((>0) . fst) $
      zip
         (iterate (subtract 1) len)
         (iterate (1+) offs)
   offset (Shifted offs len) ix =
      if ix<offs
        then error "Shape.Shifted: array index too small"
        else
          let k = ix-offs
          in  if k<len
                then fromIntegral k
                else error "Shape.Shifted: array index too big"
   uncheckedOffset (Shifted offs _len) ix = fromIntegral $ ix-offs
   inBounds (Shifted offs len) ix = ix < offs+len

instance (Integral n) => InvIndexed (Shifted n) where
   indexFromOffset (Shifted offs len) k0 =
      let k = fromIntegral k0
      in  if 0<=k && k<len
            then offs+k
            else errorIndexFromOffset "Shifted" k0
   uncheckedIndexFromOffset (Shifted offs _len) k = offs + fromIntegral k

-- cf. sample-frame:Stereo
instance Storable n => Storable (Shifted n) where
   {-# INLINE sizeOf #-}
   {-# INLINE alignment #-}
   {-# INLINE peek #-}
   {-# INLINE poke #-}
   sizeOf ~(Shifted l n) = sizeOf l + mod (- sizeOf l) (alignment n) + sizeOf n
   alignment ~(Shifted l _) = alignment l
   poke p (Shifted l n) =
      let q = castToElemPtr p
      in  poke q l >> pokeElemOff q 1 n
   peek p =
      let q = castToElemPtr p
      in  Monad.lift2 Shifted (peek q) (peekElemOff q 1)


{-# INLINE castToElemPtr #-}
castToElemPtr :: Ptr (f a) -> Ptr a
castToElemPtr = castPtr



{- |
Row-major composition of two dimensions.
-}
instance (C sh0, C sh1) => C (sh0,sh1) where
   size (sh0,sh1) = size sh0 * size sh1
   uncheckedSize (sh0,sh1) = uncheckedSize sh0 * uncheckedSize sh1

instance (Indexed sh0, Indexed sh1) => Indexed (sh0,sh1) where
   type Index (sh0,sh1) = (Index sh0, Index sh1)
   indices (sh0,sh1) = Monad.lift2 (,) (indices sh0) (indices sh1)
   offset (sh0,sh1) (ix0,ix1) =
      offset sh0 ix0 `combineOffset` sizeOffset sh1 ix1
   uncheckedOffset (sh0,sh1) (ix0,ix1) =
      uncheckedOffset sh0 ix0 `combineOffset` uncheckedSizeOffset sh1 ix1
   sizeOffset (sh0,sh1) (ix0,ix1) =
      sizeOffset sh0 ix0 `combineSizeOffset` sizeOffset sh1 ix1
   uncheckedSizeOffset (sh0,sh1) (ix0,ix1) =
      uncheckedSizeOffset sh0 ix0
      `combineSizeOffset`
      uncheckedSizeOffset sh1 ix1
   inBounds (sh0,sh1) (ix0,ix1) = inBounds sh0 ix0 && inBounds sh1 ix1

instance (InvIndexed sh0, InvIndexed sh1) => InvIndexed (sh0,sh1) where
   indexFromOffset (sh0,sh1) k =
      runInvIndex k $ liftA2 (,) (pickLastIndex sh0) (pickIndex sh1)
   uncheckedIndexFromOffset (sh0,sh1) k =
      runInvIndex k $ liftA2 (,) (uncheckedPickLastIndex sh0) (pickIndex sh1)


instance (C sh0, C sh1, C sh2) => C (sh0,sh1,sh2) where
   size (sh0,sh1,sh2) = size sh0 * size sh1 * size sh2
   uncheckedSize (sh0,sh1,sh2) =
      uncheckedSize sh0 * uncheckedSize sh1 * uncheckedSize sh2

instance (Indexed sh0, Indexed sh1, Indexed sh2) => Indexed (sh0,sh1,sh2) where
   type Index (sh0,sh1,sh2) = (Index sh0, Index sh1, Index sh2)
   indices (sh0,sh1,sh2) =
      Monad.lift3 (,,) (indices sh0) (indices sh1) (indices sh2)
   uncheckedOffset (sh0,sh1,sh2) (ix0,ix1,ix2) =
      uncheckedOffset sh0 ix0
      `combineOffset`
      uncheckedSizeOffset sh1 ix1
      `combineSizeOffset`
      uncheckedSizeOffset sh2 ix2
   sizeOffset (sh0,sh1,sh2) (ix0,ix1,ix2) =
      sizeOffset sh0 ix0
      `combineSizeOffset`
      sizeOffset sh1 ix1
      `combineSizeOffset`
      sizeOffset sh2 ix2
   uncheckedSizeOffset (sh0,sh1,sh2) (ix0,ix1,ix2) =
      uncheckedSizeOffset sh0 ix0
      `combineSizeOffset`
      uncheckedSizeOffset sh1 ix1
      `combineSizeOffset`
      uncheckedSizeOffset sh2 ix2
   inBounds (sh0,sh1,sh2) (ix0,ix1,ix2) =
      inBounds sh0 ix0 && inBounds sh1 ix1 && inBounds sh2 ix2

instance
   (InvIndexed sh0, InvIndexed sh1, InvIndexed sh2) =>
      InvIndexed (sh0,sh1,sh2) where
   indexFromOffset (sh0,sh1,sh2) k =
      runInvIndex k $
      liftA3 (,,) (pickLastIndex sh0) (pickIndex sh1) (pickIndex sh2)
   uncheckedIndexFromOffset (sh0,sh1,sh2) k =
      runInvIndex k $
      liftA3 (,,) (uncheckedPickLastIndex sh0) (pickIndex sh1) (pickIndex sh2)

runInvIndex :: s -> Back.Backwards (MS.State s) a -> a
runInvIndex k = flip MS.evalState k . Back.forwards

pickLastIndex ::
   (InvIndexed sh) => sh -> Back.Backwards (MS.State Int) (Index sh)
pickLastIndex sh =
   Back.Backwards $ MS.gets $ indexFromOffset sh

uncheckedPickLastIndex ::
   (InvIndexed sh) => sh -> Back.Backwards (MS.State Int) (Index sh)
uncheckedPickLastIndex sh =
   Back.Backwards $ MS.gets $ uncheckedIndexFromOffset sh

pickIndex :: (InvIndexed sh) => sh -> Back.Backwards (MS.State Int) (Index sh)
pickIndex sh =
   fmap (uncheckedIndexFromOffset sh) $
   Back.Backwards $ MS.state $ \k -> swap $ divMod k $ size sh



infixr 7 `combineOffset`, `combineSizeOffset`

{-# INLINE combineOffset #-}
combineOffset :: Num a => a -> (a, a) -> a
combineOffset offset0 (size1,offset1) = offset0 * size1 + offset1

{-# INLINE combineSizeOffset #-}
combineSizeOffset :: Num a => (a, a) -> (a, a) -> (a, a)
combineSizeOffset (size0,offset0) (size1,offset1) =
   (size0*size1, offset0 * size1 + offset1)



infixr 5 :+:

data sh0:+:sh1 = sh0:+:sh1
   deriving (Eq, Show)

instance (C sh0, C sh1) => C (sh0:+:sh1) where
   size (sh0:+:sh1) = size sh0 + size sh1
   uncheckedSize (sh0:+:sh1) = uncheckedSize sh0 + uncheckedSize sh1

instance (Indexed sh0, Indexed sh1) => Indexed (sh0:+:sh1) where
   type Index (sh0:+:sh1) = Either (Index sh0) (Index sh1)
   indices (sh0:+:sh1) = map Left (indices sh0) ++ map Right (indices sh1)
   offset (sh0:+:sh1) ix =
      case ix of
         Left ix0 -> offset sh0 ix0
         Right ix1 -> size sh0 + offset sh1 ix1
   uncheckedOffset (sh0:+:sh1) ix =
      case ix of
         Left ix0 -> uncheckedOffset sh0 ix0
         Right ix1 -> uncheckedSize sh0 + uncheckedOffset sh1 ix1
   sizeOffset (sh0:+:sh1) ix =
      case ix of
         Left ix0 -> mapFst (+ size sh1) $ sizeOffset sh0 ix0
         Right ix1 ->
            let size0 = size sh0
            in  mapPair ((size0+), (size0+)) $ sizeOffset sh1 ix1
   uncheckedSizeOffset (sh0:+:sh1) ix =
      case ix of
         Left ix0 -> mapFst (+ uncheckedSize sh1) $ uncheckedSizeOffset sh0 ix0
         Right ix1 ->
            let size0 = uncheckedSize sh0
            in  mapPair ((size0+), (size0+)) $ uncheckedSizeOffset sh1 ix1
   inBounds (sh0:+:sh1) = either (inBounds sh0) (inBounds sh1)

instance (InvIndexed sh0, InvIndexed sh1) => InvIndexed (sh0:+:sh1) where
   indexFromOffset (sh0:+:sh1) k =
      let pivot = size sh0
      in if k < pivot
            then Left $ indexFromOffset sh0 k
            else Right $ indexFromOffset sh1 $ k-pivot
   uncheckedIndexFromOffset (sh0:+:sh1) k =
      let pivot = size sh0
      in if k < pivot
            then Left $ uncheckedIndexFromOffset sh0 k
            else Right $ uncheckedIndexFromOffset sh1 $ k-pivot