{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeApplications #-}
module TreeSitter.GenerateSyntax
( syntaxDatatype
, removeUnderscore
, initUpper
, astDeclarationsForLanguage
) where
import Data.Char
import Language.Haskell.TH as TH
import Language.Haskell.TH.Syntax as TH
import Data.HashSet (HashSet)
import TreeSitter.Deserialize (Datatype (..), DatatypeName (..), Field (..), Children(..), Required (..), Type (..), Named (..), Multiple (..))
import Data.List.NonEmpty (NonEmpty (..))
import Data.Foldable
import Data.Text (Text)
import qualified Data.HashSet as HashSet
import qualified TreeSitter.Unmarshal as TS
import GHC.Generics hiding (Constructor, Datatype)
import Foreign.Ptr
import qualified TreeSitter.Language as TS
import Foreign.C.String
import Data.Aeson hiding (String)
import System.Directory
import System.FilePath.Posix
import TreeSitter.Node
import TreeSitter.Token
import TreeSitter.Symbol (escapeOperatorPunctuation)
astDeclarationsForLanguage :: Ptr TS.Language -> FilePath -> Q [Dec]
astDeclarationsForLanguage language filePath = do
_ <- TS.addDependentFileRelative filePath
currentFilename <- loc_filename <$> location
pwd <- runIO getCurrentDirectory
let invocationRelativePath = takeDirectory (pwd </> currentFilename) </> filePath
input <- runIO (eitherDecodeFileStrict' invocationRelativePath)
either fail (fmap (concat @[]) . traverse (syntaxDatatype language)) input
syntaxDatatype :: Ptr TS.Language -> Datatype -> Q [Dec]
syntaxDatatype language datatype = skipDefined $ do
typeParameterName <- newName "a"
case datatype of
SumType (DatatypeName _) _ subtypes -> do
types' <- fieldTypesToNestedSum subtypes
con <- normalC name [TH.bangType strictness (pure types' `appT` varT typeParameterName)]
pure [NewtypeD [] name [PlainTV typeParameterName] Nothing con [deriveGN, deriveStockClause, deriveAnyClassClause]]
ProductType (DatatypeName datatypeName) _ children fields -> do
con <- ctorForProductType datatypeName typeParameterName children fields
result <- symbolMatchingInstance language name datatypeName
pure $ generatedDatatype name [con] typeParameterName:result
LeafType (DatatypeName datatypeName) Anonymous -> do
tsSymbol <- runIO $ withCString datatypeName (TS.ts_language_symbol_for_name language)
pure [ TySynD name [] (ConT ''Token `AppT` LitT (StrTyLit datatypeName) `AppT` LitT (NumTyLit (fromIntegral tsSymbol))) ]
LeafType (DatatypeName datatypeName) Named -> do
con <- ctorForLeafType (DatatypeName datatypeName) typeParameterName
result <- symbolMatchingInstance language name datatypeName
pure $ generatedDatatype name [con] typeParameterName:result
where
skipDefined m = do
isLocal <- lookupTypeName nameStr >>= maybe (pure False) isLocalName
if isLocal then pure [] else m
name = mkName nameStr
nameStr = toNameString (datatypeNameStatus datatype) (getDatatypeName (TreeSitter.Deserialize.datatypeName datatype))
deriveStockClause = DerivClause (Just StockStrategy) [ ConT ''Eq, ConT ''Ord, ConT ''Show, ConT ''Generic, ConT ''Foldable, ConT ''Functor, ConT ''Traversable, ConT ''Generic1]
deriveAnyClassClause = DerivClause (Just AnyclassStrategy) [ConT ''TS.Unmarshal]
deriveGN = DerivClause (Just NewtypeStrategy) [ConT ''TS.SymbolMatching]
generatedDatatype name cons typeParameterName = DataD [] name [PlainTV typeParameterName] Nothing cons [deriveStockClause, deriveAnyClassClause]
symbolMatchingInstance :: Ptr TS.Language -> Name -> String -> Q [Dec]
symbolMatchingInstance language name str = do
tsSymbol <- runIO $ withCString str (TS.ts_language_symbol_for_name language)
tsSymbolType <- toEnum <$> runIO (TS.ts_language_symbol_type language tsSymbol)
[d|instance TS.SymbolMatching $(conT name) where
showFailure _ node = "Expected " <> $(litE (stringL (show name))) <> " but got " <> show (TS.fromTSSymbol (nodeSymbol node) :: $(conT (mkName "Grammar.Grammar")))
symbolMatch _ node = TS.fromTSSymbol (nodeSymbol node) == $(conE (mkName $ "Grammar." <> TS.symbolToName tsSymbolType str))|]
ctorForProductType :: String -> Name -> Maybe Children -> [(String, Field)] -> Q Con
ctorForProductType constructorName typeParameterName children fields = ctorForTypes constructorName lists where
lists = annotation : fieldList ++ childList
annotation = ("ann", varT typeParameterName)
fieldList = map (fmap toType) fields
childList = toList $ fmap toTypeChild children
toType (MkField required fieldTypes mult) =
let ftypes = fieldTypesToNestedSum fieldTypes `appT` varT typeParameterName
in case (required, mult) of
(Required, Multiple) -> appT (conT ''NonEmpty) ftypes
(Required, Single) -> ftypes
(Optional, Multiple) -> appT (conT ''[]) ftypes
(Optional, Single) -> appT (conT ''Maybe) ftypes
toTypeChild (MkChildren field) = ("extra_children", toType field)
ctorForLeafType :: DatatypeName -> Name -> Q Con
ctorForLeafType (DatatypeName name) typeParameterName = ctorForTypes name
[ ("ann", varT typeParameterName)
, ("text", conT ''Text)
]
ctorForTypes :: String -> [(String, Q TH.Type)] -> Q Con
ctorForTypes constructorName types = recC (toName Named constructorName) recordFields where
recordFields = map (uncurry toVarBangType) types
toVarBangType str type' = TH.varBangType (mkName . addTickIfNecessary . removeUnderscore $ str) (TH.bangType strictness type')
fieldTypesToNestedSum :: NonEmpty TreeSitter.Deserialize.Type -> Q TH.Type
fieldTypesToNestedSum xs = go (toList xs)
where
combine lhs rhs = (conT ''(:+:) `appT` lhs) `appT` rhs
convertToQType (MkType (DatatypeName n) named) = conT (toName named n)
go [x] = convertToQType x
go xs = let (l,r) = splitAt (length xs `div` 2) xs in (combine (go l) (go r))
strictness :: BangQ
strictness = TH.bang noSourceUnpackedness noSourceStrictness
toCamelCase :: String -> String
toCamelCase = initUpper . escapeOperatorPunctuation . removeUnderscore
clashingNames :: HashSet String
clashingNames = HashSet.fromList ["type", "module", "data"]
addTickIfNecessary :: String -> String
addTickIfNecessary s
| HashSet.member s clashingNames = s ++ "'"
| otherwise = s
toName :: Named -> String -> Name
toName named str = mkName (toNameString named str)
toNameString :: Named -> String -> String
toNameString named str = addTickIfNecessary $ case named of
Anonymous -> "Anonymous" <> toCamelCase str
Named -> toCamelCase str
moduleForName :: Name -> Maybe Module
moduleForName n = Module . PkgName <$> namePackage n <*> (ModName <$> nameModule n)
isLocalName :: Name -> Q Bool
isLocalName n = (moduleForName n ==) . Just <$> thisModule
initUpper :: String -> String
initUpper (c:cs) = toUpper c : cs
initUpper "" = ""
removeUnderscore :: String -> String
removeUnderscore = foldr appender ""
where appender :: Char -> String -> String
appender '_' cs = initUpper cs
appender c cs = c : cs