{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

-- |
-- Module      : Jikka.Core.Convert.TypeInfer
-- Description : does type inference. / 型推論を行います。
-- Copyright   : (c) Kimiyuki Onaka, 2021
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
module Jikka.Core.Convert.TypeInfer
  ( run,
    runExpr,
    runRule,

    -- * internal types and functions
    Equation (..),
    formularizeProgram,
    sortEquations,
    mergeAssertions,
    Subst (..),
    subst,
    substDefault,
    solveEquations,
    substProgram,
  )
where

import Control.Arrow (second)
import Control.Monad.State.Strict
import Control.Monad.Writer.Strict (MonadWriter, censor, execWriterT, tell)
import qualified Data.Map.Strict as M
import Data.Monoid (Dual (..))
import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.Core.Format (formatExpr, formatToplevelExpr, formatType)
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.FreeVars
import Jikka.Core.Language.Lint
import Jikka.Core.Language.NameCheck (namecheckExpr)
import Jikka.Core.Language.TypeCheck (literalToType, typecheckExpr)
import Jikka.Core.Language.Util

data Hint
  = VarHint VarName
  | ExprHint Expr
  | ToplevelExprHint ToplevelExpr
  deriving (Hint -> Hint -> Bool
(Hint -> Hint -> Bool) -> (Hint -> Hint -> Bool) -> Eq Hint
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Hint -> Hint -> Bool
$c/= :: Hint -> Hint -> Bool
== :: Hint -> Hint -> Bool
$c== :: Hint -> Hint -> Bool
Eq, Eq Hint
Eq Hint
-> (Hint -> Hint -> Ordering)
-> (Hint -> Hint -> Bool)
-> (Hint -> Hint -> Bool)
-> (Hint -> Hint -> Bool)
-> (Hint -> Hint -> Bool)
-> (Hint -> Hint -> Hint)
-> (Hint -> Hint -> Hint)
-> Ord Hint
Hint -> Hint -> Bool
Hint -> Hint -> Ordering
Hint -> Hint -> Hint
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Hint -> Hint -> Hint
$cmin :: Hint -> Hint -> Hint
max :: Hint -> Hint -> Hint
$cmax :: Hint -> Hint -> Hint
>= :: Hint -> Hint -> Bool
$c>= :: Hint -> Hint -> Bool
> :: Hint -> Hint -> Bool
$c> :: Hint -> Hint -> Bool
<= :: Hint -> Hint -> Bool
$c<= :: Hint -> Hint -> Bool
< :: Hint -> Hint -> Bool
$c< :: Hint -> Hint -> Bool
compare :: Hint -> Hint -> Ordering
$ccompare :: Hint -> Hint -> Ordering
$cp1Ord :: Eq Hint
Ord, Int -> Hint -> ShowS
[Hint] -> ShowS
Hint -> String
(Int -> Hint -> ShowS)
-> (Hint -> String) -> ([Hint] -> ShowS) -> Show Hint
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Hint] -> ShowS
$cshowList :: [Hint] -> ShowS
show :: Hint -> String
$cshow :: Hint -> String
showsPrec :: Int -> Hint -> ShowS
$cshowsPrec :: Int -> Hint -> ShowS
Show, ReadPrec [Hint]
ReadPrec Hint
Int -> ReadS Hint
ReadS [Hint]
(Int -> ReadS Hint)
-> ReadS [Hint] -> ReadPrec Hint -> ReadPrec [Hint] -> Read Hint
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [Hint]
$creadListPrec :: ReadPrec [Hint]
readPrec :: ReadPrec Hint
$creadPrec :: ReadPrec Hint
readList :: ReadS [Hint]
$creadList :: ReadS [Hint]
readsPrec :: Int -> ReadS Hint
$creadsPrec :: Int -> ReadS Hint
Read)

data Equation
  = TypeEquation Type Type [Hint]
  | TypeAssertion VarName Type
  deriving (Equation -> Equation -> Bool
(Equation -> Equation -> Bool)
-> (Equation -> Equation -> Bool) -> Eq Equation
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Equation -> Equation -> Bool
$c/= :: Equation -> Equation -> Bool
== :: Equation -> Equation -> Bool
$c== :: Equation -> Equation -> Bool
Eq, Eq Equation
Eq Equation
-> (Equation -> Equation -> Ordering)
-> (Equation -> Equation -> Bool)
-> (Equation -> Equation -> Bool)
-> (Equation -> Equation -> Bool)
-> (Equation -> Equation -> Bool)
-> (Equation -> Equation -> Equation)
-> (Equation -> Equation -> Equation)
-> Ord Equation
Equation -> Equation -> Bool
Equation -> Equation -> Ordering
Equation -> Equation -> Equation
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Equation -> Equation -> Equation
$cmin :: Equation -> Equation -> Equation
max :: Equation -> Equation -> Equation
$cmax :: Equation -> Equation -> Equation
>= :: Equation -> Equation -> Bool
$c>= :: Equation -> Equation -> Bool
> :: Equation -> Equation -> Bool
$c> :: Equation -> Equation -> Bool
<= :: Equation -> Equation -> Bool
$c<= :: Equation -> Equation -> Bool
< :: Equation -> Equation -> Bool
$c< :: Equation -> Equation -> Bool
compare :: Equation -> Equation -> Ordering
$ccompare :: Equation -> Equation -> Ordering
$cp1Ord :: Eq Equation
Ord, Int -> Equation -> ShowS
[Equation] -> ShowS
Equation -> String
(Int -> Equation -> ShowS)
-> (Equation -> String) -> ([Equation] -> ShowS) -> Show Equation
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Equation] -> ShowS
$cshowList :: [Equation] -> ShowS
show :: Equation -> String
$cshow :: Equation -> String
showsPrec :: Int -> Equation -> ShowS
$cshowsPrec :: Int -> Equation -> ShowS
Show, ReadPrec [Equation]
ReadPrec Equation
Int -> ReadS Equation
ReadS [Equation]
(Int -> ReadS Equation)
-> ReadS [Equation]
-> ReadPrec Equation
-> ReadPrec [Equation]
-> Read Equation
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [Equation]
$creadListPrec :: ReadPrec [Equation]
readPrec :: ReadPrec Equation
$creadPrec :: ReadPrec Equation
readList :: ReadS [Equation]
$creadList :: ReadS [Equation]
readsPrec :: Int -> ReadS Equation
$creadsPrec :: Int -> ReadS Equation
Read)

