{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators       #-}
-- |
-- Module      : Data.Array.Accelerate.Trafo.Environment
-- Copyright   : [2012..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Trafo.Environment
  where

import Data.Array.Accelerate.AST
import Data.Array.Accelerate.AST.Environment
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Trafo.Substitution
import Data.Array.Accelerate.Type

import Data.Array.Accelerate.Debug.Stats                            as Stats


-- An environment that holds let-bound scalar expressions. The second
-- environment variable env' is used to project out the corresponding
-- index when looking up in the environment congruent expressions.
--
data Gamma env env' aenv where
  EmptyExp :: Gamma env env' aenv

  PushExp  :: Gamma env env' aenv
           -> WeakOpenExp env aenv t
           -> Gamma env (env', t) aenv

data WeakOpenExp env aenv t where
  Subst    :: env :> env'
           -> OpenExp     env  aenv t
           -> OpenExp     env' aenv t {- LAZY -}
           -> WeakOpenExp env' aenv t

-- XXX: The simplifier calls this function every time it moves under a let
-- binding. This means we have a number of calls to 'weakenE' exponential in the
-- depth of nested let bindings, which quickly causes problems.
--
-- We can improve the situation slightly by observing that weakening by a single
-- variable does no less work than weaking by multiple variables at once; both
-- require a deep copy of the AST. By exploiting laziness (or, an IORef) we can
-- queue up multiple weakenings to happen in a single step.
--
-- <https://github.com/AccelerateHS/accelerate-llvm/issues/20>
--
incExp
    :: Gamma env     env' aenv
    -> Gamma (env,s) env' aenv
incExp :: Gamma env env' aenv -> Gamma (env, s) env' aenv
incExp Gamma env env' aenv
EmptyExp        = Gamma (env, s) env' aenv
forall env env' aenv. Gamma env env' aenv
EmptyExp
incExp (PushExp Gamma env env' aenv
env WeakOpenExp env aenv t
w) = Gamma env env' aenv -> Gamma (env, s) env' aenv
forall env env' aenv s.
Gamma env env' aenv -> Gamma (env, s) env' aenv
incExp Gamma env env' aenv
env Gamma (env, s) env' aenv
-> WeakOpenExp (env, s) aenv t -> Gamma (env, s) (env', t) aenv
forall env env' aenv t.
Gamma env env' aenv
-> WeakOpenExp env aenv t -> Gamma env (env', t) aenv
`PushExp` WeakOpenExp env aenv t -> WeakOpenExp (env, s) aenv t
forall env aenv s t.
WeakOpenExp env aenv t -> WeakOpenExp (env, s) aenv t
subs WeakOpenExp env aenv t
w
  where
    subs :: forall env aenv s t. WeakOpenExp env aenv t -> WeakOpenExp (env,s) aenv t
    subs :: WeakOpenExp env aenv t -> WeakOpenExp (env, s) aenv t
subs (Subst env :> env
k (OpenExp env aenv t
e :: OpenExp env_ aenv t) OpenExp env aenv t
_) = (env :> (env, s))
-> OpenExp env aenv t
-> OpenExp (env, s) aenv t
-> WeakOpenExp (env, s) aenv t
forall env env' aenv t.
(env :> env')
-> OpenExp env aenv t
-> OpenExp env' aenv t
-> WeakOpenExp env' aenv t
Subst ((env :> env) -> env :> (env, s)
forall env env' t. (env :> env') -> env :> (env', t)
weakenSucc' env :> env
k) OpenExp env aenv t
e ((env :> (env, s)) -> OpenExp env aenv t -> OpenExp (env, s) aenv t
forall (f :: * -> * -> * -> *) env env' aenv t.
SinkExp f =>
(env :> env') -> f env aenv t -> f env' aenv t
weakenE ((env :> env) -> env :> (env, s)
forall env env' t. (env :> env') -> env :> (env', t)
weakenSucc' env :> env
k) OpenExp env aenv t
e)

prjExp :: HasCallStack => Idx env' t -> Gamma env env' aenv -> OpenExp env aenv t
prjExp :: Idx env' t -> Gamma env env' aenv -> OpenExp env aenv t
prjExp Idx env' t
ZeroIdx      (PushExp Gamma env env' aenv
_   (Subst env :> env
_ OpenExp env aenv t
_ OpenExp env aenv t
e)) = OpenExp env aenv t
OpenExp env aenv t
e
prjExp (SuccIdx Idx env t
ix) (PushExp Gamma env env' aenv
env WeakOpenExp env aenv t
_)             = Idx env t -> Gamma env env aenv -> OpenExp env aenv t
forall env' t env aenv.
HasCallStack =>
Idx env' t -> Gamma env env' aenv -> OpenExp env aenv t
prjExp Idx env t
ix Gamma env env aenv
Gamma env env' aenv
env
prjExp Idx env' t
_            Gamma env env' aenv
_                           = String -> OpenExp env aenv t
forall a. HasCallStack => String -> a
internalError String
"inconsistent valuation"

pushExp :: Gamma env env' aenv -> OpenExp env aenv t -> Gamma env (env',t) aenv
pushExp :: Gamma env env' aenv
-> OpenExp env aenv t -> Gamma env (env', t) aenv
pushExp Gamma env env' aenv
env OpenExp env aenv t
e = Gamma env env' aenv
env Gamma env env' aenv
-> WeakOpenExp env aenv t -> Gamma env (env', t) aenv
forall env env' aenv t.
Gamma env env' aenv
-> WeakOpenExp env aenv t -> Gamma env (env', t) aenv
`PushExp` (env :> env)
-> OpenExp env aenv t
-> OpenExp env aenv t
-> WeakOpenExp env aenv t
forall env env' aenv t.
(env :> env')
-> OpenExp env aenv t
-> OpenExp env' aenv t
-> WeakOpenExp env' aenv t
Subst env :> env
forall env. env :> env
weakenId OpenExp env aenv t
e OpenExp env aenv t
e

{--
lookupExp
    :: Gamma   env env' aenv
    -> OpenExp env      aenv t
    -> Maybe (Idx env' t)
lookupExp EmptyExp        _ = Nothing
lookupExp (PushExp env e) x
  | Just Refl <- match e x  = Just ZeroIdx
  | otherwise               = SuccIdx `fmap` lookupExp env x

weakenGamma1
    :: Gamma env env' aenv
    -> Gamma env env' (aenv,t)
weakenGamma1 EmptyExp        = EmptyExp
weakenGamma1 (PushExp env e) = PushExp (weakenGamma1 env) (weaken SuccIdx e)

sinkGamma
    :: Kit acc
    => Extend acc aenv aenv'
    -> Gamma env env' aenv
    -> Gamma env env' aenv'
sinkGamma _   EmptyExp        = EmptyExp
sinkGamma ext (PushExp env e) = PushExp (sinkGamma ext env) (sinkA ext e)
--}

-- As part of various transformations we often need to lift out array valued
-- inputs to be let-bound at a higher point.
--
-- The Extend type is a heterogeneous snoc-list of array terms that witnesses
-- how the array environment is extended by binding these additional terms.
--
data Extend s f env env' where
  BaseEnv :: Extend s f env env

  PushEnv :: Extend s f env env'
          -> LeftHandSide s t env' env''
          -> f env' t
          -> Extend s f env env''

pushArrayEnv
    :: HasArraysR acc
    => Extend ArrayR acc aenv aenv'
    -> acc aenv' (Array sh e)
    -> Extend ArrayR acc aenv (aenv', Array sh e)
pushArrayEnv :: Extend ArrayR acc aenv aenv'
-> acc aenv' (Array sh e)
-> Extend ArrayR acc aenv (aenv', Array sh e)
pushArrayEnv Extend ArrayR acc aenv aenv'
env acc aenv' (Array sh e)
a = Extend ArrayR acc aenv aenv'
-> LeftHandSide ArrayR (Array sh e) aenv' (aenv', Array sh e)
-> acc aenv' (Array sh e)
-> Extend ArrayR acc aenv (aenv', Array sh e)
forall (s :: * -> *) (f :: * -> * -> *) env env' t env''.
Extend s f env env'
-> LeftHandSide s t env' env'' -> f env' t -> Extend s f env env''
PushEnv Extend ArrayR acc aenv aenv'
env (ArrayR (Array sh e)
-> LeftHandSide ArrayR (Array sh e) aenv' (aenv', Array sh e)
forall (s :: * -> *) v env. s v -> LeftHandSide s v env (env, v)
LeftHandSideSingle (ArrayR (Array sh e)
 -> LeftHandSide ArrayR (Array sh e) aenv' (aenv', Array sh e))
-> ArrayR (Array sh e)
-> LeftHandSide ArrayR (Array sh e) aenv' (aenv', Array sh e)
forall a b. (a -> b) -> a -> b
$ acc aenv' (Array sh e) -> ArrayR (Array sh e)
forall (f :: * -> * -> *) aenv sh e.
HasArraysR f =>
f aenv (Array sh e) -> ArrayR (Array sh e)
arrayR acc aenv' (Array sh e)
a) acc aenv' (Array sh e)
a


-- Append two environment witnesses
--
append :: Extend s acc env env' -> Extend s acc env' env'' -> Extend s acc env env''
append :: Extend s acc env env'
-> Extend s acc env' env'' -> Extend s acc env env''
append Extend s acc env env'
x Extend s acc env' env''
BaseEnv           = Extend s acc env env'
Extend s acc env env''
x
append Extend s acc env env'
x (PushEnv Extend s acc env' env'
e LeftHandSide s t env' env''
lhs acc env' t
a) = Extend s acc env env'
-> LeftHandSide s t env' env''
-> acc env' t
-> Extend s acc env env''
forall (s :: * -> *) (f :: * -> * -> *) env env' t env''.
Extend s f env env'
-> LeftHandSide s t env' env'' -> f env' t -> Extend s f env env''
PushEnv (Extend s acc env env'
-> Extend s acc env' env' -> Extend s acc env env'
forall (s :: * -> *) (acc :: * -> * -> *) env env' env''.
Extend s acc env env'
-> Extend s acc env' env'' -> Extend s acc env env''
append Extend s acc env env'
x Extend s acc env' env'
e) LeftHandSide s t env' env''
lhs acc env' t
a

-- Bring into scope all of the array terms in the Extend environment list. This
-- converts a term in the inner environment (aenv') into the outer (aenv).
--
bind :: (forall env t. PreOpenAcc acc env t -> acc env t)
     -> Extend ArrayR  acc aenv aenv'
     -> PreOpenAcc acc      aenv' a
     -> PreOpenAcc acc aenv       a
bind :: (forall env t. PreOpenAcc acc env t -> acc env t)
-> Extend ArrayR acc aenv aenv'
-> PreOpenAcc acc aenv' a
-> PreOpenAcc acc aenv a
bind forall env t. PreOpenAcc acc env t -> acc env t
_      Extend ArrayR acc aenv aenv'
BaseEnv           = PreOpenAcc acc aenv' a -> PreOpenAcc acc aenv a
forall a. a -> a
id
bind forall env t. PreOpenAcc acc env t -> acc env t
inject (PushEnv Extend ArrayR acc aenv env'
g LeftHandSide ArrayR t env' aenv'
lhs acc env' t
a) = (forall env t. PreOpenAcc acc env t -> acc env t)
-> Extend ArrayR acc aenv env'
-> PreOpenAcc acc env' a
-> PreOpenAcc acc aenv a
forall (acc :: * -> * -> *) aenv aenv' a.
(forall env t. PreOpenAcc acc env t -> acc env t)
-> Extend ArrayR acc aenv aenv'
-> PreOpenAcc acc aenv' a
-> PreOpenAcc acc aenv a
bind forall env t. PreOpenAcc acc env t -> acc env t
inject Extend ArrayR acc aenv env'
g (PreOpenAcc acc env' a -> PreOpenAcc acc aenv a)
-> (PreOpenAcc acc aenv' a -> PreOpenAcc acc env' a)
-> PreOpenAcc acc aenv' a
-> PreOpenAcc acc aenv a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LeftHandSide ArrayR t env' aenv'
-> acc env' t -> acc aenv' a -> PreOpenAcc acc env' a
forall bndArrs aenv aenv' (acc :: * -> * -> *) bodyArrs.
ALeftHandSide bndArrs aenv aenv'
-> acc aenv bndArrs
-> acc aenv' bodyArrs
-> PreOpenAcc acc aenv bodyArrs
Alet LeftHandSide ArrayR t env' aenv'
lhs acc env' t
a (acc aenv' a -> PreOpenAcc acc env' a)
-> (PreOpenAcc acc aenv' a -> acc aenv' a)
-> PreOpenAcc acc aenv' a
-> PreOpenAcc acc env' a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PreOpenAcc acc aenv' a -> acc aenv' a
forall env t. PreOpenAcc acc env t -> acc env t
inject

-- Sink a term from one array environment into another, where additional
-- bindings have come into scope according to the witness and no old things have
-- vanished.
--
sinkA :: Sink f => Extend s acc env env' -> f env t -> f env' t
sinkA :: Extend s acc env env' -> f env t -> f env' t
sinkA Extend s acc env env'
env = (env :> env') -> f env t -> f env' t
forall (f :: * -> * -> *) env env' t.
Sink f =>
(env :> env') -> f env t -> f env' t
weaken (Extend s acc env env' -> env :> env'
forall (s :: * -> *) (acc :: * -> * -> *) env env'.
Extend s acc env env' -> env :> env'
sinkWeaken Extend s acc env env'
env) -- TODO: Fix Stats sinkA vs sink1

sink1 :: Sink f => Extend s acc env env' -> f (env,t') t -> f (env',t') t
sink1 :: Extend s acc env env' -> f (env, t') t -> f (env', t') t
sink1 Extend s acc env env'
env = ((env, t') :> (env', t')) -> f (env, t') t -> f (env', t') t
forall (f :: * -> * -> *) env env' t.
Sink f =>
(env :> env') -> f env t -> f env' t
weaken (((env, t') :> (env', t')) -> f (env, t') t -> f (env', t') t)
-> ((env, t') :> (env', t')) -> f (env, t') t -> f (env', t') t
forall a b. (a -> b) -> a -> b
$ (env :> env') -> (env, t') :> (env', t')
forall env env' t. (env :> env') -> (env, t) :> (env', t)
sink ((env :> env') -> (env, t') :> (env', t'))
-> (env :> env') -> (env, t') :> (env', t')
forall a b. (a -> b) -> a -> b
$ Extend s acc env env' -> env :> env'
forall (s :: * -> *) (acc :: * -> * -> *) env env'.
Extend s acc env env' -> env :> env'
sinkWeaken Extend s acc env env'
env

sinkWeaken :: Extend s acc env env' -> env :> env'
sinkWeaken :: Extend s acc env env' -> env :> env'
sinkWeaken (PushEnv Extend s acc env env'
e (LeftHandSideWildcard TupR s t
_) acc env' t
_) = Extend s acc env env' -> env :> env'
forall (s :: * -> *) (acc :: * -> * -> *) env env'.
Extend s acc env env' -> env :> env'
sinkWeaken Extend s acc env env'
e
sinkWeaken (PushEnv Extend s acc env env'
e (LeftHandSideSingle s t
_)   acc env' t
_) = (env :> env') -> env :> (env', t)
forall env env' t. (env :> env') -> env :> (env', t)
weakenSucc' ((env :> env') -> env :> (env', t))
-> (env :> env') -> env :> (env', t)
forall a b. (a -> b) -> a -> b
$ Extend s acc env env' -> env :> env'
forall (s :: * -> *) (acc :: * -> * -> *) env env'.
Extend s acc env env' -> env :> env'
sinkWeaken Extend s acc env env'
e
sinkWeaken (PushEnv Extend s acc env env'
e (LeftHandSidePair LeftHandSide s v1 env' env'
l1 LeftHandSide s v2 env' env'
l2) acc env' t
_) = Extend s acc env env' -> env :> env'
forall (s :: * -> *) (acc :: * -> * -> *) env env'.
Extend s acc env env' -> env :> env'
sinkWeaken (Extend s acc env env'
-> LeftHandSide s v2 env' env'
-> acc env' v2
-> Extend s acc env env'
forall (s :: * -> *) (f :: * -> * -> *) env env' t env''.
Extend s f env env'
-> LeftHandSide s t env' env'' -> f env' t -> Extend s f env env''
PushEnv (Extend s acc env env'
-> LeftHandSide s v1 env' env'
-> acc env' v1
-> Extend s acc env env'
forall (s :: * -> *) (f :: * -> * -> *) env env' t env''.
Extend s f env env'
-> LeftHandSide s t env' env'' -> f env' t -> Extend s f env env''
PushEnv Extend s acc env env'
e LeftHandSide s v1 env' env'
l1 acc env' v1
forall a. HasCallStack => a
undefined) LeftHandSide s v2 env' env'
l2 acc env' v2
forall a. HasCallStack => a
undefined)
sinkWeaken Extend s acc env env'
BaseEnv = Text -> (env :> env) -> env :> env
forall a. Text -> a -> a
Stats.substitution Text
"sink" env :> env
forall env. env :> env
weakenId

-- Wrapper around OpenExp, with the order of type arguments env and aenv flipped
newtype OpenExp' aenv env e = OpenExp' (OpenExp env aenv e)

bindExps :: Extend ScalarType (OpenExp' aenv) env env'
         -> OpenExp env' aenv e
         -> OpenExp env  aenv e
bindExps :: Extend ScalarType (OpenExp' aenv) env env'
-> OpenExp env' aenv e -> OpenExp env aenv e
bindExps Extend ScalarType (OpenExp' aenv) env env'
BaseEnv = OpenExp env' aenv e -> OpenExp env aenv e
forall a. a -> a
id
bindExps (PushEnv Extend ScalarType (OpenExp' aenv) env env'
g LeftHandSide ScalarType t env' env'
lhs (OpenExp' OpenExp env' aenv t
b)) = Extend ScalarType (OpenExp' aenv) env env'
-> OpenExp env' aenv e -> OpenExp env aenv e
forall aenv env env' e.
Extend ScalarType (OpenExp' aenv) env env'
-> OpenExp env' aenv e -> OpenExp env aenv e
bindExps Extend ScalarType (OpenExp' aenv) env env'
g (OpenExp env' aenv e -> OpenExp env aenv e)
-> (OpenExp env' aenv e -> OpenExp env' aenv e)
-> OpenExp env' aenv e
-> OpenExp env aenv e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LeftHandSide ScalarType t env' env'
-> OpenExp env' aenv t
-> OpenExp env' aenv e
-> OpenExp env' aenv e
forall bnd_t env env' aenv body_t.
ELeftHandSide bnd_t env env'
-> OpenExp env aenv bnd_t
-> OpenExp env' aenv body_t
-> OpenExp env aenv body_t
Let LeftHandSide ScalarType t env' env'
lhs OpenExp env' aenv t
b