{-|
  Copyright  :  (C) 2012-2016, University of Twente,
                    2016-2017, Myrtle Software Ltd,
                    2017-2022, Google Inc.,
                    2021-2022, QBayLogic B.V.
  License    :  BSD2 (see the file LICENSE)
  Maintainer :  QBayLogic B.V. <devops@qbaylogic.com>
  Transformations on case-expressions
-}

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}

module Clash.Normalize.Transformations.Case
  ( caseCase
  , caseCon
  , caseElemNonReachable
  , caseFlat
  , caseLet
  , caseOneAlt
  , elimExistentials
  ) where

import qualified Control.Lens as Lens
import Control.Monad.State.Strict (evalState)
import Data.Bifunctor (second)
import Data.Coerce (coerce)
import qualified Data.Either as Either
import qualified Data.List as List
import qualified Data.List.Extra as List
import qualified Data.Maybe as Maybe
import qualified Data.Primitive.ByteArray as BA
import qualified Data.Text.Extra as Text (showt)
import GHC.Stack (HasCallStack)

#if MIN_VERSION_base(4,15,0)
import GHC.Num.Integer (Integer(..))
#else
import GHC.Integer.GMP.Internals (BigNat(..), Integer(..))
#endif

import Clash.Sized.Internal.BitVector as BV (BitVector, eq#)
import Clash.Sized.Internal.Index as I (Index, eq#)
import Clash.Sized.Internal.Signed as S (Signed, eq#)
import Clash.Sized.Internal.Unsigned as U (Unsigned, eq#)

import Clash.Core.DataCon (DataCon(..))
import Clash.Core.EqSolver
import Clash.Core.FreeVars (freeLocalIds, localVarsDoNotOccurIn)
import Clash.Core.HasType
import Clash.Core.Literal (Literal(..))
import Clash.Core.Name (nameOcc)
import Clash.Core.Pretty (showPpr)
import Clash.Core.Subst
import Clash.Core.Term
  ( Alt, Pat(..), PrimInfo(..), Term(..), collectArgs, collectArgsTicks
  , collectTicks, mkApps, mkTicks, patIds, stripTicks, Bind(..))
import Clash.Core.TyCon (TyConMap)
import Clash.Core.Type (LitTy(..), Type(..), TypeView(..), coreView1, tyView)
import Clash.Core.Util (listToLets, mkInternalVar)
import Clash.Core.VarEnv
  ( InScopeSet, elemVarSet, extendInScopeSet, extendInScopeSetList, mkVarSet
  , unitVarSet, uniqAway)
import Clash.Debug (traceIf)
import Clash.Driver.Types (DebugOpts(dbg_invariants))
import Clash.Netlist.Types (FilteredHWType(..), HWType(..))
import Clash.Netlist.Util (coreTypeToHWType, representableType)
import qualified Clash.Normalize.Primitives as NP (undefined, undefinedX)
import Clash.Normalize.Types (NormRewrite, NormalizeSession)
import Clash.Rewrite.Combinators ((>-!))
import Clash.Rewrite.Types
  ( TransformContext(..), bindings, customReprs, debugOpts, tcCache
  , typeTranslator, workFreeBinders)
import Clash.Rewrite.Util (changed, isFromInt, whnfRW)
import Clash.Rewrite.WorkFree
import Clash.Util (curLoc)

-- | Move a Case-decomposition from the subject of a Case-decomposition to the
-- alternatives
caseCase :: HasCallStack => NormRewrite
caseCase :: NormRewrite
caseCase (TransformContext InScopeSet
is0 Context
_) e :: Term
e@(Case (Term -> Term
stripTicks -> Case Term
scrut Type
alts1Ty [Alt]
alts1) Type
alts2Ty [Alt]
alts2) = do
  Bool
ty1Rep <- (CustomReprs
 -> TyConMap
 -> Type
 -> State HWMap (Maybe (Either String FilteredHWType)))
-> CustomReprs -> Bool -> TyConMap -> Type -> Bool
representableType
    ((CustomReprs
  -> TyConMap
  -> Type
  -> State HWMap (Maybe (Either String FilteredHWType)))
 -> CustomReprs -> Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad
     NormalizeState
     (CustomReprs
      -> TyConMap
      -> Type
      -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     NormalizeState (CustomReprs -> Bool -> TyConMap -> Type -> Bool)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     NormalizeState
     (CustomReprs
      -> TyConMap
      -> Type
      -> State HWMap (Maybe (Either String FilteredHWType)))
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
Lens'
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
typeTranslator
    RewriteMonad
  NormalizeState (CustomReprs -> Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad NormalizeState CustomReprs
-> RewriteMonad NormalizeState (Bool -> TyConMap -> Type -> Bool)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Getting CustomReprs RewriteEnv CustomReprs
-> RewriteMonad NormalizeState CustomReprs
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting CustomReprs RewriteEnv CustomReprs
Getter RewriteEnv CustomReprs
customReprs
    RewriteMonad NormalizeState (Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad NormalizeState Bool
-> RewriteMonad NormalizeState (TyConMap -> Type -> Bool)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Bool -> RewriteMonad NormalizeState Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
False
    RewriteMonad NormalizeState (TyConMap -> Type -> Bool)
-> RewriteMonad NormalizeState TyConMap
-> RewriteMonad NormalizeState (Type -> Bool)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
    RewriteMonad NormalizeState (Type -> Bool)
-> RewriteMonad NormalizeState Type
-> RewriteMonad NormalizeState Bool
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Type -> RewriteMonad NormalizeState Type
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Type
alts1Ty

  -- This is only worth doing if the inner case-expression has a
  -- non-representable alternative type.
  if Bool
ty1Rep then Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e else
    -- Deshadow to prevent accidental capture of free variables of inner
    -- case. Imagine:
    --
    --   case (case a of {x -> x}) of {_ -> x}
    --
    -- 'x' is introduced the inner 'case' and used (as a free variable) in
    -- the outer one. The goal of 'caseCase' is to rewrite cases such that
    -- their subjects aren't cases. This is achieved by 'pushing' the outer
    -- case to all the alternatives of the inner one. Naively doing so in
    -- this example would cause an accidental capture:
    --
    --   case a of {x -> case x of {_ -> x}}
    --
    -- Suddenly, the 'x' in the alternative of the inner case statement
    -- refers to the one introduced by the outer one, instead of being a
    -- free variable. To prevent this, we deshadow the alternatives of the
    -- original inner case. We now end up with:
    --
    --   case a of {x1 -> case x1 of {_ -> x}}
    --
    let newAlts :: [Alt]
newAlts = (Alt -> Alt) -> [Alt] -> [Alt]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Term -> Term) -> Alt -> Alt
forall (p :: Type -> Type -> Type) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (\Term
altE -> Term -> Type -> [Alt] -> Term
Case Term
altE Type
alts2Ty [Alt]
alts2))
                      ((Alt -> Alt) -> [Alt] -> [Alt]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (HasCallStack => InScopeSet -> Alt -> Alt
InScopeSet -> Alt -> Alt
deShadowAlt InScopeSet
is0) [Alt]
alts1)
     in Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> RewriteMonad NormalizeState Term)
-> Term -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ Term -> Type -> [Alt] -> Term
Case Term
scrut Type
alts2Ty [Alt]
newAlts

caseCase TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC caseCase #-}

{-
NOTE: caseOneAlt before caseCon'

When you put a bang on a signal argument:
    f :: Signal d a -> _
    f !x = ...
GHC generates a case like:
    case x of
      _ :- _ -> ...

When this f is inlined in an:
    g = f (pure False)
And clash does its Signal d a ~ a thing we get:
    g = case False of
      _ :- _ -> ...
Because no pattern matches caseCon transforms this into
    g = undefined

By trying caseOneAlt first clash can instead drop the case
and use the body of the single alternative.
-}
caseCon :: HasCallStack => NormRewrite
caseCon :: NormRewrite
caseCon = (Term -> RewriteMonad NormalizeState Term) -> NormRewrite
forall a b. a -> b -> a
const Term -> RewriteMonad NormalizeState Term
caseOneAlt NormRewrite -> NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m -> Rewrite m
>-! HasCallStack => NormRewrite
NormRewrite
caseCon'

-- | Specialize a Case-decomposition (replace by the RHS of an alternative) if
-- the subject is (an application of) a DataCon; or if there is only a single
-- alternative that doesn't reference variables bound by the pattern.
--
-- Note [CaseCon deshadow]
--
-- Imagine:
--
-- @
-- case D (f a b) (g x y) of
--   D a b -> h a
-- @
--
-- rewriting this to:
--
-- @
-- let a = f a b
-- in  h a
-- @
--
-- is very bad because the newly introduced let-binding now captures the free
-- variable 'a' in 'f a b'.
--
-- instead me must rewrite to:
--
-- @
-- let a1 = f a b
-- in  h a1
-- @
caseCon' :: HasCallStack => NormRewrite
caseCon' :: NormRewrite
caseCon' ctx :: TransformContext
ctx@(TransformContext InScopeSet
is0 Context
_) e :: Term
e@(Case Term
subj Type
ty [Alt]
alts) = do
 TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
 case Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks Term
subj of
  -- The subject is an applied data constructor
  (Data DataCon
dc, [Either Term Type]
args, [TickInfo]
ticks) -> case (Alt -> Bool) -> [Alt] -> Maybe Alt
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe a
List.find (Pat -> Bool
equalCon (Pat -> Bool) -> (Alt -> Pat) -> Alt -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alt -> Pat
forall a b. (a, b) -> a
fst) [Alt]
alts of
    Just (DataPat DataCon
_ [TyVar]
tvs [Id]
xs, Term
altE) -> do
     let
      -- Create the substitution environment for all the existential
      -- type variables.
      exTysList :: [(TyVar, Type)]
exTysList = [TyVar] -> [Type] -> [(TyVar, Type)]
forall a b. HasCallStack => [a] -> [b] -> [(a, b)]
List.zipEqual [TyVar]
tvs (Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([TyVar] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length (DataCon -> [TyVar]
dcUnivTyVars DataCon
dc)) ([Either Term Type] -> [Type]
forall a b. [Either a b] -> [b]
Either.rights [Either Term Type]
args))
      exTySubst :: Subst
exTySubst = Subst -> [(TyVar, Type)] -> Subst
extendTvSubstList (InScopeSet -> Subst
mkSubst InScopeSet
is0) [(TyVar, Type)]
exTysList
      -- Apply the type-substitution in all the pattern variables, we need
      -- to do this because we might use them as let-bindings later on,
      -- and they should have the correct type.
      xs1 :: [Id]
xs1 = (Id -> Id) -> [Id] -> [Id]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (Subst -> Id -> Id
forall a. HasCallStack => Subst -> Var a -> Var a
substTyInVar Subst
exTySubst) [Id]
xs
      -- Create an initial set of let-binders for all variables used in the
      -- RHS of the alternative. We might later decide to substitute instead
      -- of let-bind in case the RHS of the let-binder is work-free.
      fvs :: UniqMap (Var Any)
fvs = Getting (UniqMap (Var Any)) Term Id
-> (Id -> UniqMap (Var Any)) -> Term -> UniqMap (Var Any)
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting (UniqMap (Var Any)) Term Id
Fold Term Id
freeLocalIds Id -> UniqMap (Var Any)
forall a. Var a -> UniqMap (Var Any)
unitVarSet Term
altE
      ([(Id, Term)]
binds,[(Id, Term)]
_) = ((Id, Term) -> Bool)
-> [(Id, Term)] -> ([(Id, Term)], [(Id, Term)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
List.partition ((Id -> UniqMap (Var Any) -> Bool
forall a. Var a -> UniqMap (Var Any) -> Bool
`elemVarSet` UniqMap (Var Any)
fvs) (Id -> Bool) -> ((Id, Term) -> Id) -> (Id, Term) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id, Term) -> Id
forall a b. (a, b) -> a
fst)
                ([(Id, Term)] -> ([(Id, Term)], [(Id, Term)]))
-> [(Id, Term)] -> ([(Id, Term)], [(Id, Term)])
forall a b. (a -> b) -> a -> b
$ [Id] -> [Term] -> [(Id, Term)]
forall a b. HasCallStack => [a] -> [b] -> [(a, b)]
List.zipEqual [Id]
xs1 ([Either Term Type] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts [Either Term Type]
args)
      binds1 :: [(Id, Term)]
binds1 = ((Id, Term) -> (Id, Term)) -> [(Id, Term)] -> [(Id, Term)]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Term -> Term) -> (Id, Term) -> (Id, Term)
forall (p :: Type -> Type -> Type) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks)) [(Id, Term)]
binds
     Term
altE1 <-
       case [(Id, Term)]
binds1 of
        [] ->
          -- Apply the type-substitution for the existential type variables
          Term -> RewriteMonad NormalizeState Term
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"caseCon'" Subst
exTySubst Term
altE)
        [(Id, Term)]
_  -> do
          -- See Note [CaseCon deshadow]
          let
            -- Only let-bind expression that perform work.
            is1 :: InScopeSet
is1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList (InScopeSet -> [TyVar] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 [TyVar]
tvs) [Id]
xs1
          ((InScopeSet
is3,[(Id, Term)]
substIds),[Maybe (Id, Term)]
binds2) <- ((InScopeSet, [(Id, Term)])
 -> (Id, Term)
 -> RewriteMonad
      NormalizeState ((InScopeSet, [(Id, Term)]), Maybe (Id, Term)))
-> (InScopeSet, [(Id, Term)])
-> [(Id, Term)]
-> RewriteMonad
     NormalizeState ((InScopeSet, [(Id, Term)]), [Maybe (Id, Term)])
forall (m :: Type -> Type) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
List.mapAccumLM (InScopeSet, [(Id, Term)])
-> (Id, Term)
-> RewriteMonad
     NormalizeState ((InScopeSet, [(Id, Term)]), Maybe (Id, Term))
forall extra (m :: Type -> Type).
MonadState (RewriteState extra) m =>
(InScopeSet, [(Id, Term)])
-> (Id, Term) -> m ((InScopeSet, [(Id, Term)]), Maybe (Id, Term))
newBinder (InScopeSet
is1,[]) [(Id, Term)]
binds1
          let
            -- Create a substitution for all the existential type variables
            -- and the work-free expressions
            subst :: Subst
subst = InScopeSet -> Subst
mkSubst InScopeSet
is3
                      Subst -> [(TyVar, Type)] -> Subst
`extendTvSubstList` [(TyVar, Type)]
exTysList
                      Subst -> [(Id, Term)] -> Subst
`extendIdSubstList` [(Id, Term)]
substIds
            body :: Term
body  = HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"caseCon'" Subst
subst Term
altE
          case [Maybe (Id, Term)] -> [(Id, Term)]
forall a. [Maybe a] -> [a]
Maybe.catMaybes [Maybe (Id, Term)]
binds2 of
            []     -> Term -> RewriteMonad NormalizeState Term
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Term
body
            -- Use listToLets to create a series of non-recursive lets instead
            -- of a recursive group. We know these binders will not form a group.
            [(Id, Term)]
binds3 -> Term -> RewriteMonad NormalizeState Term
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ([(Id, Term)] -> Term -> Term
listToLets [(Id, Term)]
binds3 Term
body)
     Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
altE1
    Maybe Alt
_ -> case [Alt]
alts of
           -- In Core, default patterns always come first, so we match against
           -- that if there is one, and we couldn't match with any of the data
           -- patterns.
           ((Pat
DefaultPat,Term
altE):[Alt]
_) -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
altE
           [Alt]
_ -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> Type -> Term
TyApp (PrimInfo -> Term
Prim PrimInfo
NP.undefined) Type
ty)
    where
      -- Check whether the pattern matches the data constructor
      equalCon :: Pat -> Bool
equalCon (DataPat DataCon
dcPat [TyVar]
_ [Id]
_) = DataCon -> Int
dcTag DataCon
dc Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== DataCon -> Int
dcTag DataCon
dcPat
      equalCon Pat
_ = Bool
False

      -- Decide whether the applied arguments of the data constructor should
      -- be let-bound, or substituted into the alternative. We decide this
      -- based on the fact on whether the argument has the potential to make
      -- the circuit larger than needed if we were to duplicate that argument.
      newBinder :: (InScopeSet, [(Id, Term)])
-> (Id, Term) -> m ((InScopeSet, [(Id, Term)]), Maybe (Id, Term))
newBinder (InScopeSet
isN0, [(Id, Term)]
substN) (Id
x, Term
arg) = do
        BindingMap
bndrs <- Getting BindingMap (RewriteState extra) BindingMap -> m BindingMap
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState extra) BindingMap
forall extra. Lens' (RewriteState extra) BindingMap
bindings
        Lens' (RewriteState extra) (VarEnv Bool)