type Eqns = Dual [Equation]

consHint :: Hint -> Equation -> Equation
consHint :: Hint -> Equation -> Equation
consHint Hint
hint = \case
  TypeEquation Type
t1 Type
t2 [Hint]
hints -> Type -> Type -> [Hint] -> Equation
TypeEquation Type
t1 Type
t2 (Hint
hint Hint -> [Hint] -> [Hint]
forall a. a -> [a] -> [a]
: [Hint]
hints)
  TypeAssertion VarName
x Type
t -> VarName -> Type -> Equation
TypeAssertion VarName
x Type
t

wrapHint :: MonadWriter Eqns m => Hint -> m a -> m a
wrapHint :: Hint -> m a -> m a
wrapHint Hint
hint = (Dual [Equation] -> Dual [Equation]) -> m a -> m a
forall w (m :: * -> *) a. MonadWriter w m => (w -> w) -> m a -> m a
censor (([Equation] -> [Equation]) -> Dual [Equation] -> Dual [Equation]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Equation -> Equation) -> [Equation] -> [Equation]
forall a b. (a -> b) -> [a] -> [b]
map (Hint -> Equation -> Equation
consHint Hint
hint)))

wrapErrorFromHint :: MonadError Error m => Hint -> m a -> m a
wrapErrorFromHint :: Hint -> m a -> m a
wrapErrorFromHint = \case
  VarHint VarName
x -> String -> m a -> m a
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' (String -> m a -> m a) -> String -> m a -> m a
forall a b. (a -> b) -> a -> b
$ String
"around variable " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName -> String
formatVarName VarName
x
  ExprHint Expr
e -> String -> m a -> m a
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' (String -> m a -> m a) -> String -> m a -> m a
forall a b. (a -> b) -> a -> b
$ String
"around expr " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShowS
summarize (Expr -> String
formatExpr Expr
e)
  ToplevelExprHint ToplevelExpr
e -> String -> m a -> m a
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' (String -> m a -> m a) -> String -> m a -> m a
forall a b. (a -> b) -> a -> b
$ String
"around toplevel expr " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShowS
summarize (ToplevelExpr -> String
formatToplevelExpr ToplevelExpr
e)
  where
    summarize :: ShowS
summarize String
s = case String -> [String]
lines String
s of
      (String
s : String
_ : [String]
_) -> String
s String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" ..."
      [String]
_ -> String
s

wrapErrorFromHints :: MonadError Error m => [Hint] -> m a -> m a
wrapErrorFromHints :: [Hint] -> m a -> m a
wrapErrorFromHints [Hint]
hints = (Hint -> (m a -> m a) -> m a -> m a)
-> (m a -> m a) -> [Hint] -> m a -> m a
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Hint
hint m a -> m a
f -> Hint -> m a -> m a
forall (m :: * -> *) a. MonadError Error m => Hint -> m a -> m a
wrapErrorFromHint Hint
hint (m a -> m a) -> (m a -> m a) -> m a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> m a
f) m a -> m a
forall a. a -> a
id [Hint]
hints

formularizeType :: MonadWriter Eqns m => Type -> Type -> m ()
formularizeType :: Type -> Type -> m ()
formularizeType Type
t1 Type
t2 = Dual [Equation] -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Dual [Equation] -> m ()) -> Dual [Equation] -> m ()
forall a b. (a -> b) -> a -> b
$ [Equation] -> Dual [Equation]
forall a. a -> Dual a
Dual [Type -> Type -> [Hint] -> Equation
TypeEquation Type
t1 Type
t2 []]

formularizeVarName :: MonadWriter Eqns m => VarName -> Type -> m ()
formularizeVarName :: VarName -> Type -> m ()
formularizeVarName VarName
x Type
t = Dual [Equation] -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Dual [Equation] -> m ()) -> Dual [Equation] -> m ()
forall a b. (a -> b) -> a -> b
$ [Equation] -> Dual [Equation]
forall a. a -> Dual a
Dual [VarName -> Type -> Equation
TypeAssertion VarName
x Type
t]

formularizeExpr :: (MonadWriter Eqns m, MonadAlpha m, MonadError Error m) => Expr -> m Type
formularizeExpr :: Expr -> m Type
formularizeExpr Expr
e = Hint -> m Type -> m Type
forall (m :: * -> *) a.
MonadWriter (Dual [Equation]) m =>
Hint -> m a -> m a
wrapHint (Expr -> Hint
ExprHint Expr
e) (m Type -> m Type) -> m Type -> m Type
forall a b. (a -> b) -> a -> b
$ case Expr
e of
  Var VarName
x -> do
    Type
t <- m Type
forall (m :: * -> *). MonadAlpha m => m Type
genType
    VarName -> Type -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
VarName -> Type -> m ()
formularizeVarName VarName
x Type
t
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t
  Lit Literal
