{-|
Copyright   : (C) 2020, QBayLogic B.V.
License     : BSD2 (see the file LICENSE)
Maintainer  : QBayLogic B.V. <devops@qbaylogic.com>

This module provides the "quoting" part of the partial evaluator, which
traverses a WHNF value, recursively evaluating sub-terms to remove redexes.
-}

{-# LANGUAGE LambdaCase #-}

module Clash.GHC.PartialEval.Quote
  ( quote
  ) where

import Data.Bitraversable

import Clash.Core.DataCon (DataCon)
import Clash.Core.PartialEval.Monad
import Clash.Core.PartialEval.NormalForm
import Clash.Core.Term (Term, PrimInfo, TickInfo, Pat)
import Clash.Core.Type (Type(VarTy))
import Clash.Core.Var (Id, TyVar)

import Clash.GHC.PartialEval.Eval

quote :: Value -> Eval Normal
quote :: Value -> Eval Normal
quote = \case
  VNeutral Neutral Value
n -> Neutral Normal -> Normal
NNeutral (Neutral Normal -> Normal) -> Eval (Neutral Normal) -> Eval Normal
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Neutral Value -> Eval (Neutral Normal)
quoteNeutral Neutral Value
n
  VLiteral Literal
l -> Normal -> Eval Normal
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Literal -> Normal
NLiteral Literal
l)
  VData DataCon
dc Args Value
args LocalEnv
env -> DataCon -> Args Value -> LocalEnv -> Eval Normal
quoteData DataCon
dc Args Value
args LocalEnv
env
  VLam Id
i Term
x LocalEnv
env -> Id -> Term -> LocalEnv -> Eval Normal
quoteLam Id
i Term
x LocalEnv
env
  VTyLam TyVar
i Term
x LocalEnv
env -> TyVar -> Term -> LocalEnv -> Eval Normal
quoteTyLam TyVar
i Term
x LocalEnv
env
  VCast Value
x Type
a Type
b -> Value -> Type -> Type -> Eval Normal
quoteCast Value
x Type
a Type
b
  VTick Value
x TickInfo
tick -> Value -> TickInfo -> Eval Normal
quoteTick Value
x TickInfo
tick
  VThunk Term
x LocalEnv
env -> Term -> LocalEnv -> Eval Normal
quoteThunk Term
x LocalEnv
env

quoteNeutral :: Neutral Value -> Eval (Neutral Normal)
quoteNeutral :: Neutral Value -> Eval (Neutral Normal)
quoteNeutral = \case
  NeVar Id
i -> Neutral Normal -> Eval (Neutral Normal)
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Id -> Neutral Normal
forall a. Id -> Neutral a
NeVar Id
i)
  NePrim PrimInfo
pr Args Value
args -> PrimInfo -> Args Value -> Eval (Neutral Normal)
quoteNePrim PrimInfo
pr Args Value
args
  NeApp Neutral Value
x Value
y -> Neutral Value -> Value -> Eval (Neutral Normal)
quoteNeApp Neutral Value
x Value
y
  NeTyApp Neutral Value
x Type
ty -> Neutral Value -> Type -> Eval (Neutral Normal)
quoteNeTyApp Neutral Value
x Type
ty
  NeLetrec [(Id, Value)]
bs Value
x -> [(Id, Value)] -> Value -> Eval (Neutral Normal)
quoteNeLetrec [(Id, Value)]
bs Value
x
  NeCase Value
x Type
ty [(Pat, Value)]
alts -> Value -> Type -> [(Pat, Value)] -> Eval (Neutral Normal)
quoteNeCase Value
x Type
ty [(Pat, Value)]
alts

quoteArgs :: Args Value -> Eval (Args Normal)
quoteArgs :: Args Value -> Eval (Args Normal)
quoteArgs = (Either Value Type -> Eval (Either Normal Type))
-> Args Value -> Eval (Args Normal)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Value -> Eval Normal)
-> (Type -> Eval Type)
-> Either Value Type
-> Eval (Either Normal Type)
forall (t :: Type -> Type -> Type) (f :: Type -> Type) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse Value -> Eval Normal
quote Type -> Eval Type
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure)

quoteAlts :: [(Pat, Value)] -> Eval [(Pat, Normal)]
quoteAlts :: [(Pat, Value)] -> Eval [(Pat, Normal)]
quoteAlts = ((Pat, Value) -> Eval (Pat, Normal))
-> [(Pat, Value)] -> Eval [(Pat, Normal)]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Pat -> Eval Pat)
-> (Value -> Eval Normal) -> (Pat, Value) -> Eval (Pat, Normal)
forall (t :: Type -> Type -> Type) (f :: Type -> Type) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse Pat -> Eval Pat
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Value -> Eval Normal
quote)