-> BindingMap -> Term -> m Bool
forall s (m :: Type -> Type).
(HasCallStack, MonadState s m) =>
Lens' s (VarEnv Bool) -> BindingMap -> Term -> m Bool
isWorkFree forall extra. Lens' (RewriteState extra) (VarEnv Bool)
Lens' (RewriteState extra) (VarEnv Bool)
workFreeBinders BindingMap
bndrs Term
arg m Bool
-> (Bool -> m ((InScopeSet, [(Id, Term)]), Maybe (Id, Term)))
-> m ((InScopeSet, [(Id, Term)]), Maybe (Id, Term))
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Bool
True -> ((InScopeSet, [(Id, Term)]), Maybe (Id, Term))
-> m ((InScopeSet, [(Id, Term)]), Maybe (Id, Term))
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ((InScopeSet
isN0, (Id
x, Term
arg)(Id, Term) -> [(Id, Term)] -> [(Id, Term)]
forall a. a -> [a] -> [a]
:[(Id, Term)]
substN), Maybe (Id, Term)
forall a. Maybe a
Nothing)
          Bool
False ->
            let
              x' :: Id
x' = InScopeSet -> Id -> Id
forall a. (Uniquable a, ClashPretty a) => InScopeSet -> a -> a
uniqAway InScopeSet
isN0 Id
x
              isN1 :: InScopeSet