lit -> case Literal
lit of
    LitBuiltin (Proj Integer
_) [] -> m Type
forall (m :: * -> *). MonadAlpha m => m Type
genType -- Proj may have a empty list.
    Literal
_ -> Literal -> m Type
forall (m :: * -> *). MonadError Error m => Literal -> m Type
literalToType Literal
lit
  App Expr
f Expr
e -> do
    Type
ret <- m Type
forall (m :: * -> *). MonadAlpha m => m Type
genType
    Type
t <- Expr -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m,
 MonadError Error m) =>
Expr -> m Type
formularizeExpr Expr
e
    Expr -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m,
 MonadError Error m) =>
Expr -> Type -> m ()
formularizeExpr' Expr
f (Type -> Type -> Type
FunTy Type
t Type
ret)
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
ret
  Lam VarName
x Type
t Expr
body -> do
    VarName -> Type -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
VarName -> Type -> m ()
formularizeVarName VarName
x Type
t
    Type
ret <- Expr -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m,
 MonadError Error m) =>
Expr -> m Type
formularizeExpr Expr
body
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Type -> Type -> Type
FunTy Type
t Type
ret
  Let VarName
x Type
t Expr
e1 Expr
e2 -> do
    VarName -> Type -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
VarName -> Type -> m ()
formularizeVarName VarName
x Type
t
    Expr -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m,
 MonadError Error m) =>
Expr -> Type -> m ()
formularizeExpr' Expr
e1 Type
t
    Expr -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m,
 MonadError Error m) =>
Expr -> m Type
formularizeExpr Expr
e2
  Assert Expr
e1 Expr
e2 -> do
    Expr -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m,
 MonadError Error m) =>
Expr -> Type -> m ()
formularizeExpr' Expr
e1 Type
BoolTy
    Expr -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m,
 MonadError Error m) =>
Expr -> m Type
formularizeExpr Expr
e2

formularizeExpr' :: (MonadWriter Eqns m, MonadAlpha m, MonadError Error m) => Expr -> Type -> m ()
formularizeExpr' :: Expr -> Type -> m ()
formularizeExpr' Expr
e Type
t = do
  Type
t' <- Expr -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m,
 MonadError Error m) =>
Expr -> m Type
formularizeExpr Expr
e
  Hint -> m () -> m ()
forall (m :: * -> *) a.
MonadWriter (Dual [Equation]) m =>
Hint -> m a -> m a
wrapHint (Expr -> Hint
ExprHint Expr
e) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Type -> Type -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
Type -> Type -> m ()
formularizeType Type
t Type
t'

formularizeToplevelExpr :: (MonadWriter Eqns m, MonadAlpha m, MonadError Error m) => ToplevelExpr -> m Type
formularizeToplevelExpr :: ToplevelExpr -> m Type
formularizeToplevelExpr ToplevelExpr
e = Hint -> m Type -> m Type
forall (m :: * -> *) a.
MonadWriter (Dual [Equation]) m =>
Hint -> m a -> m a
wrapHint (ToplevelExpr -> Hint
ToplevelExprHint ToplevelExpr
e) (m Type -> m Type) -> m Type -> m Type
forall a b. (a -> b) -> a -> b
$ case ToplevelExpr
e of
  ResultExpr Expr
e -> Expr -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m,
 MonadError Error m) =>
Expr -> m Type
formularizeExpr Expr
e
  ToplevelLet VarName
x Type
t Expr
e ToplevelExpr
cont -> do
    VarName -> Type -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
VarName -> Type -> m ()
formularizeVarName VarName
x Type
t
    Expr -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m,
 MonadError Error m) =>
Expr -> Type -> m ()
formularizeExpr' Expr
e Type
t
    ToplevelExpr -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m,
 MonadError Error m) =>
ToplevelExpr -> m Type
formularizeToplevelExpr ToplevelExpr
cont
  ToplevelLetRec VarName
f [(VarName, Type)]
args Type
ret Expr
body ToplevelExpr
cont -> do
    VarName -> Type -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
VarName -> Type -> m ()
formularizeVarName VarName
f ([Type] -> Type -> Type
curryFunTy (((VarName, Type) -> Type) -> [(VarName, Type)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (VarName, Type) -> Type
forall a b. (a, b) -> b
snd [(VarName, Type)]
args) Type
ret)
    ((VarName, Type) -> m ()) -> [(VarName, Type)] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VarName -> Type -> m ()) -> (VarName, Type) -> m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VarName -> Type -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
VarName -> Type -> m ()
formularizeVarName) [(VarName, Type)]
args
    Expr -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m,
 MonadError Error m) =>
Expr -> Type -> m ()
formularizeExpr' Expr
body Type
ret
    ToplevelExpr -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m,
 MonadError Error m) =>
ToplevelExpr -> m Type
formularizeToplevelExpr ToplevelExpr
cont
  ToplevelAssert Expr
e ToplevelExpr
cont -> do
    Expr -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m,
 MonadError Error m) =>
Expr -> Type -> m ()
formularizeExpr' Expr
e Type
BoolTy
    ToplevelExpr -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m,
 MonadError Error m) =>
ToplevelExpr -> m Type
formularizeToplevelExpr ToplevelExpr
cont

formularizeProgram :: (MonadAlpha m, MonadError Error m) => Program -> m [Equation]
formularizeProgram :: ToplevelExpr -> m [Equation]
formularizeProgram ToplevelExpr
prog = Dual [Equation] -> [Equation]
forall a. Dual a -> a
getDual (Dual [Equation] -> [Equation])
-> m (Dual [Equation]) -> m [Equation]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> WriterT (Dual [Equation]) m Type -> m (Dual [Equation])
forall (m :: * -> *) w a. Monad m => WriterT w m a -> m w
execWriterT (ToplevelExpr -> WriterT (Dual [Equation]) m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m,
 MonadError Error m) =>
