{- |
Generate and apply index maps.
This unifies the @replicate@ and @slice@ functions of the @accelerate@ package.
However the structure of slicing and replicating cannot depend on parameters.
If you need that, you must use 'ShapeDep.backpermute' and friends.
-}
{-
Some notes on the design choice:

Instead of the shallow embedding implemented by the 'T' type,
we could maintain a symbolic representation of the Slice and Replicate pattern,
like the accelerate package does.
We actually used that representation in former versions.
It has however some drawbacks:

* We need additional type functions that map from the pattern
  to the source and the target shape and we need a proof,
  that the images of these type functions are actually shapes.
  This worked already, but was rather cumbersome.

* We need a way to store and pass this pattern through the Parameter handler.
  This yields new problems:
  We need a wrapper type for wrapping Index, Shape, Slice, Replicate, Fold patterns.
  Then the question is whether we use one Wrap type with a phantom parameter
  or whether we define a Wrap type for every pattern type.
  That is, the options are to write either

  > Wrap Shape (Z:.Int:.Int)

  or

  > Shape (Z:.Int:.Int)

  The first one seems to save us many duplicate instances of
  Storable, MultiValue etc.
  and it allows us easily to reuse the (:.) for all kinds of patterns.
  However, we need a way to restrict the element type of the (:.)-list elements.
  We can define that using variable ConstraintKinds,
  but e.g. we are not able to add a Storable superclass constraint
  to the instance Storable (Wrap constr).
  That is, we are left with the second option
  and had to define a lot of similar Storable, MultiValue instances.
-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Knead.Simple.Slice (
   T,
   Linear,
   apply,
   passAny,
   pass,
   pick,
   pickFst,
   pickSnd,
   extrude,
   extrudeFst,
   extrudeSnd,
   transpose,
   (Core.$:.),

   id,
   first,
   second,
   compose,
   ) where

import qualified Data.Array.Knead.Simple.ShapeDependent as ShapeDep
import qualified Data.Array.Knead.Simple.Private as Core

import qualified Data.Array.Knead.Index.Linear as Linear
import qualified Data.Array.Knead.Index.Nested.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Index.Linear ((#:.), (:.)((:.)), )
import Data.Array.Knead.Expression (Exp, )

import qualified LLVM.Extra.Multi.Value as MultiValue
import LLVM.Extra.Multi.Value (atom, )

import qualified Prelude as P
import Prelude hiding (id, zipWith, zipWith3, zip, zip3, replicate, )



{-
This data type is almost identical to Core.Array.
The only difference is,
that the shape @sh1@ in T can depend on another shape @sh0@.
-}
data T sh0 sh1 =
   forall ix0 ix1.
   (Shape.Index sh0 ~ ix0, Shape.Index sh1 ~ ix1) =>
   Cons
      (Exp sh0 -> Exp sh1)
      (Exp ix1 -> Exp ix0)

{- |
This is essentially a 'ShapeDep.backpermute'.
-}
apply ::
   (Core.C array, Shape.C sh0, Shape.C sh1, MultiValue.C a) =>
   T sh0 sh1 ->
   array sh0 a ->
   array sh1 a
apply (Cons fsh fix) =
   ShapeDep.backpermute fsh fix


pickFst :: Exp (Shape.Index n) -> T (n,sh) sh
pickFst i = Cons Expr.snd (Expr.zip i)

pickSnd :: Exp (Shape.Index n) -> T (sh,n) sh
pickSnd i = Cons Expr.fst (flip Expr.zip i)

{- |
Extrusion has the potential to do duplicate work.
Only use it to add dimensions of size 1, e.g. numeric 1 or unit @()@
or to duplicate slices of physical arrays.
-}
extrudeFst :: Exp n -> T sh (n,sh)
extrudeFst n = Cons (Expr.zip n) Expr.snd

extrudeSnd :: Exp n -> T sh (sh,n)
extrudeSnd n = Cons (flip Expr.zip n) Expr.fst

transpose :: T (sh0,sh1) (sh1,sh0)
transpose = Cons Expr.swap Expr.swap


-- Arrow combinators

id :: T sh sh
id = Cons P.id P.id

first :: T sh0 sh1 -> T (sh0,sh) (sh1,sh)
first (Cons fsh fix) = Cons (Expr.mapFst fsh) (Expr.mapFst fix)

second :: T sh0 sh1 -> T (sh,sh0) (sh,sh1)
second (Cons fsh fix) = Cons (Expr.mapSnd fsh) (Expr.mapSnd fix)

infixr 1 `compose`

compose :: T sh0 sh1 -> T sh1 sh2 -> T sh0 sh2
compose (Cons fshA fixA) (Cons fshB fixB) = Cons (fshB . fshA) (fixA . fixB)


type Linear sh0 sh1 = T (Linear.Shape sh0) (Linear.Shape sh1)

{- |
Like @Any@ in @accelerate@.
-}
passAny :: Linear sh sh
passAny = Cons P.id P.id

{- |
Like @All@ in @accelerate@.
-}
pass ::
   Linear sh0 sh1 ->
   Linear (sh0:.i) (sh1:.i)
pass (Cons fsh fix) =
   Cons
      (Expr.modify (Linear.shape (atom:.atom)) $ \(sh:.s) -> fsh sh :. s)
      (Expr.modify (Linear.index (atom:.atom)) $ \(ix:.i) -> fix ix :. i)

{- |
Like @Int@ in @accelerate/slice@.
-}
pick ::
   Exp i ->
   Linear sh0 sh1 ->
   Linear (sh0:.i) sh1
pick i (Cons fsh fix) =
   Cons
      (fsh . Linear.tail)
      (\ix -> fix ix #:. i)

{- |
Like @Int@ in @accelerate/replicate@.
-}
extrude ::
   Exp i ->
   Linear sh0 sh1 ->
   Linear sh0 (sh1:.i)
extrude n (Cons fsh fix) =
   Cons
      (\sh -> fsh sh #:. n)
      (fix . Linear.tail)


instance Core.Process (T sh0 sh1) where