{-# LANGUAGE FlexibleContexts #-}

-- | Internalising bindings.
module Futhark.Internalise.Bindings
  ( bindingParams,
    bindingLoopParams,
    bindingLambdaParams,
    stmPattern,
  )
where

import Control.Monad.Reader hiding (mapM)
import qualified Data.Map.Strict as M
import qualified Futhark.IR.SOACS as I
import Futhark.Internalise.Monad
import Futhark.Internalise.TypesValues
import Futhark.Util
import Language.Futhark as E hiding (matchDims)

bindingParams ::
  [E.TypeParam] ->
  [E.Pattern] ->
  ([I.FParam] -> [[I.FParam]] -> InternaliseM a) ->
  InternaliseM a
bindingParams :: [TypeParam]
-> [Pattern]
-> ([FParam] -> [[FParam]] -> InternaliseM a)
-> InternaliseM a
bindingParams [TypeParam]
tparams [Pattern]
params [FParam] -> [[FParam]] -> InternaliseM a
m = do
  [[Ident]]
flattened_params <- (Pattern -> InternaliseM [Ident])
-> [Pattern] -> InternaliseM [[Ident]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Pattern -> InternaliseM [Ident]
forall (m :: * -> *). MonadFreshNames m => Pattern -> m [Ident]
flattenPattern [Pattern]
params
  let params_idents :: [Ident]
params_idents = [[Ident]] -> [Ident]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Ident]]
flattened_params
  [[TypeBase Shape Uniqueness]]
params_ts <-
    [TypeBase (DimDecl VName) ()]
-> InternaliseM [[TypeBase Shape Uniqueness]]
internaliseParamTypes ([TypeBase (DimDecl VName) ()]
 -> InternaliseM [[TypeBase Shape Uniqueness]])
-> [TypeBase (DimDecl VName) ()]
-> InternaliseM [[TypeBase Shape Uniqueness]]
forall a b. (a -> b) -> a -> b
$
      (Ident -> TypeBase (DimDecl VName) ())
-> [Ident] -> [TypeBase (DimDecl VName) ()]
forall a b. (a -> b) -> [a] -> [b]
map ((TypeBase (DimDecl VName) Aliasing
 -> () -> TypeBase (DimDecl VName) ())
-> ()
-> TypeBase (DimDecl VName) Aliasing
-> TypeBase (DimDecl VName) ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip TypeBase (DimDecl VName) Aliasing
-> () -> TypeBase (DimDecl VName) ()
forall dim asf ast. TypeBase dim asf -> ast -> TypeBase dim ast
E.setAliases () (TypeBase (DimDecl VName) Aliasing -> TypeBase (DimDecl VName) ())
-> (Ident -> TypeBase (DimDecl VName) Aliasing)
-> Ident
-> TypeBase (DimDecl VName) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Info (TypeBase (DimDecl VName) Aliasing)
-> TypeBase (DimDecl VName) Aliasing
forall a. Info a -> a
E.unInfo (Info (TypeBase (DimDecl VName) Aliasing)
 -> TypeBase (DimDecl VName) Aliasing)
-> (Ident -> Info (TypeBase (DimDecl VName) Aliasing))
-> Ident
-> TypeBase (DimDecl VName) Aliasing
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> Info (TypeBase (DimDecl VName) Aliasing)
forall (f :: * -> *) vn.
IdentBase f vn -> f (TypeBase (DimDecl VName) Aliasing)
E.identType) [Ident]
params_idents
  let num_param_idents :: [Int]
num_param_idents = ([Ident] -> Int) -> [[Ident]] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map [Ident] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[Ident]]
flattened_params
      num_param_ts :: [Int]