ToplevelExpr -> m Type
formularizeToplevelExpr ToplevelExpr
prog)

sortEquations :: [Equation] -> ([(Type, Type, [Hint])], [(VarName, Type)])
sortEquations :: [Equation] -> ([(Type, Type, [Hint])], [(VarName, Type)])
sortEquations = [(Type, Type, [Hint])]
-> [(VarName, Type)]
-> [Equation]
-> ([(Type, Type, [Hint])], [(VarName, Type)])
go [] []
  where
    go :: [(Type, Type, [Hint])]
-> [(VarName, Type)]
-> [Equation]
-> ([(Type, Type, [Hint])], [(VarName, Type)])
go [(Type, Type, [Hint])]
eqns' [(VarName, Type)]
assertions [] = ([(Type, Type, [Hint])]
eqns', [(VarName, Type)]
assertions)
    go [(Type, Type, [Hint])]
eqns' [(VarName, Type)]
assertions (Equation
eqn : [Equation]
eqns) = case Equation
eqn of
      TypeEquation Type
t1 Type
t2 [Hint]
hints -> [(Type, Type, [Hint])]
-> [(VarName, Type)]
-> [Equation]
-> ([(Type, Type, [Hint])], [(VarName, Type)])
go ((Type
t1, Type
t2, [Hint]
hints) (Type, Type, [Hint])
-> [(Type, Type, [Hint])] -> [(Type, Type, [Hint])]
forall a. a -> [a] -> [a]
: [(Type, Type, [Hint])]
eqns') [(VarName, Type)]
assertions [Equation]
eqns
      TypeAssertion VarName
x Type
t -> [(Type, Type, [Hint])]
-> [(VarName, Type)]
-> [Equation]
-> ([(Type, Type, [Hint])], [(VarName, Type)])
go [(Type, Type, [Hint])]
eqns' ((VarName
x, Type
t) (VarName, Type) -> [(VarName, Type)] -> [(VarName, Type)]
forall a. a -> [a] -> [a]
: [(VarName, Type)]
assertions) [Equation]
eqns

mergeAssertions :: [(VarName, Type)] -> [(Type, Type, [Hint])]
mergeAssertions :: [(VarName, Type)] -> [(Type, Type, [Hint])]
mergeAssertions = Map VarName Type
-> [(Type, Type, [Hint])]
-> [(VarName, Type)]
-> [(Type, Type, [Hint])]
forall b.
Map VarName b
-> [(b, b, [Hint])] -> [(VarName, b)] -> [(b, b, [Hint])]
go Map VarName Type
forall k a. Map k a
M.empty []
  where
    go :: Map VarName b
-> [(b, b, [Hint])] -> [(VarName, b)] -> [(b, b, [Hint])]
go Map VarName b
_ [(b, b, [Hint])]
eqns [] = [(b, b, [Hint])]
eqns
    go Map VarName b
gamma [(b, b, [Hint])]
eqns ((VarName
x, b
t) : [(VarName, b)]
assertions) = case VarName -> Map VarName b -> Maybe b
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VarName
x Map VarName b
gamma of
      Maybe b
Nothing -> Map VarName b
-> [(b, b, [Hint])] -> [(VarName, b)] -> [(b, b, [Hint])]
go (VarName -> b -> Map VarName b -> Map VarName b
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VarName
x b
t Map VarName b
gamma) [(b, b, [Hint])]
eqns [(VarName, b)]
assertions
      Just b
t' -> Map VarName b
-> [(b, b, [Hint])] -> [(VarName, b)] -> [(b, b, [Hint])]
go Map VarName b
gamma ((b
t, b
t', [VarName -> Hint
VarHint VarName
x]) (b, b, [Hint]) -> [(b, b, [Hint])] -> [(b, b, [Hint])]
forall a. a -> [a] -> [a]
: [(b, b, [Hint])]
eqns) [(VarName, b)]
assertions

-- | `Subst` is type substituion. It's a mapping from type variables to their actual types.
newtype Subst = Subst {Subst -> Map TypeName Type
unSubst :: M.Map TypeName Type}
  deriving (Subst -> Subst -> Bool
(Subst -> Subst -> Bool) -> (Subst -> Subst -> Bool) -> Eq Subst
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Subst -> Subst -> Bool
$c/= :: Subst -> Subst -> Bool
== :: Subst -> Subst -> Bool
$c== :: Subst -> Subst -> Bool
Eq, Eq Subst
Eq Subst
-> (Subst -> Subst -> Ordering)
-> (Subst -> Subst -> Bool)
-> (Subst -> Subst -> Bool)
-> (Subst -> Subst -> Bool)
-> (Subst -> Subst -> Bool)
-> (Subst -> Subst -> Subst)
-> (Subst -> Subst -> Subst)
-> Ord Subst
Subst -> Subst -> Bool
Subst -> Subst -> Ordering
Subst -> Subst -> Subst
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Subst -> Subst -> Subst
$cmin :: Subst -> Subst -> Subst
max :: Subst -> Subst -> Subst
$cmax :: Subst -> Subst -> Subst
>= :: Subst -> Subst -> Bool
$c>= :: Subst -> Subst -> Bool
> :: Subst -> Subst -> Bool
$c> :: Subst -> Subst -> Bool
<= :: Subst -> Subst -> Bool
$c<= :: Subst -> Subst -> Bool
< :: Subst -> Subst -> Bool
$c< :: Subst -> Subst -> Bool
compare :: Subst -> Subst -> Ordering
$ccompare :: Subst -> Subst -> Ordering
$cp1Ord :: Eq Subst
Ord, Int -> Subst -> ShowS
[Subst] -> ShowS
Subst -> String
(Int -> Subst -> ShowS)
-> (Subst -> String) -> ([Subst] -> ShowS) -> Show Subst
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Subst] -> ShowS
$cshowList :: [Subst] -> ShowS
show :: Subst -> String
$cshow :: Subst -> String
showsPrec :: Int -> Subst -> ShowS
$cshowsPrec :: Int -> Subst -> ShowS
Show, ReadPrec [Subst]
ReadPrec Subst
Int -> ReadS Subst
ReadS [Subst]
(Int -> ReadS Subst)
-> ReadS [Subst]
-> ReadPrec Subst
-> ReadPrec [Subst]
-> Read Subst
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [Subst]
$creadListPrec :: ReadPrec [Subst]
readPrec :: ReadPrec Subst
$creadPrec :: ReadPrec Subst
readList :: ReadS [Subst]
$creadList :: ReadS [Subst]
readsPrec :: Int -> ReadS Subst
$creadsPrec :: Int -> ReadS Subst
Read)

