{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.IR.SeqMem
  ( SeqMem,

    -- * Simplification
    simplifyProg,
    simpleSeqMem,

    -- * Module re-exports
    module Futhark.IR.Mem,
    module Futhark.IR.Kernels.Kernel,
  )
where

import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Kernels.Kernel
import Futhark.IR.Mem
import Futhark.IR.Mem.Simplify
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations (BinderOps (..), mkLetNamesB', mkLetNamesB'')
import qualified Futhark.TypeCheck as TC

data SeqMem

instance Decorations SeqMem where
  type LetDec SeqMem = LetDecMem
  type FParamInfo SeqMem = FParamMem
  type LParamInfo SeqMem = LParamMem
  type RetType SeqMem = RetTypeMem
  type BranchType SeqMem = BranchTypeMem
  type Op SeqMem = MemOp ()

instance ASTLore SeqMem where
  expTypesFromPattern :: forall (m :: * -> *).
(HasScope SeqMem m, Monad m) =>
Pattern SeqMem -> m [BranchType SeqMem]
expTypesFromPattern = [BranchTypeMem] -> m [BranchTypeMem]
forall (m :: * -> *) a. Monad m => a -> m a
return ([BranchTypeMem] -> m [BranchTypeMem])
-> (PatternT LetDecMem -> [BranchTypeMem])
-> PatternT LetDecMem
-> m [BranchTypeMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, BranchTypeMem) -> BranchTypeMem)
-> [(VName, BranchTypeMem)] -> [BranchTypeMem]
forall a b. (a -> b) -> [a] -> [b]
map (VName, BranchTypeMem) -> BranchTypeMem
forall a b. (a, b) -> b
snd ([(VName, BranchTypeMem)] -> [BranchTypeMem])
-> (PatternT LetDecMem -> [(VName, BranchTypeMem)])
-> PatternT LetDecMem
-> [BranchTypeMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(VName, BranchTypeMem)], [(VName, BranchTypeMem)])
-> [(VName, BranchTypeMem)]
forall a b. (a, b) -> b
snd (([(VName, BranchTypeMem)], [(VName, BranchTypeMem)])
 -> [(VName, BranchTypeMem)])
-> (PatternT LetDecMem
    -> ([(VName, BranchTypeMem)], [(VName, BranchTypeMem)]))