num_param_ts = ([[TypeBase Shape Uniqueness]] -> Int)
-> [[[TypeBase Shape Uniqueness]]] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int)
-> ([[TypeBase Shape Uniqueness]] -> [Int])
-> [[TypeBase Shape Uniqueness]]
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([TypeBase Shape Uniqueness] -> Int)
-> [[TypeBase Shape Uniqueness]] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map [TypeBase Shape Uniqueness] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) ([[[TypeBase Shape Uniqueness]]] -> [Int])
-> [[[TypeBase Shape Uniqueness]]] -> [Int]
forall a b. (a -> b) -> a -> b
$ [Int]
-> [[TypeBase Shape Uniqueness]] -> [[[TypeBase Shape Uniqueness]]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
num_param_idents [[TypeBase Shape Uniqueness]]
params_ts

  let shape_params :: [Param (TypeBase shape u)]
shape_params = [VName -> TypeBase shape u -> Param (TypeBase shape u)
forall dec. VName -> dec -> Param dec
I.Param VName
v (TypeBase shape u -> Param (TypeBase shape u))
-> TypeBase shape u -> Param (TypeBase shape u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
I.Prim PrimType
I.int64 | E.TypeParamDim VName
v SrcLoc
_ <- [TypeParam]
tparams]
      shape_subst :: Map VName [SubExp]
shape_subst = [(VName, [SubExp])] -> Map VName [SubExp]
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Param (TypeBase Any Any) -> VName
forall dec. Param dec -> VName
I.paramName Param (TypeBase Any Any)
p, [VName -> SubExp
I.Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Any Any) -> VName
forall dec. Param dec -> VName
I.paramName Param (TypeBase Any Any)
p]) | Param (TypeBase Any Any)
p <- [Param (TypeBase Any Any)]
forall shape u. [Param (TypeBase shape u)]
shape_params]
  [Ident]
-> [TypeBase Shape Uniqueness]
-> ([[Param (TypeBase Shape Uniqueness)]] -> InternaliseM a)
-> InternaliseM a
forall t a.
Show t =>
[Ident] -> [t] -> ([[Param t]] -> InternaliseM a) -> InternaliseM a
bindingFlatPattern [Ident]
params_idents ([[TypeBase Shape Uniqueness]] -> [TypeBase Shape Uniqueness]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[TypeBase Shape Uniqueness]]
params_ts) (([[Param (TypeBase Shape Uniqueness)]] -> InternaliseM a)
 -> InternaliseM a)
-> ([[Param (TypeBase Shape Uniqueness)]] -> InternaliseM a)
-> InternaliseM a
forall a b. (a -> b) -> a -> b
$ \[[Param (TypeBase Shape Uniqueness)]]
valueparams ->
    Scope SOACS -> InternaliseM a -> InternaliseM a
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
I.localScope ([Param (TypeBase Shape Uniqueness)] -> Scope SOACS
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
I.scopeOfFParams ([Param (TypeBase Shape Uniqueness)] -> Scope SOACS)
-> [Param (TypeBase Shape Uniqueness)] -> Scope SOACS
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase Shape Uniqueness)]
forall shape u. [Param (TypeBase shape u)]
shape_params [Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
forall a. [a] -> [a] -> [a]
++ [[Param (TypeBase Shape Uniqueness)]]
-> [Param (TypeBase Shape Uniqueness)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Param (TypeBase Shape Uniqueness)]]
valueparams) (InternaliseM a -> InternaliseM a)
-> InternaliseM a -> InternaliseM a
forall a b. (a -> b) -> a -> b
$
      Map VName [SubExp] -> InternaliseM a -> InternaliseM a
forall a. Map VName [SubExp] -> InternaliseM a -> InternaliseM a
substitutingVars Map VName [SubExp]
shape_subst (InternaliseM a -> InternaliseM a)
-> InternaliseM a -> InternaliseM a
forall a b. (a -> b) -> a -> b
$
        [FParam] -> [[FParam]] -> InternaliseM a
m [FParam]
forall shape u. [Param (TypeBase shape u)]
shape_params ([[FParam]] -> InternaliseM a) -> [[FParam]] -> InternaliseM a
forall a b. (a -> b) -> a -> b
$
          [Int]
