{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module TensorSafe.Layers.MaxPooling where
import Data.Kind (Type)
import Data.Map
import Data.Proxy
import GHC.TypeLits
import TensorSafe.Compile.Expr
import TensorSafe.Layer
data MaxPooling :: Nat -> Nat -> Nat -> Nat -> Type where
MaxPooling :: MaxPooling kernelRows kernelColumns strideRows strideColumns
deriving Show
instance ( KnownNat kernelRows
, KnownNat kernelColumns
, KnownNat strideRows
, KnownNat strideColumns
) => Layer (MaxPooling kernelRows kernelColumns strideRows strideColumns) where
layer = MaxPooling
compile _ _ =
let kernelRows = natVal (Proxy :: Proxy kernelRows)
kernelColumns = natVal (Proxy :: Proxy kernelColumns)
strideRows = natVal (Proxy :: Proxy strideRows)
strideColumns = natVal (Proxy :: Proxy strideColumns)
in
CNLayer DMaxPooling (
fromList [
("poolSize", show [kernelRows, kernelColumns]),
("strides", show [strideRows, strideColumns])
])