-> PatternT LetDecMem
-> [(VName, BranchTypeMem)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT LetDecMem
-> ([(VName, BranchTypeMem)], [(VName, BranchTypeMem)])
bodyReturnsFromPattern

instance OpReturns SeqMem where
  opReturns :: forall (m :: * -> *).
(Monad m, HasScope SeqMem m) =>
Op SeqMem -> m [ExpReturns]
opReturns (Alloc SubExp
_ Space
space) = [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return [Space -> ExpReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
  opReturns (Inner ()) = [ExpReturns] -> m [ExpReturns]
forall (f :: * -> *) a. Applicative f => a -> f a
pure []

instance PrettyLore SeqMem

instance TC.CheckableOp SeqMem where
  checkOp :: OpWithAliases (Op SeqMem) -> TypeM SeqMem ()
checkOp (Alloc SubExp
size Space
_) =
    [TypeBase (ShapeBase SubExp) NoUniqueness]
-> SubExp -> TypeM SeqMem ()
forall lore.
Checkable lore =>
[TypeBase (ShapeBase SubExp) NoUniqueness]
-> SubExp -> TypeM lore ()
TC.require [PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
  checkOp (Inner ()) =
    () -> TypeM SeqMem ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance TC.Checkable SeqMem where
  checkFParamLore :: VName -> FParamInfo SeqMem -> TypeM SeqMem ()
checkFParamLore = VName -> FParamInfo SeqMem -> TypeM SeqMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
  checkLParamLore :: VName -> LParamInfo SeqMem -> TypeM SeqMem ()
checkLParamLore = VName -> LParamInfo SeqMem -> TypeM SeqMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
  checkLetBoundLore :: VName -> LetDec SeqMem -> TypeM SeqMem ()
checkLetBoundLore = VName -> LetDec SeqMem -> TypeM SeqMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
  checkRetType :: [RetType SeqMem] -> TypeM SeqMem ()
checkRetType = (RetTypeMem -> TypeM SeqMem ()) -> [RetTypeMem] -> TypeM SeqMem ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (TypeBase (ShapeBase (Ext SubExp)) Uniqueness -> TypeM SeqMem ()
forall lore u.
Checkable lore =>
TypeBase (ShapeBase (Ext SubExp)) u -> TypeM lore ()
TC.checkExtType (TypeBase (ShapeBase (Ext SubExp)) Uniqueness -> TypeM SeqMem ())
-> (RetTypeMem -> TypeBase (ShapeBase (Ext SubExp)) Uniqueness)
-> RetTypeMem
-> TypeM SeqMem ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RetTypeMem -> TypeBase (ShapeBase (Ext SubExp)) Uniqueness
forall t.
DeclExtTyped t =>
t -> TypeBase (ShapeBase (Ext SubExp)) Uniqueness
declExtTypeOf)
  primFParam :: VName -> PrimType -> TypeM SeqMem (FParam (Aliases SeqMem))
primFParam VName
name PrimType
t = Param FParamMem -> TypeM SeqMem (Param FParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param FParamMem -> TypeM SeqMem (Param FParamMem))
-> Param FParamMem -> TypeM SeqMem (Param FParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> FParamMem -> Param FParamMem
forall dec. VName -> dec -> Param dec
Param VName
name (PrimType -> FParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
  matchPattern :: Pattern (Aliases SeqMem) -> Exp (Aliases SeqMem) -> TypeM SeqMem ()
matchPattern = Pattern (Aliases SeqMem) -> Exp (Aliases SeqMem) -> TypeM SeqMem ()
forall lore.
(Mem lore, Checkable lore) =>
Pattern (Aliases lore) -> Exp (Aliases lore) -> TypeM lore ()
matchPatternToExp
  matchReturnType :: [RetType SeqMem] -> [SubExp] -> TypeM SeqMem ()
matchReturnType = [RetType SeqMem] -> [SubExp] -> TypeM SeqMem ()
forall lore.
(Mem lore, Checkable lore) =>
[RetTypeMem] -> [SubExp] -> TypeM lore ()
matchFunctionReturnType
  matchBranchType :: [BranchType SeqMem] -> Body (Aliases SeqMem) -> TypeM SeqMem ()
matchBranchType = [BranchType SeqMem] -> Body (Aliases SeqMem) -> TypeM SeqMem ()
forall lore.
(Mem lore, Checkable lore) =>
[BranchTypeMem] -> Body (Aliases lore) -> TypeM lore ()
matchBranchReturnType
  matchLoopResult :: [FParam (Aliases SeqMem)]
-> [FParam (Aliases SeqMem)] -> [SubExp] -> TypeM SeqMem ()
matchLoopResult = [FParam (Aliases SeqMem)]
-> [FParam (Aliases SeqMem)] -> [SubExp] -> TypeM SeqMem ()
forall lore.
(Mem lore, Checkable lore) =>
[FParam (Aliases lore)]
-> [FParam (Aliases lore)] -> [SubExp] -> TypeM lore ()
matchLoopResultMem

instance BinderOps SeqMem where
  mkExpDecB :: forall (m :: * -> *).
(MonadBinder m, Lore m ~ SeqMem) =>
Pattern SeqMem -> Exp SeqMem -> m (ExpDec SeqMem)
mkExpDecB Pattern SeqMem
_ Exp SeqMem
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  mkBodyB :: forall (m :: * -> *).
(MonadBinder m, Lore m ~ SeqMem) =>
Stms SeqMem -> [SubExp] -> m (Body SeqMem)
mkBodyB Stms SeqMem
stms [SubExp]
res = Body SeqMem -> m (Body SeqMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body SeqMem -> m (Body SeqMem)) -> Body SeqMem -> m (Body SeqMem)
forall a b. (a -> b) -> a -> b
$ BodyDec SeqMem -> Stms SeqMem -> [SubExp] -> Body SeqMem
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body () Stms SeqMem
stms [SubExp]
res
  mkLetNamesB :: forall (m :: * -> *).
(MonadBinder m, Lore m ~ SeqMem) =>
[VName] -> Exp SeqMem -> m (Stm SeqMem)
mkLetNamesB = ExpDec (Lore m) -> [VName] -> Exp (Lore m) -> m (Stm (Lore m))
forall (m :: * -> *) inner.
(Op (Lore m) ~ MemOp inner, MonadBinder m, ExpDec (Lore m) ~ (),
 Allocator (Lore m) (PatAllocM (Lore m))) =>
ExpDec (Lore m) -> [VName] -> Exp (Lore m) -> m (Stm (Lore m))
mkLetNamesB' ()

instance BinderOps (Engine.Wise SeqMem) where
  mkExpDecB :: forall (m :: * -> *).
(MonadBinder m, Lore m ~ Wise SeqMem) =>
Pattern (Wise SeqMem)
-> Exp (Wise SeqMem) -> m (ExpDec (Wise SeqMem))
mkExpDecB Pattern (Wise SeqMem)
pat Exp (Wise SeqMem)
e = (ExpWisdom, ()) -> m (ExpWisdom, ())
forall (m :: * -> *) a. Monad m => a -> m a
return ((ExpWisdom, ()) -> m (ExpWisdom, ()))
-> (ExpWisdom, ()) -> m (ExpWisdom, ())
forall a b. (a -> b) -> a -> b
$ Pattern (Wise SeqMem)
-> ExpDec SeqMem -> Exp (Wise SeqMem) -> ExpDec (Wise SeqMem)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
Pattern (Wise lore)
-> ExpDec lore -> Exp (Wise lore) -> ExpDec (Wise lore)
Engine.mkWiseExpDec Pattern (Wise SeqMem)
pat () Exp (Wise SeqMem)
e
  mkBodyB :: forall (m :: * -> *).
(MonadBinder m, Lore m ~ Wise SeqMem) =>
Stms (Wise SeqMem) -> [SubExp] -> m (Body (Wise SeqMem))
mkBodyB Stms (Wise SeqMem)
stms [SubExp]
res = Body (Wise SeqMem) -> m (Body (Wise SeqMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise SeqMem) -> m (Body (Wise SeqMem)))
-> Body (Wise SeqMem) -> m (Body (Wise SeqMem))
forall a b. (a -> b) -> a -> b
$ BodyDec SeqMem
-> Stms (Wise SeqMem) -> [SubExp] -> Body (Wise SeqMem)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
BodyDec lore -> Stms (Wise lore) -> [SubExp] -> Body (Wise lore)
Engine.mkWiseBody () Stms (Wise SeqMem)
stms [SubExp]
res
  mkLetNamesB :: forall (m :: * -> *).
(MonadBinder m, Lore m ~ Wise SeqMem) =>
[VName] -> Exp (Wise SeqMem) -> m (Stm (Wise SeqMem))
mkLetNamesB = [VName] -> Exp (Wise SeqMem) -> m (Stm (Wise SeqMem))
forall (m :: * -> *) inner lore.
(Op (Lore m) ~ MemOp inner, ExpDec lore ~ (),
 HasScope (Wise lore) m, Allocator lore (PatAllocM lore),
 MonadBinder m, CanBeWise (Op lore)) =>
[VName] -> Exp (Wise lore) -> m (Stm (Wise lore))
mkLetNamesB''

simplifyProg :: Prog SeqMem -> PassM (Prog SeqMem)
simplifyProg :: Prog SeqMem -> PassM (Prog SeqMem)
simplifyProg = SimpleOps SeqMem -> Prog SeqMem -> PassM (Prog SeqMem)
forall lore inner.
(SimplifyMemory lore, Op lore ~ MemOp inner) =>
SimpleOps lore -> Prog lore -> PassM (Prog lore)
simplifyProgGeneric SimpleOps SeqMem
simpleSeqMem

simpleSeqMem :: Engine.SimpleOps SeqMem
simpleSeqMem :: SimpleOps SeqMem
simpleSeqMem =
  (OpWithWisdom () -> UsageTable)
-> SimplifyOp SeqMem () -> SimpleOps SeqMem
forall lore inner.
(SimplifyMemory lore, Op lore ~ MemOp inner) =>
(OpWithWisdom inner -> UsageTable)
-> SimplifyOp lore inner -> SimpleOps lore
simpleGeneric (UsageTable -> () -> UsageTable
forall a b. a -> b -> a
const UsageTable
forall a. Monoid a => a
mempty) (SimplifyOp SeqMem () -> SimpleOps SeqMem)
-> SimplifyOp SeqMem () -> SimpleOps SeqMem
forall a b. (a -> b) -> a -> b
$ SimpleM SeqMem ((), Stms (Wise SeqMem))
-> () -> SimpleM SeqMem ((), Stms (Wise SeqMem))
forall a b. a -> b -> a
const (SimpleM SeqMem ((), Stms (Wise SeqMem))
 -> () -> SimpleM SeqMem ((), Stms (Wise SeqMem)))
-> SimpleM SeqMem ((), Stms (Wise SeqMem))
-> ()
-> SimpleM SeqMem ((), Stms (Wise SeqMem))
forall a b. (a -> b) -> a -> b
$ ((), Stms (Wise SeqMem)) -> SimpleM SeqMem ((), Stms (Wise SeqMem))
forall (m :: * -> *) a. Monad m => a -> m a
return ((), Stms (Wise SeqMem)
forall a. Monoid a => a
mempty)