-> [Param (TypeBase Shape Uniqueness)]
-> [[Param (TypeBase Shape Uniqueness)]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
num_param_ts ([[Param (TypeBase Shape Uniqueness)]]
-> [Param (TypeBase Shape Uniqueness)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Param (TypeBase Shape Uniqueness)]]
valueparams)

bindingLoopParams ::
  [E.TypeParam] ->
  E.Pattern ->
  ([I.FParam] -> [I.FParam] -> InternaliseM a) ->
  InternaliseM a
bindingLoopParams :: [TypeParam]
-> Pattern
-> ([FParam] -> [FParam] -> InternaliseM a)
-> InternaliseM a
bindingLoopParams [TypeParam]
tparams Pattern
pat [FParam] -> [FParam] -> InternaliseM a
m = do
  [Ident]
pat_idents <- Pattern -> InternaliseM [Ident]
forall (m :: * -> *). MonadFreshNames m => Pattern -> m [Ident]
flattenPattern Pattern
pat
  [TypeBase Shape Uniqueness]
pat_ts <- TypeBase (DimDecl VName) ()
-> InternaliseM [TypeBase Shape Uniqueness]
internaliseLoopParamType (Pattern -> TypeBase (DimDecl VName) ()
E.patternStructType Pattern
pat)

  let shape_params :: [Param (TypeBase shape u)]
shape_params = [VName -> TypeBase shape u -> Param (TypeBase shape u)
forall dec. VName -> dec -> Param dec
I.Param VName
v (TypeBase shape u -> Param (TypeBase shape u))
-> TypeBase shape u -> Param (TypeBase shape u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
I.Prim PrimType
I.int64 | E.TypeParamDim VName
v SrcLoc
_ <- [TypeParam]
tparams]
      shape_subst :: Map VName [SubExp]
shape_subst = [(VName, [SubExp])] -> Map VName [SubExp]
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Param (TypeBase Any Any) -> VName
forall dec. Param dec -> VName
I.paramName Param (TypeBase Any Any)
p, [VName -> SubExp
I.Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Any Any) -> VName
forall dec. Param dec -> VName
I.paramName Param (TypeBase Any Any)
p]) | Param (TypeBase Any Any)
p <- [Param (TypeBase Any Any)]
forall shape u. [Param (TypeBase shape u)]
shape_params]

  [Ident]
-> [TypeBase Shape Uniqueness]
-> ([[Param (TypeBase Shape Uniqueness)]] -> InternaliseM a)
-> InternaliseM a
forall t a.
Show t =>
[Ident] -> [t] -> ([[Param t]] -> InternaliseM a) -> InternaliseM a
bindingFlatPattern [Ident]
pat_idents [TypeBase Shape Uniqueness]
pat_ts (([[Param (TypeBase Shape Uniqueness)]] -> InternaliseM a)
 -> InternaliseM a)
-> ([[Param (TypeBase Shape Uniqueness)]] -> InternaliseM a)
-> InternaliseM a
forall a b. (a -> b) -> a -> b
$ \[[Param (TypeBase Shape Uniqueness)]]
valueparams ->
    Scope SOACS -> InternaliseM a -> InternaliseM a
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
I.localScope ([Param (TypeBase Shape Uniqueness)] -> Scope SOACS
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
I.scopeOfFParams ([Param (TypeBase Shape Uniqueness)] -> Scope SOACS)
-> [Param (TypeBase Shape Uniqueness)] -> Scope SOACS
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase Shape Uniqueness)]
forall shape u. [Param (TypeBase shape u)]
shape_params [Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
forall a. [a] -> [a] -> [a]
++ [[Param (TypeBase Shape Uniqueness)]]
-> [Param (TypeBase Shape Uniqueness)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Param (TypeBase Shape Uniqueness)]]
valueparams) (InternaliseM a -> InternaliseM a)
-> InternaliseM a -> InternaliseM a
forall a b. (a -> b) -> a -> b
$
      Map VName [SubExp] -> InternaliseM a -> InternaliseM a
