{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

-- |
-- Module      : Jikka.Core.Language.Beta
-- Description : does beta-reduction. / beta 簡約を行います。
-- Copyright   : (c) Kimiyuki Onaka, 2020
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
module Jikka.Core.Language.Beta
  ( substitute,
    substituteToplevelExpr,
  )
where

import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.Core.Language.Expr
import Jikka.Core.Language.FreeVars
import Jikka.Core.Language.Util

-- | `substitute` replaces the occrences of the given variable with the given expr. This considers contexts.
--
-- >>> flip evalAlphaT 0 $ substitute (VarName "x") (Lit (LitInt 0)) (Lam (VarName "y") IntTy (Var (VarName "x")))
-- Lam (VarName "y") IntTy (Lit (LitInt 0))
--
-- >>> flip evalAlphaT 0 $ substitute (VarName "x") (Lit (LitInt 0)) (Lam (VarName "x") IntTy (Var (VarName "x")))
-- Lam (VarName "x") IntTy (Var (VarName "x"))
substitute :: MonadAlpha m => VarName -> Expr -> Expr -> m Expr
substitute :: VarName -> Expr -> Expr -> m Expr
substitute VarName
x Expr
e = \case
  Var VarName
y -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> m Expr) -> Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ if VarName
y VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x then Expr
e else VarName -> Expr
Var VarName
y
  Lit Literal
lit -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> m Expr) -> Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ Literal -> Expr
Lit Literal
lit
  App Expr
e1 Expr
e2 -> Expr -> Expr -> Expr
App (Expr -> Expr -> Expr) -> m Expr -> m (Expr -> Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
x Expr
e Expr
e1 m (Expr -> Expr) -> m Expr -> m Expr
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
x Expr
e Expr
e2
  Lam VarName
y Type
t Expr
body ->
    if VarName
x VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
y
      then Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> m Expr) -> Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Type -> Expr -> Expr
Lam VarName
y Type
t Expr
body
      else do
        (VarName
y, Expr
body) <- Expr -> (VarName, Expr) -> m (VarName, Expr)
forall (m :: * -> *).
MonadAlpha m =>
Expr -> (VarName, Expr) -> m (VarName, Expr)
resolveConflict Expr
e (VarName
y, Expr
body)
        VarName -> Type -> Expr -> Expr
Lam VarName
y Type
t (Expr -> Expr) -> m Expr -> m Expr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
x Expr
e Expr
body
  Let VarName
y Type
t Expr
e1 Expr
e2 -> do
    Expr
e1 <- VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
x Expr
e Expr
e1
    if VarName
y VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x
      then Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> m Expr) -> Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Type -> Expr -> Expr -> Expr
Let VarName
y Type
t Expr
e1 Expr
e2
      else do
        (VarName
y, Expr
e2) <- Expr -> (VarName, Expr) -> m (VarName, Expr)
forall (m :: * -> *).
MonadAlpha m =>
Expr -> (VarName, Expr) -> m (VarName, Expr)
resolveConflict Expr
e (VarName
y, Expr
e2)
        VarName -> Type -> Expr -> Expr -> Expr
Let VarName
y Type
t Expr
e1 (Expr -> Expr) -> m Expr -> m Expr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
x Expr
e Expr
e2

substituteToplevelExpr :: (MonadAlpha m, MonadError Error m) => VarName -> Expr -> ToplevelExpr -> m ToplevelExpr
substituteToplevelExpr :: VarName -> Expr -> ToplevelExpr -> m ToplevelExpr
substituteToplevelExpr VarName
x Expr
e = \case
  ResultExpr Expr
e' -> Expr -> ToplevelExpr
ResultExpr (Expr -> ToplevelExpr) -> m Expr -> m ToplevelExpr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
x Expr
e Expr
e'
  ToplevelLet VarName
y Type
t Expr
e' ToplevelExpr
cont -> do
    Expr
e' <- VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
x Expr
e Expr
e'
    if VarName
y VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x
      then ToplevelExpr -> m ToplevelExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (ToplevelExpr -> m ToplevelExpr) -> ToplevelExpr -> m ToplevelExpr
