-- | Provides the plugins monadic envionment,
--   access to the environment and message printing capabilities.
module Control.Super.Plugin.Environment
  ( -- * Supermonad Plugin Monad
    SupermonadPluginM
  , runSupermonadPlugin
  , runSupermonadPluginAndReturn
  , runTcPlugin
    -- * Supermonad Plugin Environment Access
  , getGivenConstraints, getWantedConstraints
  , getInstEnvs
  , getClassDictionary
  , getClass
  , isOptionalClass
  , getCustomState, putCustomState, modifyCustomState
  , getInstanceFor
  , addTypeEqualities, addTypeEquality
  , addTyVarEqualities, addTyVarEquality
  , getTypeEqualities, getTyVarEqualities
  , whenNoResults
  , addWarning, displayWarnings
  , throwPluginError, throwPluginErrorSDoc, catchPluginError
    -- * Debug and Error Output
  , assert, assertM
  , printErr, printMsg, printObj, printWarn
  , printConstraints
  ) where

import Data.List ( groupBy )

import Control.Monad ( when, unless, forM_ )
import Control.Monad.Reader ( ReaderT, runReaderT, asks )
import Control.Monad.State  ( StateT , runStateT , gets, modify )
import Control.Monad.Except ( ExceptT, runExceptT, throwError, catchError )
import Control.Monad.Trans.Class ( lift )

import Class ( Class )
-- import Module ( Module )
import InstEnv ( InstEnvs, ClsInst )
import Type ( TyVar, Type )
import TyCon ( TyCon )
import TcRnTypes ( Ct, TcPluginResult(..) )
import TcPluginM ( TcPluginM, tcPluginIO )
import qualified TcPluginM
import Outputable ( Outputable )
import SrcLoc ( srcSpanFileName_maybe )
import FastString ( unpackFS )
import qualified Outputable as O

import qualified Control.Super.Plugin.Log as L
import Control.Super.Plugin.Names ( PluginClassName )
import Control.Super.Plugin.Constraint
  ( GivenCt, WantedCt
  , constraintSourceLocation
  , mkDerivedTypeEqCt, mkDerivedTypeEqCtOfTypes )
import Control.Super.Plugin.ClassDict
  ( ClassDict
  , Optional
  , emptyClsDict
  , lookupClsDictClass )
import qualified Control.Super.Plugin.ClassDict as ClsD
import Control.Super.Plugin.InstanceDict
  ( InstanceDict, lookupInstDict )

-- -----------------------------------------------------------------------------
-- Plugin Monad
-- -----------------------------------------------------------------------------

-- | The error type used as result if the plugin fails.
type SupermonadError = O.SDoc

-- | The plugin monad.
type SupermonadPluginM s = ReaderT SupermonadPluginEnv 
                       ( StateT  (SupermonadPluginState s)
                       ( ExceptT SupermonadError TcPluginM
                       ) )

-- | The read-only environent of the plugin.
data SupermonadPluginEnv = SupermonadPluginEnv
  { smEnvGivenConstraints  :: [GivenCt]
  -- ^ The given and derived constraints (all of them).
  , smEnvWantedConstraints :: [WantedCt]
  -- ^ The wanted constraints (all of them).
  , smEnvClassDictionary :: ClassDict
  -- ^ Class dictionary of the environment.
  }

-- | The modifiable state of the plugin.
data SupermonadPluginState s = SupermonadPluginState 
  { smStateTyVarEqualities :: [(Ct, TyVar, Type)]
  -- ^ Equalities between type variables and types that have been derived by the plugin.
  , smStateTypeEqualities :: [(Ct, Type, Type)]
  -- ^ Eqaulities between types that have been derived by the plugin.
  , smStateWarningQueue :: [(String, O.SDoc)]
  -- ^ A queue of warnings that are only displayed if no progress could be made.
  , smStateCustom :: s
  -- ^ Custom state of the environment.
  }

-- | Runs the given supermonad plugin solver within the type checker plugin 
--   monad. Handles errors and produces a plugin result based on the environment.
runSupermonadPluginAndReturn 
  :: [GivenCt] -- ^ /Given/ and /derived/ constraints. 
  -> [WantedCt] -- ^ /Wanted/ constraints.
  -> SupermonadPluginM () (ClassDict, s) -- ^ Initialize the custom state of the plugin.
  -> SupermonadPluginM s a -- ^ Plugin code to run. Result value is ignored.
  -> TcPluginM TcPluginResult -- ^ The plugin result.
runSupermonadPluginAndReturn givenCts wantedCts initStateM pluginM = do
  eResult <- runSupermonadPlugin givenCts wantedCts initStateM $ do
    if not $ null wantedCts then do
      _ <- pluginM
      
      tyVarEqs <- getTyVarEqualities
      let tyVarEqCts = fmap (\(baseCt, tv, ty) -> mkDerivedTypeEqCt baseCt tv ty) tyVarEqs
      
      tyEqs <- getTypeEqualities
      let tyEqCts = fmap (\(baseCt, ta, tb) -> mkDerivedTypeEqCtOfTypes baseCt ta tb) tyEqs
      
      return $ TcPluginOk [] $ tyVarEqCts ++ tyEqCts
    else
      return $ TcPluginOk [] [] -- No result
  case eResult of
    Left err -> do
      L.printErr $ L.sDocToStr err
      return $ TcPluginOk [] [] -- No result
    Right solution -> return solution
  