forall a. Map VName [SubExp] -> InternaliseM a -> InternaliseM a
substitutingVars Map VName [SubExp]
shape_subst (InternaliseM a -> InternaliseM a)
-> InternaliseM a -> InternaliseM a
forall a b. (a -> b) -> a -> b
$ [FParam] -> [FParam] -> InternaliseM a
m [FParam]
forall shape u. [Param (TypeBase shape u)]
shape_params ([FParam] -> InternaliseM a) -> [FParam] -> InternaliseM a
forall a b. (a -> b) -> a -> b
$ [[Param (TypeBase Shape Uniqueness)]]
-> [Param (TypeBase Shape Uniqueness)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Param (TypeBase Shape Uniqueness)]]
valueparams

bindingLambdaParams ::
  [E.Pattern] ->
  [I.Type] ->
  ([I.LParam] -> InternaliseM a) ->
  InternaliseM a
bindingLambdaParams :: [Pattern]
-> [Type] -> ([LParam] -> InternaliseM a) -> InternaliseM a
bindingLambdaParams [Pattern]
params [Type]
ts [LParam] -> InternaliseM a
m = do
  [Ident]
params_idents <- [[Ident]] -> [Ident]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Ident]] -> [Ident])
-> InternaliseM [[Ident]] -> InternaliseM [Ident]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern -> InternaliseM [Ident])
-> [Pattern] -> InternaliseM [[Ident]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Pattern -> InternaliseM [Ident]
forall (m :: * -> *). MonadFreshNames m => Pattern -> m [Ident]
flattenPattern [Pattern]
params

  [Ident]
-> [Type] -> ([[Param Type]] -> InternaliseM a) -> InternaliseM a
forall t a.
Show t =>
[Ident] -> [t] -> ([[Param t]] -> InternaliseM a) -> InternaliseM a
bindingFlatPattern [Ident]
params_idents [Type]
ts (([[Param Type]] -> InternaliseM a) -> InternaliseM a)
-> ([[Param Type]] -> InternaliseM a) -> InternaliseM a
forall a b. (a -> b) -> a -> b
$ \[[Param Type]]
params' ->
    Scope SOACS -> InternaliseM a -> InternaliseM a
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
I.localScope ([Param Type] -> Scope SOACS
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
I.scopeOfLParams ([Param Type] -> Scope SOACS) -> [Param Type] -> Scope SOACS
forall a b. (a -> b) -> a -> b
$ [[Param Type]] -> [Param Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Param Type]]
params') (InternaliseM a -> InternaliseM a)
-> InternaliseM a -> InternaliseM a
forall a b. (a -> b) -> a -> b
$ [LParam] -> InternaliseM a
m ([LParam] -> InternaliseM a) -> [LParam] -> InternaliseM a
forall a b. (a -> b) -> a -> b
$ [[Param Type]] -> [Param Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Param Type]]
params'

processFlatPattern ::
  Show t =>
  [E.Ident] ->
  [t] ->
  InternaliseM ([[I.Param t]], VarSubstitutions)
processFlatPattern :: [Ident] -> [t] -> InternaliseM ([[Param t]], Map VName [SubExp])
processFlatPattern [Ident]
x [t]
y = [([Param t], (VName, [SubExp]))]
-> [Ident] -> [t] -> InternaliseM ([[Param t]], Map VName [SubExp])
forall dec.
[([Param dec], (VName, [SubExp]))]
-> [Ident]
-> [dec]
-> InternaliseM ([[Param dec]], Map VName [SubExp])
processFlatPattern' [] [Ident]
x [t]
y
  where
    processFlatPattern' :: [([Param dec], (VName, [SubExp]))]
-> [Ident]
-> [dec]
-> InternaliseM ([[Param dec]], Map VName [SubExp])
processFlatPattern' [([Param dec], (VName, [SubExp]))]
pat [] [dec]
_ = do
      let ([[Param dec]]
vs, [(VName, [SubExp])]
substs) = [([Param dec], (VName, [SubExp]))]
-> ([[Param dec]], [(VName, [SubExp])])
forall a b. [(a, b)] -> ([a], [b])
unzip [([Param dec], (VName, [SubExp]))]
pat
          substs' :: Map VName [SubExp]
