{-|
  Copyright   :  (C) 2012-2016, University of Twente,
                     2016     , Myrtle Software Ltd,
                     2017     , Google Inc.
  License     :  BSD2 (see the file LICENSE)
  Maintainer  :  Christiaan Baaij <christiaan.baaij@gmail.com>

  Turn CoreHW terms into normalized CoreHW Terms
-}

{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns    #-}

module Clash.Normalize where

import           Control.Concurrent.Supply        (Supply)
import           Control.Lens                     ((.=),(^.),_2,_5)
import qualified Control.Lens                     as Lens
import           Data.Either                      (partitionEithers)
import           Data.HashMap.Strict              (HashMap)
import qualified Data.HashMap.Strict              as HashMap
import qualified Data.HashSet                     as HashSet
import qualified Data.IntMap                      as IntMap
import           Data.IntMap.Strict               (IntMap)
import           Data.List
  (groupBy, intersect, mapAccumL, sortBy)
import qualified Data.Map                         as Map
import qualified Data.Maybe                       as Maybe
import qualified Data.Set                         as Set
import qualified Data.Set.Lens                    as Lens
import           Data.Semigroup                   ((<>))
import           Data.Text.Prettyprint.Doc        (vcat)
import           Unbound.Generics.LocallyNameless (unembed, runLFreshM)

import           BasicTypes                       (InlineSpec (..))
import           SrcLoc                           (SrcSpan,noSrcSpan)

import           Clash.Core.Evaluator             (PrimEvaluator)
import           Clash.Core.FreeVars              (termFreeIds)
import           Clash.Core.Name                  (Name (..), NameSort (..))
import           Clash.Core.Pretty                (showDoc, ppr)
import           Clash.Core.Subst                 (substTms)
import           Clash.Core.Term                  (Term (..), TmName, TmOccName)
import           Clash.Core.Type                  (Type, splitCoreFunForallTy)
import           Clash.Core.TyCon
  (TyCon, TyConName, TyConOccName)
import           Clash.Core.Util                  (collectArgs, mkApps, termType)
import           Clash.Core.Var                   (Id,varName)
import           Clash.Driver.Types
  (BindingMap, ClashOpts (..), DebugLevel (..))
import           Clash.Netlist.BlackBox.Types     (BlackBoxTemplate)
import           Clash.Netlist.Types              (HWType (..))
import           Clash.Netlist.Util
  (splitNormalized, unsafeCoreTypeToHWType)
import           Clash.Normalize.Strategy
import           Clash.Normalize.Transformations
  (appProp, bindConstantVar, caseCon, flattenLet, reduceConst, topLet)
import           Clash.Normalize.Types
import           Clash.Normalize.Util
import           Clash.Primitives.Types           (PrimMap)
import           Clash.Rewrite.Combinators        ((>->),(!->))
import           Clash.Rewrite.Types
  (RewriteEnv (..), RewriteState (..), bindings, curFun, dbgLevel, extra,
   tcCache, topEntities, typeTranslator)
import           Clash.Rewrite.Util               (isUntranslatableType,
                                                   runRewrite,
                                                   runRewriteSession)
import Clash.Signal.Internal                      (ResetKind (..))
import           Clash.Util

-- | Run a NormalizeSession in a given environment
runNormalization
  :: ClashOpts
  -- ^ Level of debug messages to print
  -> Supply
  -- ^ UniqueSupply
  -> BindingMap
  -- ^ Global Binders
  -> (HashMap TyConOccName TyCon -> Bool -> Type -> Maybe (Either String HWType))
  -- ^ Hardcoded Type -> HWType translator
  -> HashMap TyConOccName TyCon
  -- ^ TyCon cache
  -> IntMap TyConName
  -- ^ Tuple TyCon cache
  -> PrimEvaluator
  -- ^ Hardcoded evaluator (delta-reduction)
  -> PrimMap BlackBoxTemplate
  -- ^ Primitive Definitions
  -> HashMap TmOccName Bool
  -- ^ Map telling whether a components is part of a recursive group
  -> [TmOccName]
  -- ^ topEntities
  -> NormalizeSession a
  -- ^ NormalizeSession to run
  -> a
runNormalization opts supply globals typeTrans tcm tupTcm eval primMap rcsMap topEnts
  = runRewriteSession rwEnv rwState
  where
    rwEnv     = RewriteEnv
                  (opt_dbgLevel opts)
                  typeTrans
                  tcm
                  tupTcm
                  eval
                  (opt_allowZero opts)
                  (HashSet.fromList topEnts)

    rwState   = RewriteState
                  0
                  globals
                  supply
                  (error $ $(curLoc) ++ "Report as bug: no curFun",noSrcSpan)
                  0
                  (IntMap.empty, 0)
                  normState

    normState = NormalizeState
                  HashMap.empty
                  Map.empty
                  HashMap.empty
                  (opt_specLimit opts)
                  HashMap.empty
                  (opt_inlineLimit opts)
                  (opt_inlineFunctionLimit opts)
                  (opt_inlineConstantLimit opts)
                  primMap
                  rcsMap


normalize
  :: [TmOccName]
  -> NormalizeSession BindingMap
normalize []  = return HashMap.empty
normalize top = do
  (new,topNormalized) <- unzip <$> mapM normalize' top
  newNormalized <- normalize (concat new)
  return (HashMap.union (HashMap.fromList topNormalized) newNormalized)

normalize'
  :: TmOccName
  -> NormalizeSession ([TmOccName],(TmOccName,(TmName,Type,SrcSpan,InlineSpec,Term)))
normalize' nm = do
  exprM <- HashMap.lookup nm <$> Lens.use bindings
  let nmS = showDoc nm
  case exprM of
    Just (nm',ty,sp,inl,tm) -> do
      tcm <- Lens.view tcCache
      let (_,resTy) = splitCoreFunForallTy tcm ty
      resTyRep <- not <$> isUntranslatableType False resTy
      if resTyRep
         then do
            tmNorm <- makeCached nm (extra.normalized) $ do
                        curFun .= (nm',sp)
                        tm' <- rewriteExpr ("normalization",normalization) (nmS,tm)
                        ty' <- termType tcm tm'
                        return (nm',ty',sp,inl,tm')
            let usedBndrs = Lens.toListOf termFreeIds (tmNorm ^. _5)
            traceIf (nm `elem` usedBndrs)
                    (concat [ $(curLoc),"Expr belonging to bndr: ",nmS ," (:: "
                            , showDoc (tmNorm ^. _2)
                            , ") remains recursive after normalization:\n"
                            , showDoc (tmNorm ^. _5) ])
                    (return ())
            tyTrans <- Lens.view typeTranslator
            case clockResetErrors sp tyTrans tcm ty of
              msgs@(_:_) -> traceIf True (concat (nmS:" (:: ":showDoc (tmNorm ^. _2)
                              :")\nhas potentially dangerous meta-stability issues:\n\n"
                              :msgs))
                              (return ())
              _ -> return ()
            prevNorm <- fmap HashMap.keys $ Lens.use (extra.normalized)
            topEnts  <- Lens.view topEntities
            let toNormalize = filter (not . (`HashSet.member` topEnts))
                            $ filter (`notElem` (nm:prevNorm)) usedBndrs
            return (toNormalize,(nm,tmNorm))
         else do
            let usedBndrs = Lens.toListOf termFreeIds tm
            prevNorm <- fmap HashMap.keys $ Lens.use (extra.normalized)
            topEnts  <- Lens.view topEntities
            let toNormalize = filter (not . (`HashSet.member` topEnts))
                            $ filter (`notElem` (nm:prevNorm)) usedBndrs
            lvl <- Lens.view dbgLevel
            traceIf (lvl >= DebugFinal)
                    (concat [$(curLoc), "Expr belonging to bndr: ", nmS, " (:: "
                            , showDoc ty
                            , ") has a non-representable return type."
                            , " Not normalising:\n", showDoc tm] )
                    (return (toNormalize,(nm,(nm',ty,sp,inl,tm))))
    Nothing -> error $ $(curLoc) ++ "Expr belonging to bndr: " ++ nmS ++ " not found"

-- | Rewrite a term according to the provided transformation
rewriteExpr :: (String,NormRewrite) -- ^ Transformation to apply
            -> (String,Term) -- ^ Term to transform
            -> NormalizeSession Term
rewriteExpr (nrwS,nrw) (bndrS,expr) = do
  lvl <- Lens.view dbgLevel
  let before = showDoc expr
  let expr' = traceIf (lvl >= DebugFinal)
                (bndrS ++ " before " ++ nrwS ++ ":\n\n" ++ before ++ "\n")
                expr
  rewritten <- runRewrite nrwS nrw expr'
  let after = showDoc rewritten
  traceIf (lvl >= DebugFinal)
    (bndrS ++ " after " ++ nrwS ++ ":\n\n" ++ after ++ "\n") $
    return rewritten

-- | Check whether the normalized bindings are non-recursive. Errors when one
-- of the components is recursive.
checkNonRecursive
  :: BindingMap
  -- ^ List of normalized binders
  -> BindingMap
checkNonRecursive norm = case Maybe.mapMaybe go (HashMap.toList norm) of
    []  -> norm
    rcs -> error $ $(curLoc) ++ "Callgraph after normalisation contains following recursive components: "
                   ++ show (vcat $ runLFreshM $ sequence [ do a' <- ppr a
                                                              b' <- ppr b
                                                              return $ a' <> b'
                                                         | (a,b ) <- rcs
                                                         ])
  where
    go (nm,(_,_,_,_,tm)) =
      let used = Lens.toListOf termFreeIds tm
      in  if nm `elem` used
             then Just (nm,tm)
             else Nothing

-- | Perform general \"clean up\" of the normalized (non-recursive) function
-- hierarchy. This includes:
--
--   * Inlining functions that simply \"wrap\" another function
cleanupGraph
  :: TmOccName
  -> BindingMap
  -> NormalizeSession BindingMap
cleanupGraph topEntity norm
  | Just ct <- mkCallTree [] norm topEntity
  = do ctFlat <- flattenCallTree ct
       return (HashMap.fromList $ snd $ callTreeToList [] ctFlat)
cleanupGraph _ norm = return norm

data CallTree = CLeaf   (TmOccName,(TmName,Type,SrcSpan,InlineSpec,Term))
              | CBranch (TmOccName,(TmName,Type,SrcSpan,InlineSpec,Term)) [CallTree]

mkCallTree
  :: [TmOccName]
  -- ^ Visited
  -> BindingMap
  -- ^ Global binders
  -> TmOccName
  -- ^ Root of the call graph
  -> Maybe CallTree
mkCallTree visited bindingMap root
  | Just rootTm <- HashMap.lookup root bindingMap
  = let used   = Set.toList $ Lens.setOf termFreeIds $ (rootTm ^. _5)
        other  = Maybe.mapMaybe (mkCallTree (root:visited) bindingMap) (filter (`notElem` visited) used)
    in  case used of
          [] -> Just (CLeaf   (root,rootTm))
          _  -> Just (CBranch (root,rootTm) other)
mkCallTree _ _ _ = Nothing

stripArgs
  :: [TmOccName]
  -> [Id]
  -> [Either Term Type]
  -> Maybe [Either Term Type]
stripArgs _      (_:_) []   = Nothing
stripArgs allIds []    args = if any mentionsId args
                                then Nothing
                                else Just args
  where
    mentionsId t = not $ null (either (Lens.toListOf termFreeIds) (const []) t
                              `intersect`
                              allIds)

stripArgs allIds (id_:ids) (Left (Var _ nm):args)
      | varName id_ == nm = stripArgs allIds ids args
      | otherwise         = Nothing
stripArgs _ _ _ = Nothing

flattenNode
  :: CallTree
  -> NormalizeSession (Either CallTree ((TmOccName,Term),[CallTree]))
flattenNode (CLeaf (nm,(nameSort -> Internal,_,_,_,e))) =
  return (Right ((nm,e),[]))
flattenNode c@(CLeaf (nm,(_,_,_,_,e))) = do
  tcm  <- Lens.view tcCache
  norm <- splitNormalized tcm e
  case norm of
    Right (ids,[(_,bExpr)],_) -> do
      let (fun,args) = collectArgs (unembed bExpr)
      case stripArgs (map (nameOcc.varName) ids) (reverse ids) (reverse args) of
        Just remainder -> return (Right ((nm,mkApps fun (reverse remainder)),[]))
        Nothing        -> return (Right ((nm,e),[]))
    _ | isCheapFunction e -> return (Right ((nm,e),[]))
      | otherwise         -> return (Left c)
flattenNode (CBranch (nm,(nameSort -> Internal,_,_,_,e)) us) =
  return (Right ((nm,e),us))
flattenNode b@(CBranch (nm,(_,_,_,_,e)) us) = do
  tcm  <- Lens.view tcCache
  norm <- splitNormalized tcm e
  case norm of
    Right (ids,[(_,bExpr)],_) -> do
      let (fun,args) = collectArgs (unembed bExpr)
      case stripArgs (map (nameOcc.varName) ids) (reverse ids) (reverse args) of
        Just remainder -> return (Right ((nm,mkApps fun (reverse remainder)),us))
        Nothing        -> return (Right ((nm,e),us))
    _ | isCheapFunction e -> return (Right ((nm,e),us))
      | otherwise         -> return (Left b)

flattenCallTree
  :: CallTree
  -> NormalizeSession CallTree
flattenCallTree c@(CLeaf _) = return c
flattenCallTree (CBranch (nm,(nm',ty,sp,inl,tm)) used) = do
  flattenedUsed   <- mapM flattenCallTree used
  (newUsed,il_ct) <- partitionEithers <$> mapM flattenNode flattenedUsed
  let (toInline,il_used) = unzip il_ct
  newExpr <- case toInline of
               [] -> return tm
               _  -> rewriteExpr ("flattenExpr",flatten) (showDoc nm, substTms toInline tm)
  let allUsed = newUsed ++ concat il_used
  -- inline all components when the resulting expression after flattening
  -- is still considered "cheap". This happens often at the topEntity which
  -- wraps another functions and has some selectors and data-constructors.
  if isCheapFunction newExpr
     then do
        let (toInline',allUsed') = unzip (map goCheap allUsed)
        newExpr' <- rewriteExpr ("flattenCheap",flatten) (showDoc nm, substTms toInline' newExpr)
        return (CBranch (nm,(nm',ty,sp,inl,newExpr')) (concat allUsed'))
     else return (CBranch (nm,(nm',ty,sp,inl,newExpr)) allUsed)
  where
    flatten =
      innerMost (appProp >-> bindConstantVar >-> caseCon >-> reduceConst >-> flattenLet) !->
      topdownSucR topLet

    goCheap (CLeaf   (nm2,(_,_,_,_,e)))    = ((nm2,e),[])
    goCheap (CBranch (nm2,(_,_,_,_,e)) us) = ((nm2,e),us)

callTreeToList
  :: [TmOccName]
  -> CallTree
  -> ([TmOccName],[(TmOccName,(TmName,Type,SrcSpan,InlineSpec,Term))])
callTreeToList visited (CLeaf (nm,bndr))
  | nm `elem` visited = (visited,[])
  | otherwise         = (nm:visited,[(nm,bndr)])
callTreeToList visited (CBranch (nm,bndr) used)
  | nm `elem` visited = (visited,[])
  | otherwise         = (visited',(nm,bndr):(concat others))
  where
    (visited',others) = mapAccumL callTreeToList (nm:visited) used

-- | Clash's clock and reset domain annotations prevent most accidental
-- meta-stability situations. That is, unless the developer uses the
-- functions marked "unsafe", the type system will prevent things like
-- illegal clock domain crossing, or improper use of asynchronous resets.
--
-- However, this all depends on clock and resets being unique. With explicit
-- clocks and resets, it is possible to have multiple clock and reset arguments
-- that are annotated with the same domain. If these arguments aren't connected
-- to the same source, we can still get metastability due to either illegal
-- clock domain crossing, or improper use of asynchronous resets.
--
-- The following situations are reported:
-- * There are 2 or more clock arguments in scope that have the same clock
--   domain annotation.
-- * There are 2 or more reset arguments in scope that have the same reset
--   domain annotation, and at least one of them is an asynchronous reset.
clockResetErrors
  :: SrcSpan
  -> (HashMap TyConOccName TyCon -> Bool -> Type -> Maybe (Either String HWType))
  -> HashMap TyConOccName TyCon
  -> Type
  -> [String]
clockResetErrors sp tyTran tcm ty =
   (Maybe.mapMaybe reportClock clks ++ Maybe.mapMaybe reportResets rsts)
  where
    (args,_)  = splitCoreFunForallTy tcm ty
    (_,args') = partitionEithers args
    hwArgs    = zip (map (unsafeCoreTypeToHWType sp $(curLoc) tyTran tcm False) args') args'
    clks      = groupBy ((==) `on` fst) . sortBy (compare `on` fst)
              $ [ ((nm,i),ty') | (Clock nm i _,ty') <- hwArgs]
    rsts      = groupBy ((==) `on` (fst.fst)) . sortBy (compare `on` (fst.fst))
              $ [ (((nm,i),s),ty') | (Reset nm i s,ty') <- hwArgs]

    reportClock clks'
      | length clks' >= 2
      = Just
      $ concat ["The following clocks:\n"
               ,concatMap (\c -> "* " ++ showDoc (snd c) ++ "\n") clks'
               ,"belong to the same clock domain and should be connected to "
               ,"the same clock source in order to prevent meta-stability "
               ,"issues."
               ]
      | otherwise
      = Nothing

    reportResets rsts'
      | length rsts' >= 2
      , any (\((_,sync),_) -> sync == Asynchronous) rsts'
      = Just
      $ concat ["The following resets:\n"
               ,concatMap (\c -> "* " ++ showDoc (snd c) ++ "\n") rsts'
               ,"belong to the same reset domain, and one or more of these "
               ,"resets is Asynchronous. Ensure that these resets are "
               ,"synchronized in order to prevent meta-stability issues."
               ]
    reportResets _ = Nothing