isN1 = InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
isN0 Id
x'
            in
              ((InScopeSet, [(Id, Term)]), Maybe (Id, Term))
-> m ((InScopeSet, [(Id, Term)]), Maybe (Id, Term))
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ((InScopeSet
isN1, (Id
x, Id -> Term
Var Id
x')(Id, Term) -> [(Id, Term)] -> [(Id, Term)]
forall a. a -> [a] -> [a]
:[(Id, Term)]
substN), (Id, Term) -> Maybe (Id, Term)
forall a. a -> Maybe a
Just (Id
x', Term
arg))

  -- The subject is a literal
  (Literal Literal
l,[Either Term Type]
_,[TickInfo]
_) -> case (Alt -> Bool) -> [Alt] -> Maybe Alt
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe a
List.find (Pat -> Bool
equalLit (Pat -> Bool) -> (Alt -> Pat) -> Alt -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alt -> Pat
forall a b. (a, b) -> a
fst) [Alt]
alts of
    Just (LitPat Literal
_,Term
altE) -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
altE
    Maybe Alt
_ -> Term -> Literal -> [Alt] -> RewriteMonad NormalizeState Term
matchLiteralContructor Term
e Literal
l [Alt]
alts
    where
      equalLit :: Pat -> Bool
equalLit (LitPat Literal
l') = Literal
l Literal -> Literal -> Bool
forall a. Eq a => a -> a -> Bool
== Literal
l'
      equalLit Pat
_ = Bool
False

  -- The subject is an applied primitive
  (Prim PrimInfo
_,[Either Term Type]
_,[TickInfo]
_) ->
    -- We try to reduce the applied primitive to WHNF
    Bool
-> TransformContext
-> Term
-> NormRewrite
-> RewriteMonad NormalizeState Term
forall extra.
Bool
-> TransformContext
-> Term
-> Rewrite extra
-> RewriteMonad extra Term
whnfRW Bool
True TransformContext
ctx Term
subj (NormRewrite -> RewriteMonad NormalizeState Term)
-> NormRewrite -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ \TransformContext
ctx1 Term
subj1 -> case Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks Term
subj1 of
      -- WHNF of subject is a literal, try `caseCon` with that
      (Literal Literal
l,[Either Term Type]
_,[TickInfo]
_) -> HasCallStack => NormRewrite
NormRewrite
caseCon TransformContext
ctx1 (Term -> Type -> [Alt] -> Term
Case (Literal -> Term
Literal Literal
l) Type
ty [Alt]
alts)
      -- WHNF of subject is a data-constructor, try `caseCon` with that
      (Data DataCon
_,[Either Term Type]
_,[TickInfo]
_) -> HasCallStack => NormRewrite
NormRewrite
caseCon TransformContext
ctx1 (Term -> Type -> [Alt] -> Term
Case Term
subj1 Type
ty [Alt]
alts)
      -- WHNF of subject is _|_, in the form of `error`: that means that the
      -- entire case-expression is evaluates to _|_
      (Prim PrimInfo
pInfo,Either Term Type
repTy:Either Term Type
_:Either Term Type
callStack:Either Term Type
msg:[Either Term Type]
_,[TickInfo]
ticks)
        | PrimInfo -> Text
primName PrimInfo
pInfo Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"GHC.Err.error" ->
        let e1 :: Term
e1 = Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (PrimInfo -> Term
Prim PrimInfo
pInfo) [TickInfo]
ticks)
                        [Either Term Type
repTy,Type -> Either Term Type
forall a b. b -> Either a b
Right Type
ty,Either Term Type
callStack,Either Term Type
msg]
         in Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
e1
      -- WHNF of subject is _|_, in the form of `absentError`: that means that
      -- the entire case-expression is evaluates to _|_
      (Prim PrimInfo
pInfo,Either Term Type
_:Either Term Type
msgOrCallStack:[Either Term Type]
_,[TickInfo]
ticks)
        | PrimInfo -> Text
primName PrimInfo
pInfo Text -> [Text] -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`elem` [Text
"Control.Exception.Base.absentError"
                                ,Text
"GHC.Prim.Panic.absentError"] ->
        let e1 :: Term
e1 = Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (PrimInfo -> Term
Prim PrimInfo
pInfo) [TickInfo]
ticks)
                        [Type -> Either Term Type
forall a b. b -> Either a b
Right Type
ty,Either Term Type
msgOrCallStack]
        in  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
e1
      -- WHNF of subject is _|_, in the form of  `patError`, `undefined`, or
      -- `errorWithoutStackTrace`: that means the entire case-expression is _|_
      (Prim PrimInfo
pInfo,Either Term Type
repTy:Either Term Type
_:Either Term Type
msgOrCallStack:[Either Term Type]
_,[TickInfo]
ticks)
        | PrimInfo -> Text