substs' = [(VName, [SubExp])] -> Map VName [SubExp]
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName, [SubExp])]
substs
          idents :: [[Param dec]]
idents = [[Param dec]] -> [[Param dec]]
forall a. [a] -> [a]
reverse [[Param dec]]
vs
      ([[Param dec]], Map VName [SubExp])
-> InternaliseM ([[Param dec]], Map VName [SubExp])
forall (m :: * -> *) a. Monad m => a -> m a
return ([[Param dec]]
idents, Map VName [SubExp]
substs')
    processFlatPattern' [([Param dec], (VName, [SubExp]))]
pat (Ident
p : [Ident]
rest) [dec]
ts = do
      ([Param dec]
ps, [Param dec]
subst, [dec]
rest_ts) <- [dec] -> [VName] -> ([Param dec], [Param dec], [dec])
forall a. [a] -> [VName] -> ([Param a], [Param a], [a])
handleMapping [dec]
ts ([VName] -> ([Param dec], [Param dec], [dec]))
-> InternaliseM [VName]
-> InternaliseM ([Param dec], [Param dec], [dec])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ident -> InternaliseM [VName]
internaliseBindee Ident
p
      [([Param dec], (VName, [SubExp]))]
-> [Ident]
-> [dec]
-> InternaliseM ([[Param dec]], Map VName [SubExp])
processFlatPattern' (([Param dec]
ps, (Ident -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
E.identName Ident
p, (Param dec -> SubExp) -> [Param dec] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
I.Var (VName -> SubExp) -> (Param dec -> VName) -> Param dec -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
I.paramName) [Param dec]
subst)) ([Param dec], (VName, [SubExp]))
-> [([Param dec], (VName, [SubExp]))]
-> [([Param dec], (VName, [SubExp]))]
forall a. a -> [a] -> [a]
: [([Param dec], (VName, [SubExp]))]
pat) [Ident]
rest [dec]
rest_ts

    handleMapping :: [a] -> [VName] -> ([Param a], [Param a], [a])
handleMapping [a]
ts [] =
      ([], [], [a]
ts)
    handleMapping [a]
ts (VName
r : [VName]
rs) =
      let ([Param a]
ps, Param a
reps, [a]
ts') = [a] -> VName -> ([Param a], Param a, [a])
forall a. [a] -> VName -> ([Param a], Param a, [a])
handleMapping' [a]
ts VName
r
          ([Param a]
pss, [Param a]
repss, [a]
ts'') = [a] -> [VName] -> ([Param a], [Param a], [a])
handleMapping [a]
ts' [VName]
rs
       in ([Param a]
ps [Param a] -> [Param a] -> [Param a]
forall a. [a] -> [a] -> [a]
++ [Param a]
pss, Param a
reps Param a -> [Param a] -> [Param a]
forall a. a -> [a] -> [a]
: [Param a]
repss, [a]
ts'')

    handleMapping' :: [a] -> VName -> ([Param a], Param a, [a])
handleMapping' (a
t : [a]
ts) VName
vname =
      let v' :: Param a