subst :: Subst -> Type -> Type
subst :: Subst -> Type -> Type
subst Subst
sigma = \case
  VarTy TypeName
x ->
    case TypeName -> Map TypeName Type -> Maybe Type
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup TypeName
x (Subst -> Map TypeName Type
unSubst Subst
sigma) of
      Maybe Type
Nothing -> TypeName -> Type
VarTy TypeName
x
      Just Type
t -> Subst -> Type -> Type
subst Subst
sigma Type
t
  Type
IntTy -> Type
IntTy
  Type
BoolTy -> Type
BoolTy
  ListTy Type
t -> Type -> Type
ListTy (Subst -> Type -> Type
subst Subst
sigma Type
t)
  TupleTy [Type]
ts -> [Type] -> Type
TupleTy ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> Type -> Type
subst Subst
sigma) [Type]
ts)
  FunTy Type
t Type
ret -> Type -> Type -> Type
FunTy (Subst -> Type -> Type
subst Subst
sigma Type
t) (Subst -> Type -> Type
subst Subst
sigma Type
ret)
  DataStructureTy DataStructure
ds -> DataStructure -> Type
DataStructureTy DataStructure
ds

unifyTyVar :: (MonadState Subst m, MonadError Error m) => TypeName -> Type -> m ()
unifyTyVar :: TypeName -> Type -> m ()
unifyTyVar TypeName
x Type
t =
  if TypeName
x TypeName -> [TypeName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Type -> [TypeName]
freeTyVars Type
t
    then String -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwInternalError (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"looped type equation " String -> ShowS
forall a. [a] -> [a] -> [a]
++ TypeName -> String
formatTypeName TypeName
x String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t
    else do
      (Subst -> Subst) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' (Map TypeName Type -> Subst
Subst (Map TypeName Type -> Subst)
-> (Subst -> Map TypeName Type) -> Subst -> Subst
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeName -> Type -> Map TypeName Type -> Map TypeName Type
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert TypeName
x Type
t (Map TypeName Type -> Map TypeName Type)
-> (Subst -> Map TypeName Type) -> Subst -> Map TypeName Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Subst -> Map TypeName Type
unSubst) -- This doesn't introduce the loop.

unifyType :: (MonadState Subst m, MonadError Error m) => Type -> Type -> m ()
unifyType :: Type -> Type -> m ()
unifyType Type
t1 Type
t2 = String -> m () -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' (String
"failed to unify " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" and " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t2) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
  Subst
sigma <- m Subst
forall s (m :: * -> *). MonadState s m => m s
get
  Type
t1 <- Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Subst -> Type -> Type
subst Subst
sigma Type
t1 -- shadowing
  Type
t2 <- Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Subst -> Type -> Type
subst Subst
sigma Type
t2 -- shadowing
  case (Type
t1, Type
t2) of
    (Type, Type)
_ | Type
t1 Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t2 -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    (VarTy TypeName
x1, Type
_) -> do
      TypeName -> Type -> m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
TypeName -> Type -> m ()
unifyTyVar TypeName
x1 Type
t2
    (Type
_, VarTy TypeName
x2) -> do
      TypeName -> Type -> m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
TypeName -> Type -> m ()
unifyTyVar TypeName
x2 Type
t1
    (ListTy Type
t1, ListTy Type
t2) -> do
      Type -> Type -> m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
Type -> Type -> m ()
unifyType Type
t1 Type
t2
    (TupleTy [Type]
ts1, TupleTy [Type]
ts2) -> do
      if [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts2
        then ((Type, Type) -> m ()) -> [(Type, Type)] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Type -> Type -> m ()) -> (Type, Type) -> m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Type -> Type -> m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