primName PrimInfo
pInfo Text -> [Text] -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`elem` [Text
"Control.Exception.Base.patError"
                                ,Text
"GHC.Err.undefined"
                                ,Text
"GHC.Err.errorWithoutStackTrace"] ->
        let e1 :: Term
e1 = Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (PrimInfo -> Term
Prim PrimInfo
pInfo) [TickInfo]
ticks)
                        [Either Term Type
repTy,Type -> Either Term Type
forall a b. b -> Either a b
Right Type
ty,Either Term Type
msgOrCallStack]
        in  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
e1
      -- WHNF of subject is _|_, in the form of our internal _|_-values: that
      -- means the entire case-expression is _|_
      (Prim PrimInfo
pInfo,[Either Term Type
_],[TickInfo]
ticks)
        | PrimInfo -> Text
primName PrimInfo
pInfo Text -> [Text] -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`elem` [ Name -> Text
forall a. Show a => a -> Text
Text.showt 'NP.undefined
                                , Name -> Text
forall a. Show a => a -> Text
Text.showt 'NP.undefinedX ] ->
        let e1 :: Term
e1 = Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (PrimInfo -> Term
Prim PrimInfo
pInfo) [TickInfo]
ticks) [Type -> Either Term Type
forall a b. b -> Either a b
Right Type
ty]
        in Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
e1
      -- WHNF of subject is _|_, in the form of `errorX`: that means that
      -- the entire case-expression is evaluates to _|_
      (Prim PrimInfo
pInfo,Either Term Type
_:Either Term Type
callStack:Either Term Type
msg:[Either Term Type]
_,[TickInfo]
ticks)
        | PrimInfo -> Text
primName PrimInfo
pInfo Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"Clash.XException.errorX"
        -> let e1 :: Term
e1 = Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (PrimInfo -> Term
Prim PrimInfo
pInfo) [TickInfo]
ticks) [Type -> Either Term Type
forall a b. b -> Either a b
Right Type
ty,Either Term Type
callStack,Either Term Type
msg]
            in Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
e1
      -- WHNF of subject is non of the above, so either a variable reference,
      -- or a primitive for which the evaluator doesn't have any evaluation
      -- rules.
      (Term, [Either Term Type], [TickInfo])
_ -> do
        let subjTy :: Type
subjTy = TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
subj
        CustomReprs
-> TyConMap
-> Type
-> State HWMap (Maybe (Either String FilteredHWType))
tran <- Getting
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     NormalizeState
     (CustomReprs
      -> TyConMap
      -> Type
      -> State HWMap (Maybe (Either String FilteredHWType)))
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
Lens'
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
typeTranslator
        CustomReprs
reprs <- Getting CustomReprs RewriteEnv CustomReprs
-> RewriteMonad NormalizeState CustomReprs
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting CustomReprs RewriteEnv CustomReprs
Getter RewriteEnv CustomReprs
customReprs
        case (State HWMap (Either String FilteredHWType)
-> HWMap -> Either String FilteredHWType
forall s a. State s a -> s -> a
`evalState` HWMap
forall a. Monoid a => a
mempty) ((CustomReprs
 -> TyConMap
 -> Type
 -> State HWMap (Maybe (Either String FilteredHWType)))
-> CustomReprs
-> TyConMap
-> Type
-> State HWMap (Either String FilteredHWType)
coreTypeToHWType CustomReprs
-> TyConMap
-> Type
-> State HWMap (Maybe (Either String FilteredHWType))
tran CustomReprs
reprs TyConMap
tcm Type
subjTy) of
          Right (FilteredHWType (Void (Just HWType
hty)) [[(Bool, FilteredHWType)]]
_areVoids)
            | HWType
hty HWType -> [HWType] -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`elem` [Int -> HWType
BitVector Int
0, Int -> HWType
Unsigned Int
0, Int -> HWType
Signed Int
0, Integer -> HWType
Index Integer
1]
            -- If we know that the type of the subject is zero-bits wide and
            -- one of the Clash number types. Then the only valid alternative is
            -- the one that can match on the literal "0", so try 'caseCon' with
            -- that.
            -> HasCallStack => NormRewrite
NormRewrite
caseCon TransformContext
ctx1 (Term -> Type -> [Alt] -> Term
Case (Literal -> Term
Literal (Integer -> Literal
IntegerLiteral Integer
0)) Type
ty [Alt]
alts)
          Either String FilteredHWType
_ -> do
            DebugOpts
opts <- Getting DebugOpts RewriteEnv DebugOpts
-> RewriteMonad NormalizeState DebugOpts
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting DebugOpts RewriteEnv DebugOpts
Getter RewriteEnv DebugOpts
debugOpts
            -- When invariants are being checked, report missing evaluation
            -- rules for the primitive evaluator.
            Bool
-> String
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall a. Bool -> String -> a -> a
traceIf (DebugOpts -> Bool
dbg_invariants DebugOpts
opts Bool -> Bool -> Bool
&& Term -> Bool
isConstant Term
subj)
              (String
"Unmatchable constant as case subject: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
subj String -> String -> String
forall a. [a] -> [a] -> [a]
++
                 String
"\nWHNF is: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
subj1)
              -- Otherwise check whether the entire case-expression has a
              -- single alternative, and pick that one.
              (Term -> RewriteMonad NormalizeState Term
caseOneAlt Term
e)

  -- The subject is a variable
  (Var Id
v, [], [TickInfo]
_) | Type -> Bool
isNum0 (Id -> Type
forall a. HasType a => a -> Type
coreTypeOf Id
v) ->
    -- If we know that the type of the subject is zero-bits wide and
    -- one of the Clash number types. Then the only valid alternative is
    -- the one that can match on the literal "0", so try 'caseCon' with
    -- that.
    HasCallStack => NormRewrite
NormRewrite
caseCon TransformContext
ctx (Term -> Type -> [Alt] -> Term
Case (Literal -> Term
Literal (Integer -> Literal
IntegerLiteral Integer
0)) Type
ty [Alt]
alts)
   where
    isNum0 :: Type -> Bool
isNum0 (Type -> TypeView
tyView -> TyConApp (TyConName -> Text
forall a. Name a -> Text
nameOcc -> Text
tcNm) [Type
arg])
      | Text
tcNm Text -> [Text] -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`elem`
        [ Name -> Text
forall a. Show a => a -> Text
Text.showt ''BitVector
        , Name -> Text
forall a. Show a => a -> Text
Text.showt ''Signed
        , Name -> Text
forall a. Show a => a -> Text
Text.showt ''Unsigned
        ]
      = Integer -> Type -> Bool
isLitX Integer
0 Type
arg
      | Text
tcNm Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Name -> Text
forall a. Show a => a -> Text
Text.showt ''Index
      = Integer -> Type -> Bool
isLitX Integer
1 Type
arg
    isNum0 (TyConMap -> Type -> Maybe Type
coreView1 TyConMap
tcm -> Just Type
t) = Type -> Bool
isNum0 Type
t
    isNum0 Type
_ = Bool
False

    isLitX :: Integer -> Type -> Bool
isLitX Integer
n (LitTy (NumTy Integer
m)) = Integer
n Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
m
    isLitX Integer
n (TyConMap -> Type -> Maybe Type
coreView1 TyConMap
tcm -> Just Type
t) = Integer -> Type -> Bool
isLitX Integer
n Type
t
    isLitX Integer
_ Type
_ = Bool
False

  -- Otherwise check whether the entire case-expression has a single
  -- alternative, and pick that one.
  (Term, [Either Term Type], [TickInfo])
_ -> Term -> RewriteMonad NormalizeState Term
caseOneAlt Term
e

caseCon' TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC caseCon' #-}