quoteBinders :: [(Id, Value)] -> Eval [(Id, Normal)]
quoteBinders :: [(Id, Value)] -> Eval [(Id, Normal)]
quoteBinders = ((Id, Value) -> Eval (Id, Normal))
-> [(Id, Value)] -> Eval [(Id, Normal)]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Id -> Eval Id)
-> (Value -> Eval Normal) -> (Id, Value) -> Eval (Id, Normal)
forall (t :: Type -> Type -> Type) (f :: Type -> Type) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse Id -> Eval Id
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Value -> Eval Normal
quote)

quoteData :: DataCon -> Args Value -> LocalEnv -> Eval Normal
quoteData :: DataCon -> Args Value -> LocalEnv -> Eval Normal
quoteData DataCon
dc Args Value
args LocalEnv
env = LocalEnv -> Eval Normal -> Eval Normal
forall a. LocalEnv -> Eval a -> Eval a
setLocalEnv LocalEnv
env (DataCon -> Args Normal -> Normal
NData DataCon
dc (Args Normal -> Normal) -> Eval (Args Normal) -> Eval Normal
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Args Value -> Eval (Args Normal)
quoteArgs Args Value
args)

quoteLam :: Id -> Term -> LocalEnv -> Eval Normal
quoteLam :: Id -> Term -> LocalEnv -> Eval Normal
quoteLam Id
i Term
x LocalEnv
env =
  LocalEnv -> Eval Normal -> Eval Normal
forall a. LocalEnv -> Eval a -> Eval a
setLocalEnv LocalEnv
env (Eval Normal -> Eval Normal) -> Eval Normal -> Eval Normal
forall a b. (a -> b) -> a -> b
$ do
    Value
eX <- Value -> Value -> Eval Value
apply (Id -> Term -> LocalEnv -> Value
VLam Id
i Term
x LocalEnv
env) (Neutral Value -> Value
VNeutral (Id -> Neutral Value
forall a. Id -> Neutral a
NeVar Id
i))
    Normal
qX <- Value -> Eval Normal
quote Value
eX

    Normal -> Eval Normal
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Id -> Normal -> LocalEnv -> Normal
NLam Id
i Normal
qX LocalEnv
env)

quoteTyLam :: TyVar -> Term -> LocalEnv -> Eval Normal
quoteTyLam :: TyVar -> Term -> LocalEnv -> Eval Normal
quoteTyLam TyVar
i Term
x LocalEnv
env =
  LocalEnv -> Eval Normal -> Eval Normal
forall a. LocalEnv -> Eval a -> Eval a
setLocalEnv LocalEnv
env (Eval Normal -> Eval Normal) -> Eval Normal -> Eval Normal
forall a b. (a -> b) -> a -> b
$ do
    Value
eX <- Value -> Type -> Eval Value
applyTy (TyVar -> Term -> LocalEnv -> Value
VTyLam TyVar
i Term
x LocalEnv
env) (TyVar -> Type
VarTy TyVar
i)
    Normal
qX <- Value -> Eval Normal
quote Value
eX

    Normal -> Eval Normal
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (TyVar -> Normal -> LocalEnv -> Normal
NTyLam TyVar
i Normal
qX LocalEnv
env)

quoteCast :: Value -> Type -> Type -> Eval Normal
quoteCast :: Value -> Type -> Type -> Eval Normal
quoteCast Value
x Type
a Type
b = Normal -> Type -> Type -> Normal
NCast (Normal -> Type -> Type -> Normal)
-> Eval Normal -> Eval (Type -> Type -> Normal)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Value -> Eval Normal
quote Value
x Eval (Type -> Type -> Normal) -> Eval Type -> Eval (Type -> Normal)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Type -> Eval Type
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Type
a Eval (Type -> Normal) -> Eval Type -> Eval Normal
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Type -> Eval Type
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Type
b

quoteTick :: Value -> TickInfo -> Eval Normal
quoteTick :: Value -> TickInfo -> Eval Normal
quoteTick Value
x TickInfo
tick = Normal -> TickInfo -> Normal
NTick (Normal -> TickInfo -> Normal)
-> Eval Normal -> Eval (TickInfo -> Normal)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Value -> Eval Normal
quote Value
x Eval (TickInfo -> Normal) -> Eval TickInfo -> Eval Normal
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> TickInfo -> Eval TickInfo
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure TickInfo
tick