v' = VName -> a -> Param a
forall dec. VName -> dec -> Param dec
I.Param VName
vname a
t
       in ([Param a
v'], Param a
v', [a]
ts)
    handleMapping' [] VName
_ =
      [Char] -> ([Param a], Param a, [a])
forall a. HasCallStack => [Char] -> a
error ([Char] -> ([Param a], Param a, [a]))
-> [Char] -> ([Param a], Param a, [a])
forall a b. (a -> b) -> a -> b
$ [Char]
"processFlatPattern: insufficient identifiers in pattern." [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ ([Ident], [t]) -> [Char]
forall a. Show a => a -> [Char]
show ([Ident]
x, [t]
y)

    internaliseBindee :: E.Ident -> InternaliseM [VName]
    internaliseBindee :: Ident -> InternaliseM [VName]
internaliseBindee Ident
bindee = do
      let name :: VName
name = Ident -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
E.identName Ident
bindee
      Int
n <- TypeBase (DimDecl VName) () -> InternaliseM Int
internalisedTypeSize (TypeBase (DimDecl VName) () -> InternaliseM Int)
-> TypeBase (DimDecl VName) () -> InternaliseM Int
forall a b. (a -> b) -> a -> b
$ (TypeBase (DimDecl VName) Aliasing
 -> () -> TypeBase (DimDecl VName) ())
-> ()
-> TypeBase (DimDecl VName) Aliasing
-> TypeBase (DimDecl VName) ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip TypeBase (DimDecl VName) Aliasing
-> () -> TypeBase (DimDecl VName) ()
forall dim asf ast. TypeBase dim asf -> ast -> TypeBase dim ast
E.setAliases () (TypeBase (DimDecl VName) Aliasing -> TypeBase (DimDecl VName) ())
-> TypeBase (DimDecl VName) Aliasing -> TypeBase (DimDecl VName) ()
forall a b. (a -> b) -> a -> b
$ Info (TypeBase (DimDecl VName) Aliasing)
-> TypeBase (DimDecl VName) Aliasing
forall a. Info a -> a
E.unInfo (Info (TypeBase (DimDecl VName) Aliasing)
 -> TypeBase (DimDecl VName) Aliasing)
-> Info (TypeBase (DimDecl VName) Aliasing)
-> TypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ Ident -> Info (TypeBase (DimDecl VName) Aliasing)
forall (f :: * -> *) vn.
IdentBase f vn -> f (TypeBase (DimDecl VName) Aliasing)
E.identType Ident
bindee
      case Int
n of
        Int
1 -> [VName] -> InternaliseM [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
name]
        Int
_ -> Int -> InternaliseM VName -> InternaliseM [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (InternaliseM VName -> InternaliseM [VName])
-> InternaliseM VName -> InternaliseM [VName]
forall a b. (a -> b) -> a -> b
$ [Char] -> InternaliseM VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> InternaliseM VName) -> [Char] -> InternaliseM VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
name

bindingFlatPattern ::
  Show t =>
  [E.Ident] ->
  [t] ->
  ([[I.Param t]] -> InternaliseM a) ->
  InternaliseM a
bindingFlatPattern :: [Ident] -> [t] -> ([[Param t]] -> InternaliseM a) -> InternaliseM a
bindingFlatPattern [Ident]
idents [t]
ts [[Param t]] -> InternaliseM a
m = do
  ([[Param t]]
ps, Map VName [SubExp]
substs) <- [Ident] -> [t] -> InternaliseM ([[Param t]], Map VName [SubExp])
forall t.
Show t =>
[Ident] -> [t] -> InternaliseM ([[Param t]], Map VName [SubExp])
processFlatPattern [Ident]
idents [t]
ts
  (InternaliseEnv -> InternaliseEnv)
-> InternaliseM a -> InternaliseM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\InternaliseEnv
env -> InternaliseEnv
env {envSubsts :: Map VName [SubExp]
envSubsts = Map VName [SubExp]
substs Map VName [SubExp] -> Map VName [SubExp] -> Map VName [SubExp]
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` InternaliseEnv -> Map VName [SubExp]
envSubsts InternaliseEnv
env}) (InternaliseM a -> InternaliseM a)
-> InternaliseM a -> InternaliseM a
forall a b. (a -> b) -> a -> b
$
    [[Param t]] -> InternaliseM a
m [[Param t]]
ps

-- | Flatten a pattern.  Returns a list of identifiers.  The
-- structural type of each identifier is returned separately.
flattenPattern :: MonadFreshNames m => E.Pattern -> m [E.Ident]
flattenPattern :: Pattern -> m [Ident]
flattenPattern = Pattern -> m [Ident]
forall (m :: * -> *). MonadFreshNames m => Pattern -> m [Ident]
flattenPattern'
  where
    flattenPattern' :: Pattern -> f [Ident]
flattenPattern' (E.PatternParens Pattern
p SrcLoc
_) =
      Pattern -> f [Ident]
flattenPattern' Pattern
p
    flattenPattern' (E.Wildcard Info (TypeBase (DimDecl VName) Aliasing)
t SrcLoc
loc) = do
      VName
name <- [Char] -> f VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"nameless"
      Pattern -> f [Ident]
flattenPattern' (Pattern -> f [Ident]) -> Pattern -> f [Ident]
forall a b. (a -> b) -> a -> b
$ VName
-> Info (TypeBase (DimDecl VName) Aliasing) -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
vn
-> f (TypeBase (DimDecl VName) Aliasing)
-> SrcLoc
-> PatternBase f vn
E.Id VName
name Info (TypeBase (DimDecl VName) Aliasing)
t SrcLoc
loc
    flattenPattern' (E.Id VName
v (Info TypeBase (DimDecl VName) Aliasing
t) SrcLoc
loc) =
      [Ident] -> f [Ident]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
-> Info (TypeBase (DimDecl VName) Aliasing) -> SrcLoc -> Ident
forall (f :: * -> *) vn.
vn
-> f (TypeBase (DimDecl VName) Aliasing)
-> SrcLoc
-> IdentBase f vn
E.Ident VName
v (TypeBase (DimDecl VName) Aliasing
-> Info (TypeBase (DimDecl VName) Aliasing)
forall a. a -> Info a
Info TypeBase (DimDecl VName) Aliasing
t) SrcLoc
loc]
    -- XXX: treat empty tuples and records as bool.
    flattenPattern' (E.TuplePattern [] SrcLoc
loc) =
      Pattern -> f [Ident]
flattenPattern' (Info (TypeBase (DimDecl VName) Aliasing) -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
f (TypeBase (DimDecl VName) Aliasing) -> SrcLoc -> PatternBase f vn
E.Wildcard (TypeBase (DimDecl VName) Aliasing
-> Info (TypeBase (DimDecl VName) Aliasing)
forall a. a -> Info a
Info (TypeBase (DimDecl VName) Aliasing
 -> Info (TypeBase (DimDecl VName) Aliasing))
-> TypeBase (DimDecl VName) Aliasing
-> Info (TypeBase (DimDecl VName) Aliasing)
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) Aliasing
-> TypeBase (DimDecl VName) Aliasing
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
E.Scalar (ScalarTypeBase (DimDecl VName) Aliasing
 -> TypeBase (DimDecl VName) Aliasing)
-> ScalarTypeBase (DimDecl VName) Aliasing
-> TypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. PrimType -> ScalarTypeBase dim as
E.Prim PrimType
E.Bool) SrcLoc
loc)
    flattenPattern' (E.RecordPattern [] SrcLoc
loc) =
      Pattern -> f [Ident]
flattenPattern' (Info (TypeBase (DimDecl VName) Aliasing) -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
f (TypeBase (DimDecl VName) Aliasing) -> SrcLoc -> PatternBase f vn
E.Wildcard (TypeBase (DimDecl VName) Aliasing
-> Info (TypeBase (DimDecl VName) Aliasing)
forall a. a -> Info a
Info (TypeBase (DimDecl VName) Aliasing
 -> Info (TypeBase (DimDecl VName) Aliasing))
-> TypeBase (DimDecl VName) Aliasing
-> Info (TypeBase (DimDecl VName) Aliasing)
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) Aliasing
-> TypeBase (DimDecl VName) Aliasing
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
E.Scalar (ScalarTypeBase (DimDecl VName) Aliasing
 -> TypeBase (DimDecl VName) Aliasing)
-> ScalarTypeBase (DimDecl VName) Aliasing
-> TypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. PrimType -> ScalarTypeBase dim as
E.Prim PrimType
E.Bool) SrcLoc
loc)
    flattenPattern' (E.TuplePattern [Pattern]
pats SrcLoc
_) =
      [[Ident]] -> [Ident]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Ident]] -> [Ident]) -> f [[Ident]] -> f [Ident]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern -> f [Ident]) -> [Pattern] -> f [[Ident]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Pattern -> f [Ident]
flattenPattern' [Pattern]
pats
    flattenPattern' (E.RecordPattern [(Name, Pattern)]
fs SrcLoc
loc) =
      Pattern -> f [Ident]
flattenPattern' (Pattern -> f [Ident]) -> Pattern -> f [Ident]
forall a b. (a -> b) -> a -> b
$ [Pattern] -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
[PatternBase f vn] -> SrcLoc -> PatternBase f vn
E.TuplePattern (((Name, Pattern) -> Pattern) -> [(Name, Pattern)] -> [Pattern]
forall a b. (a -> b) -> [a] -> [b]
map (Name, Pattern) -> Pattern
forall a b. (a, b) -> b
snd ([(Name, Pattern)] -> [Pattern]) -> [(Name, Pattern)] -> [Pattern]
forall a b. (a -> b) -> a -> b
$ Map Name Pattern -> [(Name, Pattern)]
forall a. Map Name a -> [(Name, a)]
sortFields (Map Name Pattern -> [(Name, Pattern)])
-> Map Name Pattern -> [(Name, Pattern)]
forall a b. (a -> b) -> a -> b
$ [(Name, Pattern)] -> Map Name Pattern
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, Pattern)]
fs) SrcLoc
loc
    flattenPattern' (E.PatternAscription Pattern
p TypeDeclBase Info VName
_ SrcLoc
_) =
      Pattern -> f [Ident]
flattenPattern' Pattern
p
    flattenPattern' (E.PatternLit PatLit
_ Info (TypeBase (DimDecl VName) Aliasing)
t SrcLoc
loc) =
      Pattern -> f [Ident]
flattenPattern' (Pattern -> f [Ident]) -> Pattern -> f [Ident]
forall a b. (a -> b) -> a -> b
$ Info (TypeBase (DimDecl VName) Aliasing) -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
f (TypeBase (DimDecl VName) Aliasing) -> SrcLoc -> PatternBase f vn
E.Wildcard Info (TypeBase (DimDecl VName) Aliasing)
t SrcLoc
loc
    flattenPattern' (E.PatternConstr Name
_ Info (TypeBase (DimDecl VName) Aliasing)
_ [Pattern]
ps SrcLoc
_) =
      [[Ident]] -> [Ident]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Ident]] -> [Ident]) -> f [[Ident]] -> f [Ident]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern -> f [Ident]) -> [Pattern] -> f [[Ident]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Pattern -> f [Ident]
flattenPattern' [Pattern]
ps

stmPattern ::
  E.Pattern ->
  [I.Type] ->
  ([VName] -> InternaliseM a) ->
  InternaliseM a
stmPattern :: Pattern -> [Type] -> ([VName] -> InternaliseM a) -> InternaliseM a
stmPattern Pattern
pat [Type]
ts [VName] -> InternaliseM a
m = do
  [Ident]
pat' <- Pattern -> InternaliseM [Ident]
forall (m :: * -> *). MonadFreshNames m => Pattern -> m [Ident]
flattenPattern Pattern
pat
  let addShapeStms :: t [Param dec] -> InternaliseM a
addShapeStms t [Param dec]
l =
        [VName] -> InternaliseM a
m ((Param dec -> VName) -> [Param dec] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param dec -> VName
forall dec. Param dec -> VName
I.paramName ([Param dec] -> [VName]) -> [Param dec] -> [VName]
forall a b. (a -> b) -> a -> b
$ t [Param dec] -> [Param dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat t [Param dec]
l)
  [Ident]
-> [Type] -> ([[Param Type]] -> InternaliseM a) -> InternaliseM a
forall t a.
Show t =>
[Ident] -> [t] -> ([[Param t]] -> InternaliseM a) -> InternaliseM a
bindingFlatPattern [Ident]
pat' [Type]
ts [[Param Type]] -> InternaliseM a
forall (t :: * -> *) dec.
Foldable t =>
t [Param dec] -> InternaliseM a
addShapeStms