{- [Note: Name re-creation]
The names of heap bound variables are safely generate with mkUniqSystemId in Clash.Core.Evaluator.newLetBinding.
But only their uniqs end up in the heap, not the complete names.
So we use mkUnsafeSystemName to recreate the same Name.
-}

matchLiteralContructor
  :: Term
  -> Literal
  -> [Alt]
  -> NormalizeSession Term
matchLiteralContructor :: Term -> Literal -> [Alt] -> RewriteMonad NormalizeState Term
matchLiteralContructor Term
c (IntegerLiteral Integer
l) [Alt]
alts = [Alt] -> RewriteMonad NormalizeState Term
forall extra. [Alt] -> RewriteMonad extra Term
go ([Alt] -> [Alt]
forall a. [a] -> [a]
reverse [Alt]
alts)
 where
  go :: [Alt] -> RewriteMonad extra Term
go [(Pat
DefaultPat,Term
e)] = Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
e
  go ((DataPat DataCon
dc [] [Id
x],Term
e):[Alt]
alts')
    | DataCon -> Int
dcTag DataCon
dc Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
    , Integer
l Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= ((-Integer
2)Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
63::Int)) Bool -> Bool -> Bool
&&  Integer
l Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
2Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
63::Int)
    = let fvs :: UniqMap (Var Any)
fvs = Getting (UniqMap (Var Any)) Term Id
-> (Id -> UniqMap (Var Any)) -> Term -> UniqMap (Var Any)
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting (UniqMap (Var Any)) Term Id
Fold Term Id
freeLocalIds Id -> UniqMap (Var Any)
forall a. Var a -> UniqMap (Var Any)
unitVarSet Term
e
          bind :: Bind Term
bind = Id -> Term -> Bind Term
forall a. Id -> a -> Bind a
NonRec Id
x (Literal -> Term
Literal (Integer -> Literal
IntLiteral Integer
l))
       in if Id
x Id -> UniqMap (Var Any) -> Bool
forall a. Var a -> UniqMap (Var Any) -> Bool
`elemVarSet` UniqMap (Var Any)
fvs
            then Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed (Bind Term -> Term -> Term
Let Bind Term
bind Term
e)
            else Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
e
    | DataCon -> Int
dcTag DataCon
dc Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2
    , Integer
l Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
2Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
63::Int)
#if MIN_VERSION_base(4,15,0)
    = let !(IP ba) = l
#else
    = let !(Jp# !(BN# ByteArray#
ba)) = Integer
l
#endif
          ba' :: ByteArray
ba'       = ByteArray# -> ByteArray
BA.ByteArray ByteArray#
ba
          fvs :: UniqMap (Var Any)
fvs       = Getting (UniqMap (Var Any)) Term Id
-> (Id -> UniqMap (Var Any)) -> Term -> UniqMap (Var Any)
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting (UniqMap (Var Any)) Term Id
Fold Term Id
freeLocalIds Id -> UniqMap (Var Any)
forall a. Var a -> UniqMap (Var Any)
unitVarSet Term
e
          bind :: Bind Term
bind      = Id -> Term -> Bind Term
forall a. Id -> a -> Bind a
NonRec Id
x (Literal -> Term
Literal (ByteArray -> Literal
ByteArrayLiteral ByteArray
ba'))
       in if Id
x Id -> UniqMap (Var Any) -> Bool
forall a. Var a -> UniqMap (Var Any) -> Bool
`elemVarSet` UniqMap (Var Any)
fvs
            then Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed (Bind Term -> Term -> Term
Let Bind Term
bind Term
e)
            else Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
e
    | DataCon -> Int
dcTag DataCon
dc Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
3
    , Integer
l Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< ((-Integer
2)Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
63::Int))
#if MIN_VERSION_base(4,15,0)
    = let !(IN ba) = l
#else
    = let !(Jn# !(BN# ByteArray#
ba)) = Integer
l
#endif
          ba' :: ByteArray
ba'       = ByteArray# -> ByteArray
BA.ByteArray ByteArray#
ba
          fvs :: UniqMap (Var Any)
fvs       = Getting (UniqMap (Var Any)) Term Id
-> (Id -> UniqMap (Var Any)) -> Term -> UniqMap (Var Any)
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting (UniqMap (Var Any)) Term Id
Fold Term Id
freeLocalIds Id -> UniqMap (Var Any)
forall a. Var a -> UniqMap (Var Any)
unitVarSet Term
e
          bind :: Bind Term