Type -> Type -> m ()
unifyType) ([Type] -> [Type] -> [(Type, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
ts1 [Type]
ts2)
        else String -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwInternalError (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"different type ctors " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" and " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t2
    (FunTy Type
t1 Type
ret1, FunTy Type
t2 Type
ret2) -> do
      Type -> Type -> m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
Type -> Type -> m ()
unifyType Type
t1 Type
t2
      Type -> Type -> m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
Type -> Type -> m ()
unifyType Type
ret1 Type
ret2
    (Type, Type)
_ -> String -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwInternalError (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"different type ctors " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" and " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t2

solveEquations :: MonadError Error m => [(Type, Type, [Hint])] -> m Subst
solveEquations :: [(Type, Type, [Hint])] -> m Subst
solveEquations [(Type, Type, [Hint])]
eqns = String -> m Subst -> m Subst
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"failed to solve type equations" (m Subst -> m Subst) -> m Subst -> m Subst
forall a b. (a -> b) -> a -> b
$ do
  (StateT Subst m () -> Subst -> m Subst
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
`execStateT` Map TypeName Type -> Subst
Subst Map TypeName Type
forall k a. Map k a
M.empty) (StateT Subst m () -> m Subst) -> StateT Subst m () -> m Subst
forall a b. (a -> b) -> a -> b
$ do
    [(Type, Type, [Hint])]
-> ((Type, Type, [Hint]) -> StateT Subst m ()) -> StateT Subst m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Type, Type, [Hint])]
eqns (((Type, Type, [Hint]) -> StateT Subst m ()) -> StateT Subst m ())
-> ((Type, Type, [Hint]) -> StateT Subst m ()) -> StateT Subst m ()
forall a b. (a -> b) -> a -> b
$ \(Type
t1, Type
t2, [Hint]
hints) -> do
      [Hint] -> StateT Subst m () -> StateT Subst m ()
forall (m :: * -> *) a. MonadError Error m => [Hint] -> m a -> m a
wrapErrorFromHints [Hint]
hints (StateT Subst m () -> StateT Subst m ())
-> StateT Subst m () -> StateT Subst m ()
forall a b. (a -> b) -> a -> b
$ do
        Type -> Type -> StateT Subst m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
Type -> Type -> m ()
unifyType Type
t1 Type
t2

-- | `substDefault` replaces all undetermined type variables with the given default type.
substDefault :: Type -> Type -> Type
substDefault :: Type -> Type -> Type
substDefault Type
t0 = \case
  VarTy TypeName
_ -> [Type] -> Type
TupleTy []
  Type
IntTy -> Type
IntTy
  Type
BoolTy -> Type
BoolTy
  ListTy Type
t -> Type -> Type
ListTy (Type -> Type -> Type
substDefault Type
t0 Type
t)
  TupleTy [Type]
ts -> [Type] -> Type
TupleTy ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Type -> Type
substDefault Type
t0) [Type]
ts)
  FunTy Type
t Type
ret -> Type -> Type -> Type
FunTy (Type -> Type -> Type
substDefault Type
t0 Type
t) (Type -> Type -> Type
substDefault Type
t0 Type
ret)
  DataStructureTy DataStructure
ds -> DataStructure -> Type
DataStructureTy DataStructure
ds

subst' :: Maybe Type -> Subst -> Type -> Type
subst' :: Maybe Type -> Subst -> Type -> Type
subst' Maybe Type
t0 Subst
sigma = (Type -> Type)
-> (Type -> Type -> Type) -> Maybe Type -> Type -> Type
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Type -> Type
forall a. a -> a
id Type -> Type -> Type
substDefault Maybe Type
t0 (Type -> Type) -> (Type -> Type) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Subst -> Type -> Type
subst Subst
sigma

fixProj :: MonadError Error m => [(VarName, Type)] -> Expr -> m Expr
fixProj :: [(VarName, Type)] -> Expr -> m Expr
fixProj [(VarName, Type)]
env = \case
  Proj' [] Integer
i Expr
e -> do
    -- fix Proj with a empty list
    Type
t <- [(VarName, Type)] -> Expr -> m Type
forall (m :: * -> *).
MonadError Error m =>
[(VarName, Type)] -> Expr -> m Type
typecheckExpr [(VarName, Type)]
env Expr
e
    case Type
t of
      TupleTy [Type]
ts -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> m Expr) -> Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ [Type] -> Integer -> Expr -> Expr
Proj' [Type]
ts Integer
i Expr
e
      Type
_ -> String -> m Expr
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwInternalError (String -> m Expr) -> String -> m Expr
forall a b. (a -> b) -> a -> b
$ String
"type of argument of proj must be a tuple: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t
  Expr
e -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e

substProgram :: MonadError Error m => Maybe Type -> Subst -> Program -> m Program
substProgram :: Maybe Type -> Subst -> ToplevelExpr -> m ToplevelExpr
substProgram Maybe Type
t0 Subst
sigma = ([(VarName, Type)] -> Expr -> m Expr)
-> ToplevelExpr -> m ToplevelExpr
forall (m :: * -> *).
Monad m =>
([(VarName, Type)] -> Expr -> m Expr)
-> ToplevelExpr -> m ToplevelExpr
mapExprProgramM (([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)] -> Expr -> m Expr
forall (m :: * -> *).
Monad m =>
([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)] -> Expr -> m Expr
mapSubExprM [(VarName, Type)] -> Expr -> m Expr
forall (m :: * -> *).
MonadError Error m =>
[(VarName, Type)] -> Expr -> m Expr
fixProj) (ToplevelExpr -> m ToplevelExpr)
-> (ToplevelExpr -> ToplevelExpr) -> ToplevelExpr -> m ToplevelExpr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Type) -> ToplevelExpr -> ToplevelExpr
mapTypeProgram (Maybe Type -> Subst -> Type -> Type
subst' Maybe Type
t0 Subst
sigma)

substExpr :: MonadError Error m => Maybe Type -> Subst -> [(VarName, Type)] -> Expr -> m Expr
substExpr :: Maybe Type -> Subst -> [(VarName, Type)] -> Expr -> m Expr
substExpr Maybe Type
t0 Subst
sigma [(VarName, Type)]
env = ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)] -> Expr -> m Expr
forall (m :: * -> *).
Monad m =>
([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)] -> Expr -> m Expr
mapSubExprM [(VarName, Type)] -> Expr -> m Expr
forall (m :: * -> *).
MonadError Error m =>
[(VarName, Type)] -> Expr -> m Expr
fixProj [(VarName, Type)]
env (Expr -> m Expr) -> (Expr -> Expr) -> Expr -> m Expr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Type) -> Expr -> Expr
mapTypeExpr (Maybe Type -> Subst -> Type -> Type
subst' Maybe Type
t0 Subst
sigma)