-- | Runs the given supermonad plugin solver within the type checker plugin 
--   monad.
runSupermonadPlugin 
  :: [GivenCt] -- ^ /Given/ and /derived/ constraints. 
  -> [WantedCt] -- ^ /Wanted/ constraints.
  -> SupermonadPluginM () (ClassDict, s) -- ^ Initialize the custom state of the plugin.
  -> SupermonadPluginM s a -- ^ Plugin code to run.
  -> TcPluginM (Either SupermonadError a) -- ^ Either an error message or an actual plugin result.
runSupermonadPlugin givenCts wantedCts initStateM pluginM = do
  -- Try to construct the environment or throw errors
  let initEnv = SupermonadPluginEnv
        { smEnvGivenConstraints  = givenCts
        , smEnvWantedConstraints = wantedCts
        , smEnvClassDictionary   = emptyClsDict
        }
  let initState :: SupermonadPluginState ()
      initState = SupermonadPluginState 
        { smStateTyVarEqualities = []
        , smStateTypeEqualities  = []
        , smStateWarningQueue    = []
        , smStateCustom = ()
        }
  eInitResult <- runExceptT $ flip runStateT initState $ runReaderT initStateM initEnv
  case eInitResult of
    Left err -> return $ Left err
    Right ((smDict, customState), postInitState) -> do
      let env = initEnv { smEnvClassDictionary = smDict }
      let -- state :: SupermonadPluginState s
          state = SupermonadPluginState 
            { smStateTyVarEqualities = smStateTyVarEqualities postInitState
            , smStateTypeEqualities  = smStateTypeEqualities  postInitState
            , smStateWarningQueue    = smStateWarningQueue    postInitState
            , smStateCustom = customState
            }
      eResult <- runExceptT $ flip runStateT state $ runReaderT pluginM env
      return $ case eResult of
        Left  err -> Left err
        Right (a, _res) -> Right a


-- | Execute the given 'TcPluginM' computation within the plugin monad.
runTcPlugin :: TcPluginM a -> SupermonadPluginM s a
runTcPlugin = lift . lift . lift

-- -----------------------------------------------------------------------------
-- Plugin Environment Access
-- -----------------------------------------------------------------------------

-- | Returns the type class dictionary.
getClassDictionary :: SupermonadPluginM s ClassDict
getClassDictionary = asks smEnvClassDictionary

-- | Returns the plugins custom state.
getCustomState :: SupermonadPluginM s s
getCustomState = gets smStateCustom

-- | Writes the plugins custom state.
putCustomState :: s -> SupermonadPluginM s ()
putCustomState newS = modify (\s -> s { smStateCustom = newS })

-- | Modifies the plugins custom state.
modifyCustomState :: (s -> s) -> SupermonadPluginM s ()
modifyCustomState sf = modify (\s -> s { smStateCustom = sf (smStateCustom s) })

-- | Looks up a class by its name in the class dictionary of the 
--   plugin environment.
getClass :: PluginClassName -> SupermonadPluginM s (Maybe Class)
getClass clsName = lookupClsDictClass clsName <$> asks smEnvClassDictionary

-- | Check if the given class name refers that is optional in the solving process.
isOptionalClass :: PluginClassName -> SupermonadPluginM s Optional
isOptionalClass clsName = ClsD.isOptionalClass clsName <$> asks smEnvClassDictionary

-- | Returns all of the /given/ and /derived/ constraints of this plugin call.
getGivenConstraints :: SupermonadPluginM s [GivenCt]
getGivenConstraints = asks smEnvGivenConstraints

-- | Returns all of the wanted constraints of this plugin call.
getWantedConstraints :: SupermonadPluginM s [WantedCt]
getWantedConstraints = asks smEnvWantedConstraints

-- | Shortcut to access the instance environments.
getInstEnvs :: SupermonadPluginM s InstEnvs
getInstEnvs = runTcPlugin TcPluginM.getInstEnvs

-- | Retrieves the associated instance of the given type constructor and class.
getInstanceFor :: TyCon -> Class -> SupermonadPluginM InstanceDict (Maybe ClsInst)
getInstanceFor tc cls = fmap (lookupInstDict tc cls) getCustomState

-- | Add another type variable equality to the derived equalities.
addTyVarEquality :: Ct -> TyVar -> Type -> SupermonadPluginM s ()
addTyVarEquality ct tv ty = modify $ \s -> s { smStateTyVarEqualities = (ct, tv, ty) : smStateTyVarEqualities s }

-- | Add a list of type variable equalities to the derived equalities.
addTyVarEqualities :: [(Ct, TyVar, Type)] -> SupermonadPluginM s ()
addTyVarEqualities = mapM_ (\(ct, tv, ty) -> addTyVarEquality ct tv ty)