bind      = Id -> Term -> Bind Term
forall a. Id -> a -> Bind a
NonRec Id
x (Literal -> Term
Literal (ByteArray -> Literal
ByteArrayLiteral ByteArray
ba'))
       in if Id
x Id -> UniqMap (Var Any) -> Bool
forall a. Var a -> UniqMap (Var Any) -> Bool
`elemVarSet` UniqMap (Var Any)
fvs
            then Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed (Bind Term -> Term -> Term
Let Bind Term
bind Term
e)
            else Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
e
    | Bool
otherwise
    = [Alt] -> RewriteMonad extra Term
go [Alt]
alts'
  go ((LitPat Literal
l', Term
e):[Alt]
alts')
    | Integer -> Literal
IntegerLiteral Integer
l Literal -> Literal -> Bool
forall a. Eq a => a -> a -> Bool
== Literal
l'
    = Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
e
    | Bool
otherwise
    = [Alt] -> RewriteMonad extra Term
go [Alt]
alts'
  go [Alt]
_ = String -> RewriteMonad extra Term
forall a. HasCallStack => String -> a
error (String -> RewriteMonad extra Term)
-> String -> RewriteMonad extra Term
forall a b. (a -> b) -> a -> b
$ $(String
curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"Report as bug: caseCon error: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
c

matchLiteralContructor Term
c (NaturalLiteral Integer
l) [Alt]
alts = [Alt] -> RewriteMonad NormalizeState Term
forall extra. [Alt] -> RewriteMonad extra Term
go ([Alt] -> [Alt]
forall a. [a] -> [a]
reverse [Alt]
alts)
 where
  go :: [Alt] -> RewriteMonad extra Term
go [(Pat
DefaultPat,Term
e)] = Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
e
  go ((DataPat DataCon
dc [] [Id
x],Term
e):[Alt]
alts')
    | DataCon -> Int
dcTag DataCon
dc Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
    , Integer
l Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
0 Bool -> Bool -> Bool
&& Integer
l Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
2Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
64::Int)
    = let fvs :: UniqMap (Var Any)
fvs       = Getting (UniqMap (Var Any)) Term Id
-> (Id -> UniqMap (Var Any)) -> Term -> UniqMap (Var Any)
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting (UniqMap (Var Any)) Term Id
Fold Term Id
freeLocalIds Id -> UniqMap (Var Any)
forall a. Var a -> UniqMap (Var Any)
unitVarSet Term
e
          bind :: Bind Term
bind      = Id -> Term -> Bind Term
forall a. Id -> a -> Bind a
NonRec Id
x (Literal -> Term
Literal (Integer -> Literal
WordLiteral Integer
l))
       in if Id
x Id -> UniqMap (Var Any) -> Bool
forall a. Var a -> UniqMap (Var Any) -> Bool
`elemVarSet` UniqMap (Var Any)
fvs
            then Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed (Bind Term -> Term -> Term
Let Bind Term
bind Term
e)
            else Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
e
    | DataCon -> Int
dcTag DataCon
dc Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2
    , Integer
l Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
2Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
64::Int)
#if MIN_VERSION_base(4,15,0)
    = let !(IP ba) = l
#else
    = let !(Jp# !(BN# ByteArray#
ba)) = Integer
l
#endif
          ba' :: ByteArray
ba'       = ByteArray# -> ByteArray
BA.ByteArray ByteArray#
ba
          fvs :: UniqMap (Var Any)
fvs       = Getting (UniqMap (Var Any)) Term Id
-> (Id -> UniqMap (Var Any)) -> Term -> UniqMap (Var Any)
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting (UniqMap (Var Any)) Term Id
Fold Term Id
freeLocalIds Id -> UniqMap (Var Any)
forall a. Var a -> UniqMap (Var Any)
unitVarSet Term
e
          bind :: Bind Term
bind      = Id -> Term -> Bind Term
forall a. Id -> a -> Bind a
NonRec Id
x (Literal -> Term
Literal (ByteArray -> Literal
ByteArrayLiteral ByteArray
ba'))
       in if Id
x Id -> UniqMap (Var Any) -> Bool
forall a. Var a -> UniqMap (Var Any) -> Bool
`elemVarSet` UniqMap (Var Any)
fvs
            then Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed (Bind Term -> Term -> Term
Let Bind Term
bind Term
e)
            else Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
e
    | Bool
otherwise
    = [Alt] -> RewriteMonad extra Term
go [Alt]
alts'
  go ((LitPat Literal
l', Term
e):[Alt]
alts')
    | Integer -> Literal
NaturalLiteral Integer
l Literal -> Literal -> Bool
forall a. Eq a => a -> a -> Bool
== Literal
l'
    = Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
e
    | Bool
otherwise
    = [Alt] -> RewriteMonad extra Term
go [Alt]
alts'
  go [Alt]
_ = String -> RewriteMonad extra Term
forall a. HasCallStack => String -> a
error (String -> RewriteMonad extra Term)
-> String -> RewriteMonad extra Term
forall a b. (a -> b) -> a -> b
$ $(String
curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"Report as bug: caseCon error: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
c

matchLiteralContructor Term
_ Literal
_ ((Pat
DefaultPat,Term
e):[Alt]
_) = Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
e
matchLiteralContructor Term
c Literal
_ [Alt]
_ =
  String -> RewriteMonad NormalizeState Term
forall a. HasCallStack => String -> a
error (String -> RewriteMonad NormalizeState Term)
-> String -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ $(String
curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"Report as bug: caseCon error: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
c
{-# SCC matchLiteralContructor #-}

-- | Remove non-reachable alternatives. For example, consider:
--
-- @
-- data STy ty where
--   SInt :: Int -> STy Int
--   SBool :: Bool -> STy Bool
--
-- f :: STy ty -> ty
-- f (SInt b) = b + 1
-- f (SBool True) = False
-- f (SBool False) = True
-- {\-\# NOINLINE f \#-\}
--
-- g :: STy Int -> Int
-- g = f
-- @
--
-- @f@ is always specialized on @STy Int@. The SBool alternatives are therefore
-- unreachable. Additional information can be found at:
-- https://github.com/clash-lang/clash-compiler/pull/465
caseElemNonReachable :: HasCallStack => NormRewrite
caseElemNonReachable :: NormRewrite
caseElemNonReachable TransformContext
_ case0 :: Term
case0@(Case Term
scrut Type
altsTy [Alt]
alts0) = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache

  let ([Alt]
altsAbsurd, [Alt]
altsOther) = (Alt -> Bool) -> [Alt] -> ([Alt], [Alt])
forall a. (a -> Bool) -> [a] -> ([a], [a])
List.partition (TyConMap -> Pat -> Bool
isAbsurdPat TyConMap
tcm (Pat -> Bool) -> (Alt -> Pat) -> Alt -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alt -> Pat
forall a b. (a, b) -> a
fst) [Alt]
alts0
  case [Alt]
altsAbsurd of
    [] -> Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
case0
    [Alt]
_  -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> RewriteMonad NormalizeState Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Term -> RewriteMonad NormalizeState Term
caseOneAlt (Term -> Type -> [Alt] -> Term
Case Term
scrut Type
altsTy [Alt]
altsOther)

caseElemNonReachable TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC caseElemNonReachable #-}

-- | Flatten ridiculous case-statements generated by GHC
--
-- For case-statements in haskell of the form:
--
-- @
-- f :: Unsigned 4 -> Unsigned 4
-- f x = case x of
--   0 -> 3
--   1 -> 2
--   2 -> 1
--   3 -> 0
-- @
--
-- GHC generates Core that looks like:
--
-- @
-- f = \\(x :: Unsigned 4) -> case x == fromInteger 3 of
--                             False -> case x == fromInteger 2 of
--                               False -> case x == fromInteger 1 of
--                                 False -> case x == fromInteger 0 of
--                                   False -> error "incomplete case"
--                                   True  -> fromInteger 3
--                                 True -> fromInteger 2
--                               True -> fromInteger 1
--                             True -> fromInteger 0
-- @
--
-- Which would result in a priority decoder circuit where a normal decoder
-- circuit was desired.
--
-- This transformation transforms the above Core to the saner:
--
-- @
-- f = \\(x :: Unsigned 4) -> case x of
--        _ -> error "incomplete case"
--        0 -> fromInteger 3
--        1 -> fromInteger 2
--        2 -> fromInteger 1
--        3 -> fromInteger 0
-- @
caseFlat :: HasCallStack => NormRewrite
caseFlat :: NormRewrite
caseFlat (TransformContext InScopeSet
is0 Context
_) e :: Term
e@(Case (Term -> Maybe (Term, Term)
collectEqArgs -> Just (Term
scrut',Term
val)) Type
ty [Alt]
_) =
  case Term -> Term -> Maybe [Alt]
collectFlat Term
scrut' Term
e of
    Just [Alt]
alts' -> case Term -> (Term, [Either Term Type])
collectArgs Term
val of
      -- When we're pattern matching on `Int`, extract the `Int#` first before
      -- we do the Literal matching branches.
      (Data DataCon
dc,[Either Term Type]
_)
        | Name DataCon -> Text
forall a. Name a -> Text
nameOcc (DataCon -> Name DataCon
dcName DataCon
dc) Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"GHC.Types.I#"
        , [Type
argTy] <- DataCon -> [Type]
dcArgTys DataCon
dc
        -> do
          Id
wild <- InScopeSet -> Text -> Type -> RewriteMonad NormalizeState Id
forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet -> Text -> Type -> m Id
mkInternalVar InScopeSet
is0 Text
"wild" Type
argTy
          Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> Type -> [Alt] -> Term
Case Term
scrut' Type
ty
                    [(DataCon -> [TyVar] -> [Id] -> Pat
DataPat DataCon
dc [] [Id
wild]
                     ,Term -> Type -> [Alt] -> Term
Case (Id -> Term
Var Id
wild) Type
ty ([Alt] -> Alt
forall a. [a] -> a
last [Alt]
alts' Alt -> [Alt] -> [Alt]
forall a. a -> [a] -> [a]
: [Alt] -> [Alt]
forall a. [a] -> [a]
init [Alt]
alts'))])
      (Term, [Either Term Type])
_ -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> Type -> [Alt] -> Term
Case Term
scrut' Type
ty ([Alt] -> Alt
forall a. [a] -> a
last [Alt]
alts' Alt -> [Alt] -> [Alt]
forall a. a -> [a] -> [a]
: [Alt] -> [Alt]
forall a. [a] -> [a]
init [Alt]
alts'))
    Maybe [Alt]
Nothing -> Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e

caseFlat TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC caseFlat #-}

collectFlat :: Term -> Term -> Maybe [Alt]
collectFlat :: Term -> Term -> Maybe [Alt]
collectFlat Term
scrut (Case (Term -> Maybe (Term, Term)
collectEqArgs -> Just (Term
scrut', Term
val)) Type
_ty [Alt
lAlt,Alt
rAlt])
  | Term
scrut' Term -> Term -> Bool
forall a. Eq a => a -> a -> Bool
== Term
scrut
  = case Term -> (Term, [Either Term Type])
collectArgs Term
val of
      (Prim PrimInfo
p,[Either Term Type]
args') | Text -> Bool
isFromInt (PrimInfo -> Text
primName PrimInfo
p) ->
        Either Term Type -> Maybe [Alt]
forall b. Either Term b -> Maybe [Alt]
go ([Either Term Type] -> Either Term Type
forall a. [a] -> a
last [Either Term Type]
args')
      (Data DataCon
dc,[Either Term Type]
args') | Name DataCon -> Text
forall a. Name a -> Text
nameOcc (DataCon -> Name DataCon
dcName DataCon
dc) Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"GHC.Types.I#" ->
        Either Term Type -> Maybe [Alt]
forall b. Either Term b -> Maybe [Alt]
go ([Either Term Type] -> Either Term Type
forall a. [a] -> a
last [Either Term Type]
args')
      (Term, [Either Term Type])
_ -> Maybe [Alt]
forall a. Maybe a
Nothing
  where
    go :: Either Term b -> Maybe [Alt]
go (Left (Literal Literal
i)) = case (Alt
lAlt,Alt
rAlt) of
              ((Pat
pl,Term
el),(Pat
pr,Term
er))
                | Pat -> Bool
isFalseDcPat Pat
pl Bool -> Bool -> Bool
|| Pat -> Bool
isTrueDcPat Pat
pr ->
                   case Term -> Term -> Maybe [Alt]
collectFlat Term
scrut Term
el of
                     Just [Alt]
alts' -> [Alt] -> Maybe [Alt]
forall a. a -> Maybe a
Just ((Literal -> Pat
LitPat Literal
i, Term
er) Alt -> [Alt] -> [Alt]
forall a. a -> [a] -> [a]
: [Alt]
alts')
                     Maybe [Alt]
Nothing    -> [Alt] -> Maybe [Alt]
forall a. a -> Maybe a
Just [(Literal -> Pat
LitPat Literal
i, Term
er)
                                        ,(Pat
DefaultPat, Term
el)
                                        ]
                | Bool
otherwise ->
                   case Term -> Term -> Maybe [Alt]
collectFlat Term
scrut Term
er of
                     Just [Alt]
alts' -> [Alt] -> Maybe [Alt]
forall a. a -> Maybe a
Just ((Literal -> Pat
LitPat Literal
i, Term
el) Alt -> [Alt] -> [Alt]
forall a. a -> [a] -> [a]
: [Alt]
alts')
                     Maybe [Alt]
Nothing    -> [Alt] -> Maybe [Alt]
forall a. a -> Maybe a
Just [(Literal -> Pat
LitPat Literal
i, Term
el)
                                        ,(Pat
DefaultPat, Term
er)
                                        ]
    go Either Term b
_ = Maybe [Alt]
forall a. Maybe a
Nothing

    isFalseDcPat :: Pat -> Bool
isFalseDcPat (DataPat DataCon
p [TyVar]
_ [Id]
_)
      = ((Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"GHC.Types.False") (Text -> Bool) -> (DataCon -> Text) -> DataCon -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name DataCon -> Text
forall a. Name a -> Text
nameOcc (Name DataCon -> Text)
-> (DataCon -> Name DataCon) -> DataCon -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DataCon -> Name DataCon
dcName) DataCon
p
    isFalseDcPat Pat
_ = Bool
False

    isTrueDcPat :: Pat -> Bool
isTrueDcPat (DataPat DataCon
p [TyVar]
_ [Id]
_)
      = ((Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"GHC.Types.True") (Text -> Bool) -> (DataCon -> Text) -> DataCon -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name DataCon -> Text
forall a. Name a -> Text
nameOcc (Name DataCon -> Text)
-> (DataCon -> Name DataCon) -> DataCon -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DataCon -> Name DataCon
dcName) DataCon
p
    isTrueDcPat Pat
_ = Bool
False

collectFlat Term
_ Term
_ = Maybe [Alt]
forall a. Maybe a
Nothing
{-# SCC collectFlat #-}

collectEqArgs :: Term -> Maybe (Term,Term)
collectEqArgs :: Term -> Maybe (Term, Term)
collectEqArgs f :: Term
f@(Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks -> (Prim PrimInfo
p, [Either Term Type]
args, [TickInfo]
ticks))
  | Text
nm Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Name -> Text
forall a. Show a => a -> Text
Text.showt 'BV.eq#
    = case [Either Term Type]
args of
        [Either Term Type
_,Either Term Type
_,Left Term
scrut,Left Term
val] -> (Term, Term) -> Maybe (Term, Term)
forall a. a -> Maybe a
Just (Term -> [TickInfo] -> Term
mkTicks Term
scrut [TickInfo]
ticks,Term
val)
        [Either Term Type]
_ -> String -> Maybe (Term, Term)
forall a. HasCallStack => String -> a
error (String
"collectEqArgs: BV.eq expects 4 arguments, but got: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
f)
  | Text
nm Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Name -> Text
forall a. Show a => a -> Text
Text.showt 'I.eq#  Bool -> Bool -> Bool
||
    Text
nm Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Name -> Text
forall a. Show a => a -> Text
Text.showt 'S.eq# Bool -> Bool -> Bool
||
    Text
nm Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Name -> Text
forall a. Show a => a -> Text
Text.showt 'U.eq#
    = case [Either Term Type]
args of
        [Either Term Type
_,Left Term
scrut,Left Term
val] -> (Term, Term) -> Maybe (Term, Term)
forall a. a -> Maybe a
Just (Term -> [TickInfo] -> Term
mkTicks Term
scrut [TickInfo]
ticks,Term
val)
        [Either Term Type]
_ -> String -> Maybe (Term, Term)
forall a. HasCallStack => String -> a
error (Text -> String
forall a. Show a => a -> String
show Text
nm String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" expects 3 arguments, but got: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
f)
  | Text
nm Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"GHC.Classes.eqInt"
    = case [Either Term Type]
args of
        [Left Term
scrut,Left Term
val] -> (Term, Term) -> Maybe (Term, Term)
forall a. a -> Maybe a
Just (Term -> [TickInfo] -> Term
mkTicks Term
scrut [TickInfo]
ticks,Term
val)
        [Either Term Type]
_ -> String -> Maybe (Term, Term)
forall a. HasCallStack => String -> a
error (String
"eqInt expects 2 arguments, but got: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
f)
 where
  nm :: Text
nm = PrimInfo -> Text
primName PrimInfo
p

collectEqArgs Term
_ = Maybe (Term, Term)
forall a. Maybe a
Nothing

-- | Lift the let-bindings out of the subject of a Case-decomposition
caseLet :: HasCallStack => NormRewrite
caseLet :: NormRewrite
caseLet (TransformContext InScopeSet
is0 Context
_) (Case (Term -> (Term, [TickInfo])
collectTicks -> (Let Bind Term
xes Term
e,[TickInfo]
ticks)) Type
ty [Alt]
alts) = do
  -- Note [CaseLet deshadow]
  -- Imagine
  --
  -- @
  -- case (let x = u in e) of {p -> a}
  -- @
  --
  -- where `a` has a free variable named `x`.
  --
  -- Simply transforming the above to:
  --
  -- @
  -- let x = u in case e of {p -> a}
  -- @
  --
  -- would be very bad, because now the let-binding captures the free x variable
  -- in a.
  --
  -- We must therefor rename `x` so that it doesn't capture the free variables
  -- in the alternative:
  --
  -- @
  -- let x1 = u[x:=x1] in case e[x:=x1] of {p -> a}
  -- @
  --
  -- It is safe to over-approximate the free variables in `a` by simply taking
  -- the current InScopeSet.
  let (Bind Term
xes1,Term
e1) = HasCallStack =>
InScopeSet -> Bind Term -> Term -> (Bind Term, Term)
InScopeSet -> Bind Term -> Term -> (Bind Term, Term)
deshadowLetExpr InScopeSet
is0 Bind Term
xes Term
e
  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Bind Term -> Term -> Term
Let ((Term -> Term) -> Bind Term -> Bind Term
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) Bind Term
xes1)
                  (Term -> Type -> [Alt] -> Term
Case (Term -> [TickInfo] -> Term
mkTicks Term
e1 [TickInfo]
ticks) Type
ty [Alt]
alts))

caseLet TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC caseLet #-}

caseOneAlt :: Term -> NormalizeSession Term
caseOneAlt :: Term -> RewriteMonad NormalizeState Term
caseOneAlt e :: Term
e@(Case Term
_ Type
_ [(Pat
pat,Term
altE)]) =
  case Pat
pat of
    Pat
DefaultPat -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
altE
    LitPat Literal
_ -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
altE
    DataPat DataCon
_ [TyVar]
tvs [Id]
xs
      | ([TyVar] -> [Var Any]
coerce [TyVar]
tvs [Var Any] -> [Var Any] -> [Var Any]
forall a. [a] -> [a] -> [a]
++ [Id] -> [Var Any]
coerce [Id]
xs) [Var Any] -> Term -> Bool
forall a. [Var a] -> Term -> Bool
`localVarsDoNotOccurIn` Term
altE
      -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
altE
      | Bool
otherwise
      -> Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e

caseOneAlt (Case Term
_ Type
_ ((Pat
pat,Term
alt):alts :: [Alt]
alts@(Alt
_:[Alt]
_)))
  | (Alt -> Bool) -> [Alt] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all ((Term -> Term -> Bool
forall a. Eq a => a -> a -> Bool
== Term
alt) (Term -> Bool) -> (Alt -> Term) -> Alt -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alt -> Term
forall a b. (a, b) -> b
snd) [Alt]
alts
  , ([TyVar]
tvs,[Id]
xs) <- Pat -> ([TyVar], [Id])
patIds Pat
pat
  , ([TyVar] -> [Var Any]
coerce [TyVar]
tvs [Var Any] -> [Var Any] -> [Var Any]
forall a. [a] -> [a] -> [a]
++ [Id] -> [Var Any]
coerce [Id]
xs) [Var Any] -> Term -> Bool
forall a. [Var a] -> Term -> Bool
`localVarsDoNotOccurIn` Term
alt
  = Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
alt

caseOneAlt Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC caseOneAlt #-}

-- | Tries to eliminate existentials by using heuristics to determine what the
-- existential should be. For example, consider Vec:
--
--    data Vec :: Nat -> Type -> Type where
--      Nil       :: Vec 0 a
--      Cons x xs :: a -> Vec n a -> Vec (n + 1) a
--
-- Thus, 'null' (annotated with existentials) could look like:
--
--    null :: forall n . Vec n Bool -> Bool
--    null v =
--      case v of
--        Nil  {n ~ 0}                                     -> True
--        Cons {n1:Nat} {n~n1+1} (x :: a) (xs :: Vec n1 a) -> False
--
-- When it's applied to a vector of length 5, this becomes:
--
--    null :: Vec 5 Bool -> Bool
--    null v =
--      case v of
--        Nil  {5 ~ 0}                                     -> True
--        Cons {n1:Nat} {5~n1+1} (x :: a) (xs :: Vec n1 a) -> False
--
-- This function solves 'n1' and replaces every occurrence with its solution. A
-- very limited number of solutions are currently recognized: only adds (such
-- as in the example) will be solved.
elimExistentials :: HasCallStack => NormRewrite
elimExistentials :: NormRewrite
elimExistentials (TransformContext InScopeSet
is0 Context
_) (Case Term
scrut Type
altsTy [Alt]
alts0) = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
  [Alt]
alts1 <- (Alt -> RewriteMonad NormalizeState Alt)
-> [Alt] -> RewriteMonad NormalizeState [Alt]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (InScopeSet -> TyConMap -> Alt -> RewriteMonad NormalizeState Alt
go InScopeSet
is0 TyConMap
tcm) [Alt]
alts0
  Term -> RewriteMonad NormalizeState Term
caseOneAlt (Term -> Type -> [Alt] -> Term
Case Term
scrut Type
altsTy [Alt]
alts1)
 where
    -- Eliminate free type variables if possible
    go :: InScopeSet -> TyConMap -> Alt -> NormalizeSession Alt
    go :: InScopeSet -> TyConMap -> Alt -> RewriteMonad NormalizeState Alt
go InScopeSet
is2 TyConMap
tcm alt :: Alt
alt@(pat :: Pat
pat@(DataPat DataCon
dc [TyVar]
exts0 [Id]
xs0), Term
term0) =
      case TyConMap -> UniqMap (Var Any) -> [(Type, Type)] -> [(TyVar, Type)]
solveNonAbsurds TyConMap
tcm ([TyVar] -> UniqMap (Var Any)
forall a. [Var a] -> UniqMap (Var Any)
mkVarSet [TyVar]
exts0) (TyConMap -> Pat -> [(Type, Type)]
patEqs TyConMap
tcm Pat
pat) of
        -- No equations solved:
        [] -> Alt -> RewriteMonad NormalizeState Alt
forall (m :: Type -> Type) a. Monad m => a -> m a
return Alt
alt
        -- One or more equations solved:
        [(TyVar, Type)]
sols ->
          Alt -> RewriteMonad NormalizeState Alt
forall a extra. a -> RewriteMonad extra a
changed (Alt -> RewriteMonad NormalizeState Alt)
-> RewriteMonad NormalizeState Alt
-> RewriteMonad NormalizeState Alt
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< InScopeSet -> TyConMap -> Alt -> RewriteMonad NormalizeState Alt
go InScopeSet
is2 TyConMap
tcm (DataCon -> [TyVar] -> [Id] -> Pat
DataPat DataCon
dc [TyVar]
exts1 [Id]
xs1, Term
term1)
          where
            -- Substitute solution in existentials and applied types
            is3 :: InScopeSet
is3 = InScopeSet -> [TyVar] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is2 [TyVar]
exts0
            xs1 :: [Id]
xs1 = (Id -> Id) -> [Id] -> [Id]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (Subst -> Id -> Id
forall a. HasCallStack => Subst -> Var a -> Var a
substTyInVar (Subst -> [(TyVar, Type)] -> Subst
extendTvSubstList (InScopeSet -> Subst
mkSubst InScopeSet
is3) [(TyVar, Type)]
sols)) [Id]
xs0
            exts1 :: [TyVar]
exts1 = HasCallStack => InScopeSet -> [TyVar] -> [(TyVar, Type)] -> [TyVar]
InScopeSet -> [TyVar] -> [(TyVar, Type)] -> [TyVar]
substInExistentialsList InScopeSet
is2 [TyVar]
exts0 [(TyVar, Type)]
sols

            -- Substitute solution in term.
            is4 :: InScopeSet
is4 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is3 [Id]
xs1
            subst :: Subst
subst = Subst -> [(TyVar, Type)] -> Subst
extendTvSubstList (InScopeSet -> Subst
mkSubst InScopeSet
is4) [(TyVar, Type)]
sols
            term1 :: Term
term1 = HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"Replacing tyVar due to solved eq" Subst
subst Term
term0

    go InScopeSet
_ TyConMap
_ Alt
alt = Alt -> RewriteMonad NormalizeState Alt
forall (m :: Type -> Type) a. Monad m => a -> m a
return Alt
alt

elimExistentials TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC elimExistentials #-}