-- | `run` does type inference.
--
-- * This assumes that program has no name conflicts.
--
-- Before:
--
-- > let f = fun y -> y
-- > in let x = 1
-- > in f(x + x)
--
-- After:
--
-- > let f: int -> int = fun y: int -> y
-- > in let x: int = 1
-- > in f(x + x)
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: ToplevelExpr -> m ToplevelExpr
run ToplevelExpr
prog = String -> m ToplevelExpr -> m ToplevelExpr
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Convert.TypeInfer.run" (m ToplevelExpr -> m ToplevelExpr)
-> m ToplevelExpr -> m ToplevelExpr
forall a b. (a -> b) -> a -> b
$ do
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    ToplevelExpr -> m ()
forall (m :: * -> *). MonadError Error m => ToplevelExpr -> m ()
ensureAlphaConverted ToplevelExpr
prog
  [Equation]
eqns <- ToplevelExpr -> m [Equation]
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
ToplevelExpr -> m [Equation]
formularizeProgram ToplevelExpr
prog
  let ([(Type, Type, [Hint])]
eqns', [(VarName, Type)]
assertions) = [Equation] -> ([(Type, Type, [Hint])], [(VarName, Type)])
sortEquations [Equation]
eqns
  let eqns'' :: [(Type, Type, [Hint])]
eqns'' = [(VarName, Type)] -> [(Type, Type, [Hint])]
mergeAssertions [(VarName, Type)]
assertions
  Subst
sigma <- [(Type, Type, [Hint])] -> m Subst
forall (m :: * -> *).
MonadError Error m =>
[(Type, Type, [Hint])] -> m Subst
solveEquations ([(Type, Type, [Hint])]
eqns' [(Type, Type, [Hint])]
-> [(Type, Type, [Hint])] -> [(Type, Type, [Hint])]
forall a. [a] -> [a] -> [a]
++ [(Type, Type, [Hint])]
eqns'')
  let t0 :: Maybe Type
t0 = Type -> Maybe Type
forall a. a -> Maybe a
Just Type
UnitTy
  ToplevelExpr
prog <- Maybe Type -> Subst -> ToplevelExpr -> m ToplevelExpr
forall (m :: * -> *).
MonadError Error m =>
Maybe Type -> Subst -> ToplevelExpr -> m ToplevelExpr
substProgram Maybe Type
t0 Subst
sigma ToplevelExpr
prog
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    ToplevelExpr -> m ()
forall (m :: * -> *). MonadError Error m => ToplevelExpr -> m ()
ensureAlphaConverted ToplevelExpr
prog
    ToplevelExpr -> m ()
forall (m :: * -> *). MonadError Error m => ToplevelExpr -> m ()
ensureWellTyped ToplevelExpr
prog
  ToplevelExpr -> m ToplevelExpr
forall (m :: * -> *) a. Monad m => a -> m a
return ToplevelExpr
prog

runExpr :: (MonadAlpha m, MonadError Error m) => [(VarName, Type)] -> Expr -> m Expr
runExpr :: [(VarName, Type)] -> Expr -> m Expr
runExpr [(VarName, Type)]
env Expr
e = String -> m Expr -> m Expr
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Convert.TypeInfer.runExpr" (m Expr -> m Expr) -> m Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ do
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    [(VarName, Type)] -> Expr -> m ()
forall (m :: * -> *).
MonadError Error m =>
[(VarName, Type)] -> Expr -> m ()
namecheckExpr [(VarName, Type)]
env Expr
e
  [Equation]
eqns <- Dual [Equation] -> [Equation]
forall a. Dual a -> a
getDual (Dual [Equation] -> [Equation])
-> m (Dual [Equation]) -> m [Equation]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> WriterT (Dual [Equation]) m Type -> m (Dual [Equation])
forall (m :: * -> *) w a. Monad m => WriterT w m a -> m w
execWriterT (Expr -> WriterT (Dual [Equation]) m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m,
 MonadError Error m) =>
Expr -> m Type
formularizeExpr Expr
e)
  let ([(Type, Type, [Hint])]
eqns', [(VarName, Type)]
assertions) = [Equation] -> ([(Type, Type, [Hint])], [(VarName, Type)])
sortEquations [Equation]
eqns
  let eqns'' :: [(Type, Type, [Hint])]
eqns'' = [(VarName, Type)] -> [(Type, Type, [Hint])]
mergeAssertions ([(VarName, Type)]
env [(VarName, Type)] -> [(VarName, Type)] -> [(VarName, Type)]
forall a. [a] -> [a] -> [a]
++ [(VarName, Type)]
assertions)
  Subst
sigma <- [(Type, Type, [Hint])] -> m Subst
forall (m :: * -> *).
MonadError Error m =>
[(Type, Type, [Hint])] -> m Subst
solveEquations ([(Type, Type, [Hint])]
eqns' [(Type, Type, [Hint])]
-> [(Type, Type, [Hint])] -> [(Type, Type, [Hint])]
forall a. [a] -> [a] -> [a]
++ [(Type, Type, [Hint])]
eqns'')
  let t0 :: Maybe a
t0 = Maybe a
forall a. Maybe a
Nothing -- don't use substDefault
  [(VarName, Type)]