quoteThunk :: Term -> LocalEnv -> Eval Normal
quoteThunk :: Term -> LocalEnv -> Eval Normal
quoteThunk Term
x LocalEnv
env = LocalEnv -> Eval Normal -> Eval Normal
forall a. LocalEnv -> Eval a -> Eval a
setLocalEnv LocalEnv
env (Term -> Eval Value
eval Term
x Eval Value -> (Value -> Eval Normal) -> Eval Normal
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= Value -> Eval Normal
quote)

quoteNePrim :: PrimInfo -> Args Value -> Eval (Neutral Normal)
quoteNePrim :: PrimInfo -> Args Value -> Eval (Neutral Normal)
quoteNePrim PrimInfo
pr = (Args Normal -> Neutral Normal)
-> Eval (Args Normal) -> Eval (Neutral Normal)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (PrimInfo -> Args Normal -> Neutral Normal
forall a. PrimInfo -> Args a -> Neutral a
NePrim PrimInfo
pr) (Eval (Args Normal) -> Eval (Neutral Normal))
-> (Args Value -> Eval (Args Normal))
-> Args Value
-> Eval (Neutral Normal)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Args Value -> Eval (Args Normal)
quoteArgs

quoteNeApp :: Neutral Value -> Value -> Eval (Neutral Normal)
quoteNeApp :: Neutral Value -> Value -> Eval (Neutral Normal)
quoteNeApp Neutral Value
x Value
y = Neutral Normal -> Normal -> Neutral Normal
forall a. Neutral a -> a -> Neutral a
NeApp (Neutral Normal -> Normal -> Neutral Normal)
-> Eval (Neutral Normal) -> Eval (Normal -> Neutral Normal)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Neutral Value -> Eval (Neutral Normal)
quoteNeutral Neutral Value
x Eval (Normal -> Neutral Normal)
-> Eval Normal -> Eval (Neutral Normal)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Value -> Eval Normal
quote Value
y

quoteNeTyApp :: Neutral Value -> Type -> Eval (Neutral Normal)
quoteNeTyApp :: Neutral Value -> Type -> Eval (Neutral Normal)
quoteNeTyApp Neutral Value
x Type
ty = Neutral Normal -> Type -> Neutral Normal
forall a. Neutral a -> Type -> Neutral a
NeTyApp (Neutral Normal -> Type -> Neutral Normal)
-> Eval (Neutral Normal) -> Eval (Type -> Neutral Normal)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Neutral Value -> Eval (Neutral Normal)
quoteNeutral Neutral Value
x Eval (Type -> Neutral Normal) -> Eval Type -> Eval (Neutral Normal)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Type -> Eval Type
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Type
ty

quoteNeLetrec :: [(Id, Value)] -> Value -> Eval (Neutral Normal)
quoteNeLetrec :: [(Id, Value)] -> Value -> Eval (Neutral Normal)
quoteNeLetrec [(Id, Value)]
bs Value
x =
  [(Id, Value)] -> Eval (Neutral Normal) -> Eval (Neutral Normal)
forall a. [(Id, Value)] -> Eval a -> Eval a
withIds [(Id, Value)]
bs ([(Id, Normal)] -> Normal -> Neutral Normal
forall a. [(Id, a)] -> a -> Neutral a
NeLetrec ([(Id, Normal)] -> Normal -> Neutral Normal)
-> Eval [(Id, Normal)] -> Eval (Normal -> Neutral Normal)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Id, Value)] -> Eval [(Id, Normal)]
quoteBinders [(Id, Value)]
bs Eval (Normal -> Neutral Normal)
-> Eval Normal -> Eval (Neutral Normal)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Value -> Eval Normal
quote Value
x)

quoteNeCase :: Value -> Type -> [(Pat, Value)] -> Eval (Neutral Normal)
quoteNeCase :: Value -> Type -> [(Pat, Value)] -> Eval (Neutral Normal)
quoteNeCase Value
x Type
ty [(Pat, Value)]
alts =
  Normal -> Type -> [(Pat, Normal)] -> Neutral Normal
forall a. a -> Type -> [(Pat, a)] -> Neutral a
NeCase (Normal -> Type -> [(Pat, Normal)] -> Neutral Normal)
-> Eval Normal -> Eval (Type -> [(Pat, Normal)] -> Neutral Normal)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Value -> Eval Normal
quote Value
x Eval (Type -> [(Pat, Normal)] -> Neutral Normal)
-> Eval Type -> Eval ([(Pat, Normal)] -> Neutral Normal)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Type -> Eval Type
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Type
ty Eval ([(Pat, Normal)] -> Neutral Normal)
-> Eval [(Pat, Normal)] -> Eval (Neutral Normal)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> [(Pat, Value)] -> Eval [(Pat, Normal)]
quoteAlts [(Pat, Value)]
alts