forall a b. (a -> b) -> a -> b
$ VarName -> Type -> Expr -> ToplevelExpr -> ToplevelExpr
ToplevelLet VarName
y Type
t Expr
e' ToplevelExpr
cont
      else do
        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (VarName
y VarName -> Expr -> Bool
`isFreeVar` Expr
e) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          String -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwInternalError (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"Jikka.Core.Language.Beta.substituteToplevelExpr: toplevel name conflicts: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VarName -> String
unVarName VarName
y
        VarName -> Type -> Expr -> ToplevelExpr -> ToplevelExpr
ToplevelLet VarName
y Type
t Expr
e' (ToplevelExpr -> ToplevelExpr) -> m ToplevelExpr -> m ToplevelExpr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VarName -> Expr -> ToplevelExpr -> m ToplevelExpr
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
VarName -> Expr -> ToplevelExpr -> m ToplevelExpr
substituteToplevelExpr VarName
x Expr
e ToplevelExpr
cont
  ToplevelLetRec VarName
f [(VarName, Type)]
args Type
ret Expr
body ToplevelExpr
cont -> do
    if VarName
f VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x
      then ToplevelExpr -> m ToplevelExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (ToplevelExpr -> m ToplevelExpr) -> ToplevelExpr -> m ToplevelExpr
forall a b. (a -> b) -> a -> b
$ VarName
-> [(VarName, Type)]
-> Type
-> Expr
-> ToplevelExpr
-> ToplevelExpr
ToplevelLetRec VarName
f [(VarName, Type)]
args Type
ret Expr
body ToplevelExpr
cont
      else do
        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (VarName
f VarName -> Expr -> Bool
`isFreeVar` Expr
e) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          String -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwInternalError (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"Jikka.Core.Language.Beta.substituteToplevelExpr: toplevel name conflicts: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VarName -> String
unVarName VarName
f
        ([(VarName, Type)]
args, Expr
body) <-
          if VarName
x VarName -> [VarName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ((VarName, Type) -> VarName) -> [(VarName, Type)] -> [VarName]
forall a b. (a -> b) -> [a] -> [b]
map (VarName, Type) -> VarName
forall a b. (a, b) -> a
fst [(VarName, Type)]
args
            then ([(VarName, Type)], Expr) -> m ([(VarName, Type)], Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return ([(VarName, Type)]
args, Expr
body)
            else do
              let go :: ([(VarName, b)], Expr) -> (VarName, b) -> m ([(VarName, b)], Expr)
go ([(VarName, b)]
args, Expr
body) (VarName
y, b
t) = do
                    (VarName
y, Expr
body) <- Expr -> (VarName, Expr) -> m (VarName, Expr)
forall (m :: * -> *).
MonadAlpha m =>
Expr -> (VarName, Expr) -> m (VarName, Expr)
resolveConflict Expr
e (VarName
y, Expr
body)
                    ([(VarName, b)], Expr) -> m ([(VarName, b)], Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return ([(VarName, b)]
args [(VarName, b)] -> [(VarName, b)] -> [(VarName, b)]
forall a. [a] -> [a] -> [a]
++ [(VarName
y, b
t)], Expr
body)
              (([(VarName, Type)], Expr)
 -> (VarName, Type) -> m ([(VarName, Type)], Expr))
-> ([(VarName, Type)], Expr)
-> [(VarName, Type)]
-> m ([(VarName, Type)], Expr)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([(VarName, Type)], Expr)
-> (VarName, Type) -> m ([(VarName, Type)], Expr)
forall (m :: * -> *) b.
MonadAlpha m =>
([(VarName, b)], Expr) -> (VarName, b) -> m ([(VarName, b)], Expr)
go ([], Expr
body) [(VarName, Type)]
args
        VarName
-> [(VarName, Type)]
-> Type
-> Expr
-> ToplevelExpr
-> ToplevelExpr
ToplevelLetRec VarName
f [(VarName, Type)]
args Type
ret Expr
body (ToplevelExpr -> ToplevelExpr) -> m ToplevelExpr -> m ToplevelExpr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VarName -> Expr -> ToplevelExpr -> m ToplevelExpr
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
VarName -> Expr -> ToplevelExpr -> m ToplevelExpr
substituteToplevelExpr VarName
x Expr
e ToplevelExpr
cont

resolveConflict :: MonadAlpha m => Expr -> (VarName, Expr) -> m (VarName, Expr)
resolveConflict :: Expr -> (VarName, Expr) -> m (VarName, Expr)
resolveConflict Expr
e (VarName
x, Expr
e') =
  if VarName
x VarName -> Expr -> Bool
`isFreeVar` Expr
e
    then do
      VarName
y <- VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
x
      Expr
e' <- VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
x (VarName -> Expr
Var VarName
y) Expr
e'
      (VarName, Expr) -> m (VarName, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (VarName
y, Expr
e')
    else (VarName, Expr) -> m (VarName, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (VarName
x, Expr
e')