-- | Add another type equality to the derived equalities.
addTypeEquality :: Ct -> Type -> Type -> SupermonadPluginM s ()
addTypeEquality ct ta tb = modify $ \s -> s { smStateTypeEqualities = (ct, ta, tb) : smStateTypeEqualities s }

-- | Add a list of type equality to the derived equalities.
addTypeEqualities :: [(Ct, Type, Type)] -> SupermonadPluginM s ()
addTypeEqualities = mapM_ (\(ct, ta, tb) -> addTypeEquality ct ta tb)

-- | Returns all derived type variable equalities that were added to the results thus far.
getTyVarEqualities :: SupermonadPluginM s [(Ct, TyVar, Type)]
getTyVarEqualities = gets $ smStateTyVarEqualities

-- | Returns all derived type variable equalities that were added to the results thus far.
getTypeEqualities :: SupermonadPluginM s [(Ct, Type, Type)]
getTypeEqualities = gets $ smStateTypeEqualities

-- | Add a warning to the queue of warnings that will be displayed when no progress could be made.
addWarning :: String -> O.SDoc -> SupermonadPluginM s ()
addWarning msg details = modify $ \s -> s { smStateWarningQueue = (msg, details) : smStateWarningQueue s }

-- | Execute the given plugin code only if no plugin results were produced so far.
whenNoResults :: SupermonadPluginM s () -> SupermonadPluginM s ()
whenNoResults m = do
  tyVarEqs <- getTyVarEqualities
  tyEqs <- getTypeEqualities
  when (null tyVarEqs && null tyEqs) m

-- | Displays the queued warning messages if no progress has been made.
displayWarnings :: SupermonadPluginM s ()
displayWarnings = whenNoResults $ do
  warns <- gets smStateWarningQueue
  forM_ warns $ \(msg, details) -> do
    printWarn msg
    internalPrint $ L.smObjMsg $ L.sDocToStr details

-- -----------------------------------------------------------------------------
-- Plugin debug and error printing
-- -----------------------------------------------------------------------------

stringToSupermonadError :: String -> SupermonadError
stringToSupermonadError = O.text

-- | Assert the given condition. If the condition does not
--   evaluate to 'True', an error with the given message will
--   be thrown the plugin aborts.
assert :: Bool -> String -> SupermonadPluginM s ()
assert cond msg = unless cond $ throwPluginError msg

-- | Assert the given condition. Same as 'assert' but with
--   a monadic condition.
assertM :: SupermonadPluginM s Bool -> String -> SupermonadPluginM s ()
assertM condM msg = do
  cond <- condM
  assert cond msg

-- | Throw an error with the given message in the plugin.
--   This will abort all further actions.
throwPluginError :: String -> SupermonadPluginM s a
throwPluginError = throwError . stringToSupermonadError

-- | Throw an error with the given message in the plugin.
--   This will abort all further actions.
throwPluginErrorSDoc :: O.SDoc -> SupermonadPluginM s a
throwPluginErrorSDoc = throwError

-- | Catch an error that was thrown by the plugin.
catchPluginError :: SupermonadPluginM s a -> (SupermonadError -> SupermonadPluginM s a) -> SupermonadPluginM s a
catchPluginError = catchError

-- | Print some generic outputable object from the plugin (Unsafe).
printObj :: Outputable o => o -> SupermonadPluginM s ()
printObj = internalPrint . L.smObjMsg . L.pprToStr

-- | Print a message from the plugin.
printMsg :: String -> SupermonadPluginM s ()
printMsg = internalPrint . L.smDebugMsg

-- | Print an error message from the plugin.
printErr :: String -> SupermonadPluginM s ()
printErr = internalPrint . L.smErrMsg

-- | Print a warning message from the plugin.
printWarn :: String -> SupermonadPluginM s ()
printWarn = internalPrint . L.smWarnMsg

-- | Internal function for printing from within the monad.
internalPrint :: String -> SupermonadPluginM s ()
internalPrint = runTcPlugin . tcPluginIO . putStr

-- | Print the given string as if it was an object. This allows custom
--   formatting of object.
printFormattedObj :: String -> SupermonadPluginM s ()
printFormattedObj = internalPrint . L.smObjMsg

-- | Print the given constraints in the plugins custom format.
printConstraints :: [Ct] -> SupermonadPluginM s ()
printConstraints cts =
  forM_ groupedCts $ \(file, ctGroup) -> do
    printFormattedObj $ maybe "From unknown file:" (("From " ++) . (++":") . unpackFS) file
    mapM_ (printFormattedObj . L.formatConstraint) ctGroup
  where
    groupedCts = (\ctGroup -> (getCtFile $ head ctGroup, ctGroup)) <$> groupBy eqFileName cts
    eqFileName ct1 ct2 = getCtFile ct1 == getCtFile ct2
    getCtFile = srcSpanFileName_maybe . constraintSourceLocation