{-# LANGUAGE CPP, LambdaCase, ViewPatterns #-}
module TypeLevel.Rewrite.Internal.Lookup where

import Control.Arrow ((***), first)
import Data.Tuple (swap)

-- GHC API
import GHC (DataCon, TyCon, dataConTyCon)
#if MIN_VERSION_ghc(9,0,0)
import GHC.Driver.Finder (cannotFindModule)
import GHC (Module, ModuleName, mkModuleName)
import GHC.Plugins (mkDataOcc, mkTcOcc)
import GHC.Utils.Panic (panicDoc)
import GHC.Tc.Plugin
  ( FindResult(Found), TcPluginM, findImportedModule, lookupOrig, tcLookupDataCon, tcLookupTyCon
  , unsafeTcPluginTcM
  )
import GHC.Tc.Solver.Monad (getDynFlags)
#else
import Finder (cannotFindModule)
import Module (Module, ModuleName, mkModuleName)
import OccName (mkDataOcc, mkTcOcc)
import Panic (panicDoc)
import TcPluginM
import TcSMonad (getDynFlags)
#endif

lookupModule
  :: String  -- ^ module name
  -> TcPluginM Module
lookupModule :: String -> TcPluginM Module
lookupModule String
moduleNameStr = do
  let moduleName :: ModuleName
      moduleName :: ModuleName
moduleName = String -> ModuleName
mkModuleName String
moduleNameStr
  ModuleName -> Maybe FastString -> TcPluginM FindResult
findImportedModule ModuleName
moduleName Maybe FastString
forall a. Maybe a
Nothing TcPluginM FindResult
-> (FindResult -> TcPluginM Module) -> TcPluginM Module
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Found ModLocation
_ Module
module_ -> do
      Module -> TcPluginM Module
forall (f :: * -> *) a. Applicative f => a -> f a
pure Module
module_
    FindResult
findResult -> do
      DynFlags
dynFlags <- TcM DynFlags -> TcPluginM DynFlags
forall a. TcM a -> TcPluginM a
unsafeTcPluginTcM TcM DynFlags
forall (m :: * -> *). HasDynFlags m => m DynFlags
getDynFlags
      String -> SDoc -> TcPluginM Module
forall a. String -> SDoc -> a
panicDoc (String
"TypeLevel.Lookup.lookupModule " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> String
forall a. Show a => a -> String
show String
moduleNameStr)
             (SDoc -> TcPluginM Module) -> SDoc -> TcPluginM Module
forall a b. (a -> b) -> a -> b
$ DynFlags -> ModuleName -> FindResult -> SDoc
cannotFindModule DynFlags
dynFlags ModuleName
moduleName FindResult
findResult

-- 'TcPluginM.lookupM' unfortunately fails with a very unhelpful error message
-- when we look up a name which doesn't exist:
--
--   Can't find interface-file declaration for type constructor or class ModuleName.TypeName
--   Probable cause: bug in .hi-boot file, or inconsistent .hi file
--   Use -ddump-if-trace to get an idea of which file caused the error
--
-- But the true cause isn't a corrupted file, it's simply that the requested
-- name is not in the given module. I don't know how to fix the error message
-- (I can't use 'try' nor 'tryM' because we're in the wrong monad)

lookupTyCon
  :: String  -- ^ module name
  -> String  -- ^ type constructor/family name
  -> TcPluginM TyCon
lookupTyCon :: String -> String -> TcPluginM TyCon
lookupTyCon String
moduleNameStr String
tyConNameStr = do
  Module
module_ <- String -> TcPluginM Module
lookupModule String
moduleNameStr
  Name
tyConName <- Module -> OccName -> TcPluginM Name
lookupOrig Module
module_ (String -> OccName
mkTcOcc String
tyConNameStr)
  TyCon
tyCon <- Name -> TcPluginM TyCon
tcLookupTyCon Name
tyConName
  TyCon -> TcPluginM TyCon
forall (f :: * -> *) a. Applicative f => a -> f a
pure TyCon
tyCon

lookupDataCon
  :: String  -- ^ module name
  -> String  -- ^ data constructor name
  -> TcPluginM DataCon
lookupDataCon :: String -> String -> TcPluginM DataCon
lookupDataCon String
moduleNameStr String
dataConNameStr = do
  Module
module_ <- String -> TcPluginM Module
lookupModule String
moduleNameStr
  Name
dataConName <- Module -> OccName -> TcPluginM Name
lookupOrig Module
module_ (String -> OccName
mkDataOcc String
dataConNameStr)
  DataCon
dataCon <- Name -> TcPluginM DataCon
tcLookupDataCon Name
dataConName
  DataCon -> TcPluginM DataCon
forall (f :: * -> *) a. Applicative f => a -> f a
pure DataCon
dataCon


splitFirstDot
  :: String -> Maybe (String, String)
splitFirstDot :: String -> Maybe (String, String)
splitFirstDot (Char
'.' : String
rhs)
  = (String, String) -> Maybe (String, String)
forall a. a -> Maybe a
Just (String
"", String
rhs)
splitFirstDot (Char
x : String
xs)
  = (String -> String) -> (String, String) -> (String, String)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (Char
xChar -> String -> String
forall a. a -> [a] -> [a]
:) ((String, String) -> (String, String))
-> Maybe (String, String) -> Maybe (String, String)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> Maybe (String, String)
splitFirstDot String
xs
splitFirstDot String
_
  = Maybe (String, String)
forall a. Maybe a
Nothing

splitLastDot
  :: String -> Maybe (String, String)
splitLastDot :: String -> Maybe (String, String)
splitLastDot
  = ((String, String) -> (String, String))
-> Maybe (String, String) -> Maybe (String, String)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (String, String) -> (String, String)
forall a b. (a, b) -> (b, a)
swap
  (Maybe (String, String) -> Maybe (String, String))
-> (String -> Maybe (String, String))
-> String
-> Maybe (String, String)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((String, String) -> (String, String))
-> Maybe (String, String) -> Maybe (String, String)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (String -> String
forall a. [a] -> [a]
reverse (String -> String)
-> (String -> String) -> (String, String) -> (String, String)
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** String -> String
forall a. [a] -> [a]
reverse)
  (Maybe (String, String) -> Maybe (String, String))
-> (String -> Maybe (String, String))
-> String
-> Maybe (String, String)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Maybe (String, String)
splitFirstDot
  (String -> Maybe (String, String))
-> (String -> String) -> String -> Maybe (String, String)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String
forall a. [a] -> [a]
reverse

-- lookup a Fully-Qualified Name, such as "'GHC.Types.[]" or "TypeLevel.Append.++"
lookupFQN
  :: String
  -> TcPluginM TyCon
lookupFQN :: String -> TcPluginM TyCon
lookupFQN (Char
'\'' : (String -> Maybe (String, String)
splitLastDot -> Just (String
moduleNameStr, String
dataConNameStr)))
  = DataCon -> TyCon
dataConTyCon (DataCon -> TyCon) -> TcPluginM DataCon -> TcPluginM TyCon
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> TcPluginM DataCon
lookupDataCon String
moduleNameStr String
dataConNameStr
lookupFQN (String -> Maybe (String, String)
splitLastDot -> Just (String
moduleNameStr, String
tyConNameStr))
  = String -> String -> TcPluginM TyCon
lookupTyCon String
moduleNameStr String
tyConNameStr
lookupFQN String
fqn
  = String -> TcPluginM TyCon
forall a. HasCallStack => String -> a
error (String -> TcPluginM TyCon) -> String -> TcPluginM TyCon
forall a b. (a -> b) -> a -> b
$ String
"expected " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> String
forall a. Show a => a -> String
show String
"ModuleName.TypeName"
         String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
", got " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> String
forall a. Show a => a -> String
show String
fqn