{-# LANGUAGE TypeFamilies, ScopedTypeVariables #-}
module Language.Haskell.Names.Recursive
  ( resolve
  , annotate
  ) where

import Data.Foldable (traverse_)
import Data.Graph(stronglyConnComp, flattenSCC)
import Data.Data (Data)
import Control.Monad (forM, forM_, unless)

import qualified Data.Map as Map (insert)
import Control.Monad.State.Strict (State, execState, get, modify)
import Language.Haskell.Exts.Annotated
import Language.Haskell.Exts.Annotated.Simplify (sModuleName)

import Language.Haskell.Names.Types
import Language.Haskell.Names.SyntaxUtils
import Language.Haskell.Names.ScopeUtils
import Language.Haskell.Names.ModuleSymbols
import Language.Haskell.Names.Exports
import Language.Haskell.Names.Imports
import Language.Haskell.Names.Open.Base
import Language.Haskell.Names.Annotated


-- | Takes a list of modules and an environment and updates the environment
-- with each of the given modules' exported symbols.
resolve :: (Data l, Eq l) => [Module l] -> Environment -> Environment
resolve modules environment = updatedEnvironment where
  moduleSCCs = groupModules modules
  updatedEnvironment = execState (traverse_ findFixPoint moduleSCCs) environment

-- | Take a set of modules and return a list of sets, where each sets for
-- a strongly connected component in the import graph.
groupModules :: [Module l] -> [[Module l]]
groupModules modules =
  map flattenSCC (stronglyConnComp (map moduleNode modules))

moduleNode :: Module l -> (Module l, ModuleName (), [ModuleName ()])
moduleNode modul =
  ( modul
  , dropAnn (getModuleName modul)
  , map (dropAnn . importModule) (getImports modul)
  )

-- | Compute interfaces for a set of mutually recursive modules and
-- update the environment accordingly.
findFixPoint :: (Data l, Eq l) => [Module l] -> State Environment ()
findFixPoint modules = loop (replicate (length modules) []) where
  loop modulesSymbols = do
    forM_ (zip modules modulesSymbols) (\(modul, symbols) -> do
      modify (Map.insert (sModuleName (getModuleName modul)) symbols))
    environment <- get
    modulesSymbols' <- forM modules (\modul -> do
      let globalTable = moduleTable (importTable environment modul) modul
      return (exportedSymbols globalTable modul))
    unless (modulesSymbols == modulesSymbols') (loop modulesSymbols')

-- | Annotate a module with scoping information using the given environment.
-- All imports of the given module should be in the environment.
annotate :: (Data l, Eq l, SrcInfo l) => Environment -> Module l -> Module (Scoped l)
annotate environment modul@(Module _ _ _ _ _) =
  Module l' maybeModuleHead' modulePragmas' importDecls' decls' where
    Module l maybeModuleHead modulePragmas importDecls decls = modul
    l' = none l
    maybeModuleHead' = case maybeModuleHead of
      Nothing -> Nothing
      Just (ModuleHead lh moduleName maybeWarning maybeExports) ->
        Just (ModuleHead lh' moduleName' maybeWarning' maybeExports') where
          lh'= none lh
          moduleName' = noScope moduleName
          maybeWarning' = fmap noScope maybeWarning
          maybeExports' = fmap (annotateExportSpecList globalTable) maybeExports
    modulePragmas' = fmap noScope modulePragmas
    importDecls' = annotateImportDecls moduleName environment importDecls
    decls' = map (annotateDecl (initialScope (sModuleName moduleName) globalTable)) decls
    globalTable = moduleTable (importTable environment modul) modul
    moduleName = getModuleName modul
annotate _ _ = error "annotateModule: non-standard modules are not supported"