{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveDataTypeable #-}
module CodeGen.Types.Parsed where
import CodeGen.Prelude
import CodeGen.Types.CLI
import Control.Monad
import Data.Data
import Data.Typeable
import qualified Data.HashSet as HS
data Parsable
= Ptr Parsable
| TenType TenType
| CType CType
deriving (Eq, Show, Generic, Hashable)
data CType
= CBool
| CVoid
| CPtrdiff
| CFloat
| CDouble
| CLong
| CUInt64
| CUInt32
| CUInt16
| CUInt8
| CInt64
| CInt32
| CInt16
| CInt8
| CInt
| CSize
| CChar
| CShort
deriving (Eq, Show, Generic, Hashable, Bounded, Enum)
newtype TenType = Pair { unTenType :: (RawTenType, LibType) }
deriving (Eq, Show, Generic, Hashable)
data RawTenType
= Tensor
| ByteTensor
| CharTensor
| ShortTensor
| IntTensor
| LongTensor
| FloatTensor
| DoubleTensor
| HalfTensor
| Storage
| ByteStorage
| CharStorage
| ShortStorage
| IntStorage
| LongStorage
| FloatStorage
| DoubleStorage
| HalfStorage
| DescBuff
| Generator
| Allocator
| File
| Half
| State
| IndexTensor
| IntegerTensor
| Real
| AccReal
| Stream
deriving (Eq, Show, Generic, Hashable, Bounded, Enum)
isConcreteCudaPrefixed :: TenType -> Bool
isConcreteCudaPrefixed (Pair (t, lib)) = (lib == THC || lib == THCUNN) && t `HS.member` HS.fromList
[ ByteTensor
, CharTensor
, ShortTensor
, IntTensor
, LongTensor
, FloatTensor
, DoubleTensor
, HalfTensor
]
allTenTypes :: [TenType]
allTenTypes = Pair <$> ((,) <$> [minBound..maxBound] <*> [minBound..maxBound])
data Arg = Arg
{ argType :: Parsable
, argName :: Text
} deriving (Eq, Show, Generic, Hashable)
data Function = Function
{ funPrefix :: Maybe (LibType, Text)
, funName :: Text
, funArgs :: [Arg]
, funReturn :: Parsable
} deriving (Eq, Show, Generic, Hashable)
type Parser = Parsec Void String