{-# LANGUAGE OverloadedStrings #-}
module TensorSafe.Compile.Expr (
DLayer (..),
CNetwork (..),
JavaScript (..),
Python (..),
Generator,
generate,
generateFile
) where
import Data.Map
import Data.Text.Lazy as T
import Formatting
import Text.Casing (camel, quietSnake)
data DLayer = DConv2D
| DDense
| DDropout
| DFlatten
| DLSTM
| DMaxPooling
| DRelu
| DActivation
deriving Show
data CNetwork = CNSequence CNetwork
| CNCons CNetwork CNetwork
| CNLayer DLayer (Map String String)
| CNReturn
| CNNil
deriving Show
data JavaScript = JavaScript deriving Show
data Python = Python deriving Show
class LayerGenerator l where
generateName :: l -> DLayer -> String
instance LayerGenerator JavaScript where
generateName _ DConv2D = "conv2d"
generateName _ DDense = "dense"
generateName _ DDropout = "dropout"
generateName _ DFlatten = "flatten"
generateName _ DLSTM = "lstm"
generateName _ DMaxPooling = "maxPooling2d"
generateName _ DRelu = "reLU"
generateName _ DActivation = "activation"
instance LayerGenerator Python where
generateName _ DConv2D = "Conv2D"
generateName _ DDense = "Dense"
generateName _ DDropout = "Dropout"
generateName _ DFlatten = "Flatten"
generateName _ DLSTM = "LSTM"
generateName _ DMaxPooling = "MaxPool2D"
generateName _ DRelu = "ReLu"
generateName _ DActivation = "Activation"
class Generator l where
generate :: l -> CNetwork -> Text
generateFile :: l -> CNetwork -> Text
instance Generator JavaScript where
generate l =
T.intercalate "\n" . generateJS
where
generateJS :: CNetwork -> [Text]
generateJS (CNSequence cn) = ["var model = tf.sequential();"] ++ generateJS cn
generateJS (CNCons cn1 cn2) = (generateJS cn1) ++ (generateJS cn2)
generateJS CNNil = []
generateJS CNReturn = []
generateJS (CNLayer layer params) =
[format
("model.add(tf.layers." % string % "(" % string % "));")
(generateName l layer)
(paramsToJS params)
]
generateFile l cn =
startCode `append` (generate l cn) `append` endCode
where
startCode :: Text
startCode = T.intercalate "\n"
[ "// Autogenerated code"
, "var tf = require(\"@tensorflow/tfjs\");"
, "function model() {"
, "\n"
]
endCode :: Text
endCode = T.intercalate "\n"
[ "\n"
, "return model;"
, "}"
, "\n"
, "module.exports = model();"
]
paramsToJS :: Map String String -> String
paramsToJS m =
(foldrWithKey showParam "{ " m) ++ "}"
where
showParam :: String -> String -> String -> String
showParam key value accum = accum ++ (camel key) ++ ": " ++ value ++ ", "
instance Generator Python where
generate l =
T.intercalate "\n" . generatePy
where
generatePy :: CNetwork -> [Text]
generatePy (CNSequence cn) = ["model = tf.keras.models.Sequential()"] ++ generatePy cn
generatePy (CNCons cn1 cn2) = (generatePy cn1) ++ (generatePy cn2)
generatePy CNNil = []
generatePy CNReturn = []
generatePy (CNLayer layer params) =
[format
("model.add(tf.layers." % string % "(" % string % "))")
(generateName l layer)
(paramsToPython params)]
generateFile l cn =
startCode `append` (generate l cn)
where
startCode :: Text
startCode = T.intercalate "\n"
[ "// Autogenerated code"
, "import tensorflow as tf"
, "\n"
]
paramsToPython :: Map String String -> String
paramsToPython =
foldrWithKey showParam ""
where
showParam :: String -> String -> String -> String
showParam key value accum = accum ++ (transform key) ++ "=" ++ value ++ ", "
transform :: String -> String
transform key
| key == "inputDim" = "input_shape"
| otherwise = quietSnake key