module LambdaCube.SystemFw_.TypeChecker
  ( reduceType

  , infer
  , inferKind
  ) where

import           Data.List                         (uncons)
import           LambdaCube.SystemFw_.Ast
import           LambdaCube.SystemFw_.Substitution

reduceType :: LCType -> LCType
reduceType :: LCType -> LCType
reduceType = LCType -> LCType
go
  where
    go :: LCType -> LCType
go LCType
LCBase = LCType
LCBase
    go e :: LCType
e@(LCTVar Int
_) = LCType
e
    go (LCArr LCType
a LCType
b) = LCType -> LCType
go LCType
a LCType -> LCType -> LCType
`LCArr` LCType -> LCType
go LCType
b
    go (LCTTLam LCKind
k LCType
b) =  LCKind -> LCType -> LCType
LCTTLam LCKind
k (LCType -> LCType) -> LCType -> LCType
forall a b. (a -> b) -> a -> b
$ LCType -> LCType
go LCType
b
    go (LCTTApp LCType
f LCType
a)
      | LCTTLam LCKind
_ LCType
b <- LCType -> LCType
go LCType
f
      , LCType
v <- LCType -> LCType
go LCType
a
      = LCType -> LCType
go (LCType -> LCType) -> LCType -> LCType
forall a b. (a -> b) -> a -> b
$ LCType -> Int -> LCType -> LCType
substituteTypeInType LCType
v Int
0 LCType
b
      | Bool
otherwise
      = [Char] -> LCType
forall a. HasCallStack => [Char] -> a
error [Char]
"Did you really kind check this?"

infer :: LCTerm -> LCType
infer :: LCTerm -> LCType
infer = [LCType] -> LCTerm -> LCType
go []
  where
    go :: [LCType] -> LCTerm -> LCType
go [LCType]
tl (LCVar Int
n) = LCType
-> ((LCType, [LCType]) -> LCType)
-> Maybe (LCType, [LCType])
-> LCType
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Char] -> LCType
forall a. HasCallStack => [Char] -> a
error [Char]
"Out-of-scope variable") (LCType, [LCType]) -> LCType
forall a b. (a, b) -> a
fst (Maybe (LCType, [LCType]) -> LCType)
-> ([LCType] -> Maybe (LCType, [LCType])) -> [LCType] -> LCType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [LCType] -> Maybe (LCType, [LCType])
forall a. [a] -> Maybe (a, [a])
uncons ([LCType] -> LCType) -> [LCType] -> LCType
forall a b. (a -> b) -> a -> b
$ Int -> [LCType] -> [LCType]
forall a. Int -> [a] -> [a]
drop Int
n [LCType]
tl
    go [LCType]
tl (LCLam LCType
t LCTerm
b)
      | LCKind
LCStar <- LCType -> LCKind
inferKind LCType
t
      = LCType
v LCType -> LCType -> LCType
`LCArr` [LCType] -> LCTerm -> LCType
go (LCType
v LCType -> [LCType] -> [LCType]
forall a. a -> [a] -> [a]
: [LCType]
tl) LCTerm
b
      | Bool
otherwise
      = [Char] -> LCType
forall a. HasCallStack => [Char] -> a
error [Char]
"Function argument kind mismatch"
      where
        v :: LCType
v = LCType -> LCType
reduceType LCType
t
    go [LCType]
tl (LCApp LCTerm
f LCTerm
a)
      | LCArr LCType
at LCType
rt <- [LCType] -> LCTerm -> LCType
go [LCType]
tl LCTerm
f
      , LCType
at LCType -> LCType -> Bool
forall a. Eq a => a -> a -> Bool
== [LCType] -> LCTerm -> LCType
go [LCType]
tl LCTerm
a
      = LCType
rt
      | Bool
otherwise
      = [Char] -> LCType
forall a. HasCallStack => [Char] -> a
error [Char]
"Function argument type mismatch"

inferKind :: LCType -> LCKind
inferKind :: LCType -> LCKind
inferKind = [LCKind] -> LCType -> LCKind
go []
  where
    go :: [LCKind] -> LCType -> LCKind
go [LCKind]
_  LCType
LCBase = LCKind
LCStar
    go [LCKind]
kl (LCTVar Int
n) = LCKind
-> ((LCKind, [LCKind]) -> LCKind)
-> Maybe (LCKind, [LCKind])
-> LCKind
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Char] -> LCKind
forall a. HasCallStack => [Char] -> a
error [Char]
"Out-of-scope variable") (LCKind, [LCKind]) -> LCKind
forall a b. (a, b) -> a
fst (Maybe (LCKind, [LCKind]) -> LCKind)
-> ([LCKind] -> Maybe (LCKind, [LCKind])) -> [LCKind] -> LCKind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [LCKind] -> Maybe (LCKind, [LCKind])
forall a. [a] -> Maybe (a, [a])
uncons ([LCKind] -> LCKind) -> [LCKind] -> LCKind
forall a b. (a -> b) -> a -> b
$ Int -> [LCKind] -> [LCKind]
forall a. Int -> [a] -> [a]
drop Int
n [LCKind]
kl
    go [LCKind]
kl (LCArr LCType
a LCType
b)
      | LCKind
LCStar <- [LCKind] -> LCType -> LCKind
go [LCKind]
kl LCType
a
      , LCKind
LCStar <- [LCKind] -> LCType -> LCKind
go [LCKind]
kl LCType
b
      = LCKind
LCStar
      | Bool
otherwise
      = [Char] -> LCKind
forall a. HasCallStack => [Char] -> a
error [Char]
"Arrow kind mismatch"
    go [LCKind]
kl (LCTTLam LCKind
k LCType
b) = LCKind
k LCKind -> LCKind -> LCKind
`LCKArr` [LCKind] -> LCType -> LCKind
go (LCKind
k LCKind -> [LCKind] -> [LCKind]
forall a. a -> [a] -> [a]
: [LCKind]
kl) LCType
b
    go [LCKind]
kl (LCTTApp LCType
f LCType
a)
      | LCKArr LCKind
ak LCKind
rk <- [LCKind] -> LCType -> LCKind
go [LCKind]
kl LCType
f
      , LCKind
ak LCKind -> LCKind -> Bool
forall a. Eq a => a -> a -> Bool
== [LCKind] -> LCType -> LCKind
go [LCKind]
kl LCType
a
      = LCKind
rk
      | Bool
otherwise
      = [Char] -> LCKind
forall a. HasCallStack => [Char] -> a
error [Char]
"Function argument kind mismatch"