-- | Alpha-equality for schemes
{-# LANGUAGE FlexibleContexts #-}

module AST.Term.Scheme.AlphaEq
    ( alphaEq
    ) where

import           AST
import           AST.Class.Has (HasChild(..))
import           AST.Class.ZipMatch (zipMatch_)
import           AST.Recurse (wrapM, (#>>))
import           AST.Term.Scheme
import           AST.Unify
import           AST.Unify.New (newTerm)
import           AST.Unify.QuantifiedVar
import           AST.Unify.Term (UTerm(..), uBody)
import qualified Control.Lens as Lens
import           Control.Lens.Operators
import           Data.Constraint
import           Data.Maybe (fromMaybe)
import           Data.Proxy (Proxy(..))

import           Prelude.Compat

makeQVarInstancesInScope ::
    Unify m typ =>
    Tree QVars typ -> m (Tree (QVarInstances (UVarOf m)) typ)
makeQVarInstancesInScope (QVars foralls) =
    traverse makeSkolem foralls <&> QVarInstances
    where
        makeSkolem c = scopeConstraints >>= newVar binding . USkolem . (c <>)

schemeBodyToType ::
    (Unify m typ, HasChild varTypes typ, Ord (QVar typ)) =>
    Tree varTypes (QVarInstances (UVarOf m)) -> Tree typ (UVarOf m) -> m (Tree (UVarOf m) typ)
schemeBodyToType foralls x =
    case x ^? quantifiedVar >>= getForAll of
    Nothing -> newTerm x
    Just r -> pure r
    where
        getForAll v = foralls ^? getChild . _QVarInstances . Lens.ix v

schemeToRestrictedType ::
    forall m varTypes typ.
    ( Monad m
    , KTraversable varTypes
    , KNodesConstraint varTypes (Unify m)
    , HasScheme varTypes m typ
    ) =>
    Tree Pure (Scheme varTypes typ) -> m (Tree (UVarOf m) typ)
schemeToRestrictedType (Pure (Scheme vars typ)) =
    do
        foralls <- traverseK (Proxy @(Unify m) #> makeQVarInstancesInScope) vars
        wrapM (Proxy @(HasScheme varTypes m) #>> schemeBodyToType foralls) typ

goUTerm ::
    forall m t.
    Unify m t =>
    Tree (UVarOf m) t -> Tree (UTerm (UVarOf m)) t ->
    Tree (UVarOf m) t -> Tree (UTerm (UVarOf m)) t ->
    m ()
goUTerm xv USkolem{} yv USkolem{} =
    do
        bindVar binding xv (UInstantiated yv)
        bindVar binding yv (UInstantiated xv)
goUTerm xv (UInstantiated xt) yv (UInstantiated yt)
    | xv == yt && yv == xt = pure ()
    | otherwise = unifyError (SkolemEscape xv)
goUTerm xv USkolem{} yv UUnbound{} = bindVar binding yv (UToVar xv)
goUTerm xv UUnbound{} yv USkolem{} = bindVar binding xv (UToVar yv)
goUTerm _ (UToVar xv) yv yu =
    do
        xu <- lookupVar binding xv
        goUTerm xv xu yv yu
goUTerm xv xu _ (UToVar yv) =
    do
        yu <- lookupVar binding yv
        goUTerm xv xu yv yu
goUTerm xv USkolem{} yv _ = unifyError (SkolemUnified xv yv)
goUTerm xv _ yv USkolem{} = unifyError (SkolemUnified yv xv)
goUTerm xv UUnbound{} yv yu = goUTerm xv yu yv yu -- Term created in structure mismatch
goUTerm xv xu yv UUnbound{} = goUTerm xv xu yv xu -- Term created in structure mismatch
goUTerm _ (UTerm xt) _ (UTerm yt) =
    withDict (unifyRecursive (Proxy @m) (Proxy @t)) $
    zipMatch_ (Proxy @(Unify m) #> goUVar) (xt ^. uBody) (yt ^. uBody)
    & fromMaybe (structureMismatch (\x y -> x <$ goUVar x y) xt yt)
goUTerm _ _ _ _ = error "unexpected state at alpha-eq"

goUVar ::
    Unify m t =>
    Tree (UVarOf m) t -> Tree (UVarOf m) t -> m ()
goUVar xv yv =
    do
        xu <- lookupVar binding xv
        yu <- lookupVar binding yv
        goUTerm xv xu yv yu

-- Check for alpha equality. Raises a `unifyError` when mismatches.
alphaEq ::
    ( KTraversable varTypes
    , KNodesConstraint varTypes (Unify m)
    , HasScheme varTypes m typ
    ) =>
    Tree Pure (Scheme varTypes typ) ->
    Tree Pure (Scheme varTypes typ) ->
    m ()
alphaEq s0 s1 =
    do
        t0 <- schemeToRestrictedType s0
        t1 <- schemeToRestrictedType s1
        goUVar t0 t1