env <- [(VarName, Type)] -> m [(VarName, Type)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([(VarName, Type)] -> m [(VarName, Type)])
-> [(VarName, Type)] -> m [(VarName, Type)]
forall a b. (a -> b) -> a -> b
$ ((VarName, Type) -> (VarName, Type))
-> [(VarName, Type)] -> [(VarName, Type)]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> Type) -> (VarName, Type) -> (VarName, Type)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Maybe Type -> Subst -> Type -> Type
subst' Maybe Type
forall a. Maybe a
t0 Subst
sigma)) [(VarName, Type)]
env
  Expr
e <- Maybe Type -> Subst -> [(VarName, Type)] -> Expr -> m Expr
forall (m :: * -> *).
MonadError Error m =>
Maybe Type -> Subst -> [(VarName, Type)] -> Expr -> m Expr
substExpr Maybe Type
forall a. Maybe a
t0 Subst
sigma [(VarName, Type)]
env Expr
e
  m Type -> m Type
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m Type -> m Type) -> m Type -> m Type
forall a b. (a -> b) -> a -> b
$ do
    [(VarName, Type)] -> Expr -> m ()
forall (m :: * -> *).
MonadError Error m =>
[(VarName, Type)] -> Expr -> m ()
namecheckExpr [(VarName, Type)]
env Expr
e
    [(VarName, Type)] -> Expr -> m Type
forall (m :: * -> *).
MonadError Error m =>
[(VarName, Type)] -> Expr -> m Type
typecheckExpr [(VarName, Type)]
env Expr
e
  Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e

runRule :: (MonadAlpha m, MonadError Error m) => [(VarName, Type)] -> Expr -> Expr -> m ([(VarName, Type)], Expr, Expr)
runRule :: [(VarName, Type)]
-> Expr -> Expr -> m ([(VarName, Type)], Expr, Expr)
runRule [(VarName, Type)]
args Expr
e1 Expr
e2 = String
-> m ([(VarName, Type)], Expr, Expr)
-> m ([(VarName, Type)], Expr, Expr)
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Convert.TypeInfer.runRule" (m ([(VarName, Type)], Expr, Expr)
 -> m ([(VarName, Type)], Expr, Expr))
-> m ([(VarName, Type)], Expr, Expr)
-> m ([(VarName, Type)], Expr, Expr)
forall a b. (a -> b) -> a -> b
$ do
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    -- Underscores are allowed for names, so we don't use namecheckExpr here.
    () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  [Equation]
eqns <- (Dual [Equation] -> [Equation]
forall a. Dual a -> a
getDual (Dual [Equation] -> [Equation])
-> m (Dual [Equation]) -> m [Equation]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) (m (Dual [Equation]) -> m [Equation])
-> (WriterT (Dual [Equation]) m () -> m (Dual [Equation]))
-> WriterT (Dual [Equation]) m ()
-> m [Equation]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WriterT (Dual [Equation]) m () -> m (Dual [Equation])
forall (m :: * -> *) w a. Monad m => WriterT w m a -> m w
execWriterT (WriterT (Dual [Equation]) m () -> m [Equation])
-> WriterT (Dual [Equation]) m () -> m [Equation]
forall a b. (a -> b) -> a -> b
$ do
    Type
t <- Expr -> WriterT (Dual [Equation]) m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m,
 MonadError Error m) =>
Expr -> m Type
formularizeExpr Expr
e1
    Expr -> Type -> WriterT (Dual [Equation]) m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m,
 MonadError Error m) =>
Expr -> Type -> m ()
formularizeExpr' Expr
e2 Type
t
  let ([(Type, Type, [Hint])]
eqns', [(VarName, Type)]
assertions) = [Equation] -> ([(Type, Type, [Hint])], [(VarName, Type)])
sortEquations [Equation]
eqns
  let eqns'' :: [(Type, Type, [Hint])]
eqns'' = [(VarName, Type)] -> [(Type, Type, [Hint])]
mergeAssertions [(VarName, Type)]
assertions
  Subst
sigma <- [(Type, Type, [Hint])] -> m Subst
forall (m :: * -> *).
MonadError Error m =>
[(Type, Type, [Hint])] -> m Subst
solveEquations ([(Type, Type, [Hint])]
eqns' [(Type, Type, [Hint])]
-> [(Type, Type, [Hint])] -> [(Type, Type, [Hint])]
forall a. [a] -> [a] -> [a]
++ [(Type, Type, [Hint])]
eqns'')
  [(VarName, Type)]
args <- [(VarName, Type)] -> m [(VarName, Type)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([(VarName, Type)] -> m [(VarName, Type)])
-> [(VarName, Type)] -> m [(VarName, Type)]
forall a b. (a -> b) -> a -> b
$ ((VarName, Type) -> (VarName, Type))
-> [(VarName, Type)] -> [(VarName, Type)]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> Type) -> (VarName, Type) -> (VarName, Type)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Subst -> Type -> Type
subst Subst
sigma)) [(VarName, Type)]
args -- don't use substDefault
  Expr
e1 <- Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> m Expr) -> Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> Expr -> Expr
mapTypeExpr (Subst -> Type -> Type
subst Subst
sigma) Expr
e1 -- don't use substDefault
  Expr
e2 <- Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> m Expr) -> Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> Expr -> Expr
mapTypeExpr (Subst -> Type -> Type
subst Subst
sigma) Expr
e2 -- don't use substDefault
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    -- Underscores are allowed for names, so we don't use namecheckExpr here.
    -- Type variables can remain, so we don't use typecheckExpr here.
    () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  ([(VarName, Type)], Expr, Expr)
-> m ([(VarName, Type)], Expr, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return ([(VarName, Type)]
args, Expr
e1, Expr
e2)