{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DeriveDataTypeable #-}
module CodeGen.Types.CLI
( LibType(..)
, prefix
, describe'
, supported
, supportedLibraries
, outDir
, outModule
, srcDir
, CodeGenType(..)
, generatable
, TemplateType(..)
, generatedTypes
) where
import Data.Data
import Data.Typeable
import CodeGen.Prelude
import qualified Data.Text as T
import qualified Data.HashSet as HS
data LibType
= ATen
| THCUNN
| THCS
| THC
| THNN
| THS
| TH
deriving (Eq, Ord, Show, Enum, Bounded, Read, Generic, Hashable, Data, Typeable)
prefix :: LibType -> Bool -> Text
prefix lt long =
case lt of
THC -> if long then "THCuda" else "THC"
THCUNN -> if long then "THCuda" else "THC"
_ -> tshow lt
describe' :: LibType -> String
describe' = \case
ATen -> "A simple TENsor library thats exposes the Tensor operations in Torch"
++ "and PyTorch directly in C++11."
TH -> "Torch7"
THC -> "Cuda-based Torch7"
THCS -> "Cuda-based Sparse Tensor support with TH"
THCUNN -> "Cuda-based THNN"
THNN -> "THNN"
THS -> "TH Sparse tensor support (ATen library)"
supported :: LibType -> Bool
supported lt = lt `HS.member` HS.fromList [TH, THC, THNN, THCUNN]
supportedLibraries :: [LibType]
supportedLibraries = filter supported [minBound..maxBound]
outDir :: LibType -> FilePath
outDir lt = intercalate ""
[ "output/raw/"
, toLowers lt ++ "/"
, "src/"
, T.unpack (out "/" lt)
]
toLowers :: Show a => a -> String
toLowers = map toLower . show
outModule :: LibType -> Text
outModule = out "."
out :: Text -> LibType -> Text
out x = \case
THCUNN -> go2 THC
THNN -> go2 TH
rest -> go1 rest
where
go1 lt = T.intercalate x ["Torch","FFI", tshow lt]
go2 lt = T.intercalate x ["Torch","FFI", tshow lt, "NN"]
srcDir :: LibType -> CodeGenType -> FilePath
srcDir lt cgt = intercalate ""
[ "./deps/aten/src/"
, show lt ++ "/"
, if cgt == GenericFiles then "generic/" else ""
]
data CodeGenType
= GenericFiles
| ConcreteFiles
deriving (Eq, Ord, Enum, Bounded)
instance Read CodeGenType where
readsPrec _ s = case s of
"generic" -> [(GenericFiles, "")]
"concrete" -> [(ConcreteFiles, "")]
_ -> []
instance Show CodeGenType where
show = \case
GenericFiles -> "generic"
ConcreteFiles -> "concrete"
generatable :: CodeGenType -> Bool
generatable = const True
data TemplateType
= GenByte
| GenChar
| GenDouble
| GenFloat
| GenHalf
| GenInt
| GenLong
| GenShort
| GenNothing
deriving (Eq, Ord, Bounded, Show, Generic, Hashable)
generatedTypes :: LibType -> CodeGenType -> [TemplateType]
generatedTypes THNN = \case { ConcreteFiles -> [GenNothing]; GenericFiles -> [GenDouble, GenFloat] }
generatedTypes THCUNN = \case { ConcreteFiles -> [GenNothing]; GenericFiles -> [GenDouble, GenFloat] }
generatedTypes _ = \case
ConcreteFiles -> [GenNothing]
GenericFiles ->
[ GenByte
, GenChar
, GenDouble
, GenFloat
, GenHalf
, GenInt
, GenLong
, GenShort
]