{-# LANGUAGE CPP
           , ScopedTypeVariables
           , GADTs
           , DataKinds
           , KindSignatures
           , GeneralizedNewtypeDeriving
           , TypeOperators
           , FlexibleContexts
           , FlexibleInstances
           , OverloadedStrings
           , PatternGuards
           , Rank2Types
           #-}

{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
----------------------------------------------------------------
--                                                    2016.05.28
-- |
-- Module      :  Language.Hakaru.Syntax.TypeCheck
-- Copyright   :  Copyright (c) 2016 the Hakaru team
-- License     :  BSD3
-- Maintainer  :  wren@community.haskell.org
-- Stability   :  experimental
-- Portability :  GHC-only
--
-- Bidirectional type checking for our AST.
----------------------------------------------------------------
module Language.Hakaru.Syntax.TypeCheck
    (
    -- * The type checking monad
      TypeCheckError
    , TypeCheckMonad(), runTCM, unTCM
    , TypeCheckMode(..)
    -- * Type checking itself
    , inferable
    , mustCheck
    , TypedAST(..)
    , inferType
    , checkType
    ) where

import           Prelude hiding (id, (.))
import           Control.Category
import           Data.Proxy            (KProxy(..))
import           Data.Text             (pack, Text())
import qualified Data.IntMap           as IM
import qualified Data.Traversable      as T
import qualified Data.List.NonEmpty    as L
import qualified Data.Foldable         as F
import qualified Data.Sequence         as S
import qualified Data.Vector           as V
#if __GLASGOW_HASKELL__ < 710
import           Control.Applicative   (Applicative(..), (<$>))
import           Data.Monoid           (Monoid(..))
#endif
import qualified Language.Hakaru.Parser.AST as U

import Data.Number.Nat                (fromNat)
import Language.Hakaru.Syntax.IClasses
import Language.Hakaru.Types.DataKind (Hakaru(..), HData', HBool)
import Language.Hakaru.Types.Sing
import Language.Hakaru.Types.Coercion
import Language.Hakaru.Types.HClasses
    ( HEq, hEq_Sing, HOrd, hOrd_Sing, HSemiring, hSemiring_Sing
    , hRing_Sing, sing_HRing, hFractional_Sing, sing_HFractional
    , sing_NonNegative, hDiscrete_Sing
    , HIntegrable(..)
    , HRadical(..), HContinuous(..))
import Language.Hakaru.Syntax.ABT
import Language.Hakaru.Syntax.Datum
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.AST.Sing
    (sing_Literal, sing_MeasureOp)

----------------------------------------------------------------
----------------------------------------------------------------

-- | Those terms from which we can synthesize a unique type. We are
-- also allowed to check them, via the change-of-direction rule.
inferable :: U.AST -> Bool
inferable = not . mustCheck


-- | Those terms whose types must be checked analytically. We cannot
-- synthesize (unambiguous) types for these terms.
--
-- N.B., this function assumes we're in 'StrictMode'. If we're
-- actually in 'LaxMode' then a handful of AST nodes behave
-- differently: in particular, 'U.NaryOp_', 'U.Superpose', and
-- 'U.Case_'. In strict mode those cases can just infer one of their
-- arguments and then check the rest against the inferred type.
-- (For case-expressions, we must also check the scrutinee since
-- it's type cannot be unambiguously inferred from the patterns.)
-- Whereas in lax mode we must infer all arguments and then take
-- the lub of their types in order to know which coercions to
-- introduce.
mustCheck :: U.AST -> Bool
mustCheck e = caseVarSyn e (const False) go
    where
    go :: U.MetaTerm -> Bool
    go (U.Lam_ _  e2)     = mustCheck' e2

    -- In general, applications don't require checking; we infer
    -- the first applicand to get the type of the second and of the
    -- result, then we check the second and return the result type.
    -- Thus, applications will only yield \"must check\" errors if
    -- the function does; but that's the responsability of the
    -- function term, not of the application term it's embedded
    -- within.
    --
    -- However, do note that the above only applies to lambda-defined
    -- functions, not to all \"function-like\" things. In particular,
    -- data constructors require checking (see the note below).
    go (U.App_ _  _)      = False

    -- We follow Dunfield & Pientka and \Pi\Sigma in inferring or
    -- checking depending on what the body requires. This is as
    -- opposed to the TLDI'05 paper, which always infers @e2@ but
    -- will check or infer the @e1@ depending on whether it has a
    -- type annotation or not.
    go (U.Let_ _ e2)      = mustCheck' e2

    go (U.Ann_ _ _)       = False
    go (U.CoerceTo_ _ _)  = False
    go (U.UnsafeTo_ _ _)  = False

    -- In general (according to Dunfield & Pientka), we should be
    -- able to infer the result of a fully saturated primop by
    -- looking up it's type and then checking all the arguments.
    go (U.PrimOp_  _ _)   = False
    go (U.ArrayOp_ _ es)  = F.all mustCheck es

    -- In strict mode: if we can infer any of the arguments, then
    -- we can check all the rest at the same type.
    --
    -- BUG: in lax mode we must be able to infer all of them;
    -- otherwise we may not be able to take the lub of the types
    go (U.NaryOp_   _ es) = F.all mustCheck es
    go (U.Superpose_ pes) = F.all (mustCheck . snd) pes

    -- Our numeric literals aren't polymorphic, so we can infer
    -- them just fine. Or rather, according to our AST they aren't;
    -- in truth, they are in the surface language. Which is part
    -- of the reason for needing 'LaxMode'
    --
    -- TODO: correctly capture our surface-language semantics by
    -- always treating literals as if we're in 'LaxMode'.
    go (U.Literal_ _) = False

    -- I return true because most folks (neelk, Pfenning, Dunfield
    -- & Pientka) say all data constructors mustCheck. The main
    -- issue here is dealing with (polymorphic) sum types and phantom
    -- types, since these mean the term doesn't contain enough
    -- information for all the type indices. Even for record types,
    -- there's the additional issue of the term (perhaps) not giving
    -- enough information about the nominal type even if it does
    -- give enough info for the structural type.
    --
    -- Still, given those limitations, we should be able to infer
    -- a subset of data constructors which happen to avoid the
    -- problem areas. In particular, given that our surface syntax
    -- doesn't use the sum-of-products representation, we should
    -- be able to rely on symbol resolution to avoid the nominal
    -- typing issue. Thus, for non-empty arrays and non-phantom
    -- record types, we should be able to infer the whole type
    -- provided we can infer the various subterms.
    go U.Empty_          = True
    go (U.Pair_ e1 e2)   = mustCheck  e1 && mustCheck e2
    go (U.Array_ _ e1)   = mustCheck' e1
    go (U.Datum_ _)      = True

    -- TODO: everyone says this, but it seems to me that if we can
    -- infer any of the branches (and check the rest to agree) then
    -- we should be able to infer the whole thing... Or maybe the
    -- problem is that the change-of-direction rule might send us
    -- down the wrong path?
    go (U.Case_ _ _)     = True

    go (U.Dirac_  e1)        = mustCheck  e1
    go (U.MBind_  _   e2)    = mustCheck' e2
    go (U.Plate_  _   e2)    = mustCheck' e2
    go (U.Chain_  _   e2 e3) = mustCheck  e2 && mustCheck' e3
    go (U.MeasureOp_ _ _)    = False
    go (U.Integrate_  _ _ _) = False
    go (U.Summate_    _ _ _) = False
    go (U.Product_    _ _ _) = False
    go U.Reject_             = True
    go (U.Expect_ _ e2)      = mustCheck' e2
    go (U.Observe_  e1  _)   = mustCheck  e1

mustCheck'
    :: MetaABT U.SourceSpan U.Term '[ 'U.U ] 'U.U
    -> Bool
mustCheck' e = caseBind e $ \_ e' -> mustCheck e'

----------------------------------------------------------------
----------------------------------------------------------------

type Input = Maybe (V.Vector Text)

type Ctx = VarSet ('KProxy :: KProxy Hakaru)

data TypeCheckMode = StrictMode | LaxMode | UnsafeMode
    deriving (Read, Show)

type TypeCheckError = Text

newtype TypeCheckMonad a =
    TCM { unTCM :: Ctx
                -> Input
                -> TypeCheckMode
                -> Either TypeCheckError a }

runTCM :: TypeCheckMonad a -> Input -> TypeCheckMode -> Either TypeCheckError a
runTCM m = unTCM m emptyVarSet

instance Functor TypeCheckMonad where
    fmap f m = TCM $ \ctx input mode -> fmap f (unTCM m ctx input mode)

instance Applicative TypeCheckMonad where
    pure x    = TCM $ \_ _ _ -> Right x
    mf <*> mx = mf >>= \f -> fmap f mx

-- TODO: ensure this instance has the appropriate strictness
instance Monad TypeCheckMonad where
    return   = pure
    mx >>= k =
        TCM $ \ctx input mode ->
        unTCM mx ctx input mode >>= \x ->
        unTCM (k x) ctx input mode

{-
-- We could provide this instance, but there's no decent error
-- message to give for the 'empty' case that works in all circumstances.
-- Because we only would need this to define 'inferOneCheckOthers',
-- we inline the definition there instead.
instance Alternative TypeCheckMonad where
    empty   = failwith "Alternative.empty"
    x <|> y = TCM $ \ctx mode ->
        case unTCM x ctx mode of
        Left  _ -> unTCM y ctx mode
        Right e -> Right e
-}

showT :: Show a => a -> Text
showT = pack . show

show1T :: Show1 a => a (i :: Hakaru) -> Text
show1T = pack . show1


-- | Return the mode in which we're checking\/inferring types.
getInput :: TypeCheckMonad Input
getInput = TCM $ \_ input _ -> Right input

-- | Return the mode in which we're checking\/inferring types.
getMode :: TypeCheckMonad TypeCheckMode
getMode = TCM $ \_ _ mode -> Right mode

-- | Extend the typing context, but only locally.
pushCtx
    :: Variable (a :: Hakaru)
    -> TypeCheckMonad b
    -> TypeCheckMonad b
pushCtx x (TCM m) = TCM (m . insertVarSet x)

getCtx :: TypeCheckMonad Ctx
getCtx = TCM $ \ctx _ _ -> Right ctx

failwith :: TypeCheckError -> TypeCheckMonad r
failwith e = TCM $ \_ _ _ -> Left e

failwith_ :: TypeCheckError -> TypeCheckMonad r
failwith_ = failwith

makeErrMsg
    :: Text
    -> Maybe U.SourceSpan
    -> Text
    -> TypeCheckMonad TypeCheckError
makeErrMsg header sourceSpan footer = do
  input_ <- getInput
  case (sourceSpan, input_) of
    (Just s, Just input) ->
          return $ mconcat [ header
                           , U.printSourceSpan s input
                           , footer
                           ]
    _                    ->
          return $ mconcat [ header, "\n", footer ]

-- | Fail with a type-mismatch error.
typeMismatch
    :: Maybe U.SourceSpan
    -> Either Text (Sing (a :: Hakaru))
    -> Either Text (Sing (b :: Hakaru))
    -> TypeCheckMonad r
typeMismatch s typ1 typ2 = failwith =<<
    makeErrMsg
     "Type Mismatch:\n\n"
     s
     (mconcat [ "expected "
              , msg1
              , ", found "
              , msg2
              ])
    where
    msg1 = case typ1 of { Left msg -> msg; Right typ -> show1T typ }
    msg2 = case typ2 of { Left msg -> msg; Right typ -> show1T typ }

missingInstance
    :: Text
    -> Sing (a :: Hakaru)
    -> Maybe U.SourceSpan
    -> TypeCheckMonad r
missingInstance clas typ s = failwith =<<
   makeErrMsg
    "Missing Instance: "
    s
    (mconcat $ ["No ", clas, " instance for type ", showT typ])

missingLub
    :: Sing (a :: Hakaru)
    -> Sing (b :: Hakaru)
    -> Maybe U.SourceSpan
    -> TypeCheckMonad r
missingLub typ1 typ2 s = failwith =<<
    makeErrMsg
     "Missing common type:\n\n"
     s
     (mconcat ["No lub of types ", showT typ1, " and ", showT typ2])

-- we can't have free variables, so it must be a typo
ambiguousFreeVariable
    :: Text
    -> Maybe U.SourceSpan
    -> TypeCheckMonad r
ambiguousFreeVariable x s = failwith =<<
    makeErrMsg
     (mconcat $ ["Name not in scope: ", x])
     s
     " perhaps it is a typo?"

ambiguousNullCoercion
    :: Maybe U.SourceSpan
    -> TypeCheckMonad r
ambiguousNullCoercion s = failwith =<<
    makeErrMsg
     "Cannot infer type for null-coercion over a checking term."
     s
     "Please add a type annotation to either the term being coerced or the result of the coercion."

ambiguousEmptyNary
    :: Maybe U.SourceSpan
    -> TypeCheckMonad r
ambiguousEmptyNary s = failwith =<<
    makeErrMsg
     "Cannot infer unambiguous type for empty n-ary operator."
     s
     "Try adding an annotation on the result of the operator."

ambiguousMustCheckNary
    :: Maybe U.SourceSpan
    -> TypeCheckMonad r
ambiguousMustCheckNary s = failwith =<<
    makeErrMsg
     "Could not infer any of the arguments."
     s
     "Try adding a type annotation to at least one of them."

ambiguousMustCheck
    :: Maybe U.SourceSpan
    -> TypeCheckMonad r
ambiguousMustCheck s = failwith =<<
    makeErrMsg
     "Cannot infer types for checking terms."
     s
     "Please add a type annotation."

argumentNumberError
     :: TypeCheckMonad r
argumentNumberError = failwith =<<
    makeErrMsg "Argument error:" Nothing "Passed wrong number of arguments"

----------------------------------------------------------------
----------------------------------------------------------------
-- BUG: haddock doesn't like annotations on GADT constructors. So
-- here we'll avoid using the GADT syntax, even though it'd make
-- the data type declaration prettier\/cleaner.
-- <https://github.com/hakaru-dev/hakaru/issues/6>
--
-- | The @e' ∈ τ@ portion of the inference judgement.
data TypedAST (abt :: [Hakaru] -> Hakaru -> *)
    = forall b. TypedAST !(Sing b) !(abt '[] b)

instance Show2 abt => Show (TypedAST abt) where
    showsPrec p (TypedAST typ e) =
        showParen_12 p "TypedAST" typ e


makeVar :: forall (a :: Hakaru). Variable 'U.U -> Sing a -> Variable a
makeVar (Variable hintID nameID _) typ =
    Variable hintID nameID typ


inferBinder
    :: (ABT Term abt)
    => Sing a
    -> MetaABT U.SourceSpan U.Term '[ 'U.U ] 'U.U
    -> (forall b. Sing b -> abt '[ a ] b -> TypeCheckMonad r)
    -> TypeCheckMonad r
inferBinder typ e k =
    caseBind e $ \x e1 -> do
    let x' = x {varType = typ}
    TypedAST typ1 e1' <- pushCtx x' (inferType e1)
    k typ1 (bind x' e1')

inferBinders
    :: (ABT Term abt)
    => List1 Variable xs
    -> U.AST
    -> (forall a. Sing a -> abt xs a -> TypeCheckMonad r)
    -> TypeCheckMonad r
inferBinders = \xs e k -> do
    TypedAST typ e' <- pushesCtx xs (inferType e)
    k typ (binds_ xs e')
    where
    -- TODO: make sure the 'TCM'\/'unTCM' stuff doesn't do stupid asymptotic things
    pushesCtx
        :: List1 Variable (xs :: [Hakaru])
        -> TypeCheckMonad b
        -> TypeCheckMonad b
    pushesCtx Nil1         m = m
    pushesCtx (Cons1 x xs) m = pushesCtx xs (TCM (unTCM m . insertVarSet x))


checkBinder
    :: (ABT Term abt)
    => Sing a
    -> Sing b
    -> MetaABT U.SourceSpan U.Term '[ 'U.U ] 'U.U
    -> TypeCheckMonad (abt '[ a ] b)
checkBinder typ eTyp e =
    caseBind e $ \x e1 -> do
    let x' = x {varType = typ}
    pushCtx x' (bind x' <$> checkType eTyp e1)


checkBinders
    :: (ABT Term abt)
    => List1 Variable xs
    -> Sing a
    -> U.AST
    -> TypeCheckMonad (abt xs a)
checkBinders xs eTyp e =
    case xs of
    Nil1        -> checkType eTyp e
    Cons1 x xs' -> pushCtx x (bind x <$> checkBinders xs' eTyp e)


----------------------------------------------------------------
-- | Given a typing environment and a term, synthesize the term's
-- type (and produce an elaborated term):
--
-- > Γ ⊢ e ⇒ e' ∈ τ
inferType
    :: forall abt
    .  (ABT Term abt)
    => U.AST
    -> TypeCheckMonad (TypedAST abt)
inferType = inferType_
  where
  -- HACK: we need to give these local definitions to avoid
  -- \"ambiguity\" in the choice of ABT instance...
  checkType_ :: forall b. Sing b -> U.AST -> TypeCheckMonad (abt '[] b)
  checkType_ = checkType

  inferOneCheckOthers_ ::
      [U.AST] -> TypeCheckMonad (TypedASTs abt)
  inferOneCheckOthers_ = inferOneCheckOthers


  inferVariable
      :: Maybe U.SourceSpan
      -> Variable 'U.U
      -> TypeCheckMonad (TypedAST abt)
  inferVariable sourceSpan (Variable hintID nameID _) = do
      ctx <- getCtx
      case IM.lookup (fromNat nameID) (unVarSet ctx) of
        Just (SomeVariable x') ->
            return $ TypedAST (varType x') (var x')
        Nothing -> ambiguousFreeVariable hintID sourceSpan


  -- HACK: We need this monomorphic binding so that GHC doesn't get
  -- confused about which @(ABT AST abt)@ instance to use in recursive
  -- calls.
  inferType_ :: U.AST -> TypeCheckMonad (TypedAST abt)
  inferType_ e0 =
    let s = getMetadata e0 in
    caseVarSyn e0 (inferVariable s) (go s)
    where
    go :: Maybe U.SourceSpan -> U.MetaTerm -> TypeCheckMonad (TypedAST abt)
    go sourceSpan t =
      case t of
       U.Lam_ (U.SSing typ) e -> do
           inferBinder typ e $ \typ2 e2 ->
               return . TypedAST (SFun typ typ2) $ syn (Lam_ :$ e2 :* End)

       U.App_ e1 e2 -> do
           TypedAST typ1 e1' <- inferType_ e1
           case typ1 of
               SFun typ2 typ3 -> do
                   e2' <- checkType_ typ2 e2
                   return . TypedAST typ3 $ syn (App_ :$ e1' :* e2' :* End)
               _ -> typeMismatch sourceSpan (Left "function type") (Right typ1)
           -- The above is the standard rule that everyone uses.
           -- However, if the @e1@ is a lambda (rather than a primop
           -- or a variable), then it will require a type annotation.
           -- Couldn't we just as well add an additional rule that
           -- says to infer @e2@ and then infer @e1@ under the assumption
           -- that the variable has the same type as the argument? (or
           -- generalize that idea to keep track of a bunch of arguments
           -- being passed in; sort of like a dual to our typing
           -- environments?) Is this at all related to what Dunfield
           -- & Neelk are doing in their ICFP'13 paper with that
           -- \"=>=>\" judgment? (prolly not, but...)

       U.Let_ e1 e2 -> do
           TypedAST typ1 e1' <- inferType_ e1
           inferBinder typ1 e2 $ \typ2 e2' ->
               return . TypedAST typ2 $ syn (Let_ :$ e1' :* e2' :* End)

       U.Ann_ (U.SSing typ1) e1 -> do
           -- N.B., this requires that @typ1@ is a 'Sing' not a 'Proxy',
           -- since we can't generate a 'Sing' from a 'Proxy'.
           TypedAST typ1 <$> checkType_ typ1 e1

       U.PrimOp_  op es -> inferPrimOp  op es
       U.ArrayOp_ op es -> inferArrayOp op es
       U.NaryOp_  op es -> do
           mode <- getMode
           TypedASTs typ es' <-
               case mode of
               StrictMode -> inferOneCheckOthers_ es
               LaxMode    -> inferLubType sourceSpan es
               UnsafeMode -> inferLubType sourceSpan es
           op' <- make_NaryOp typ op
           return . TypedAST typ $ syn (NaryOp_ op' $ S.fromList es')

       U.Literal_ (Some1 v) ->
           -- TODO: in truth, we can infer this to be any supertype
           -- (adjusting the concrete @v@ as necessary). That is, the
           -- surface language treats numeric literals as polymorphic,
           -- so we should capture that somehow--- even if we're not
           -- in 'LaxMode'. We'll prolly need to handle this
           -- subtype-polymorphism the same way as we do for for
           -- everything when in 'UnsafeMode'.
           return . TypedAST (sing_Literal v) $ syn (Literal_ v)

       -- TODO: we can try to do 'U.Case_' by using branch-based
       -- variants of 'inferOneCheckOthers' and 'inferLubType' depending
       -- on the mode; provided we can in fact infer the type of the
       -- scrutinee. N.B., if we add this case, then we need to update
       -- 'mustCheck' to return the right thing.

       U.CoerceTo_ (Some2 c) e1 ->
           case singCoerceDomCod c of
           Nothing
               | inferable e1 -> inferType_ e1
               | otherwise    -> ambiguousNullCoercion sourceSpan
           Just (dom,cod) -> do
               e1' <- checkType_ dom e1
               return . TypedAST cod $ syn (CoerceTo_ c :$ e1' :* End)

       U.UnsafeTo_ (Some2 c) e1 ->
           case singCoerceDomCod c of
           Nothing
               | inferable e1 -> inferType_ e1
               | otherwise    -> ambiguousNullCoercion sourceSpan
           Just (dom,cod) -> do
               e1' <- checkType_ cod e1
               return . TypedAST dom $ syn (UnsafeFrom_ c :$ e1' :* End)

       U.MeasureOp_ (U.SomeOp op) es -> do
           let (typs, typ1) = sing_MeasureOp op
           es' <- checkSArgs typs es
           return . TypedAST (SMeasure typ1) $ syn (MeasureOp_ op :$ es')

       U.Pair_ e1 e2 -> do
           TypedAST typ1 e1' <- inferType_ e1
           TypedAST typ2 e2' <- inferType_ e2
           return . TypedAST (sPair typ1 typ2) $
                  syn (Datum_ $ dPair_ typ1 typ2 e1' e2')

       U.Array_ e1 e2 -> do
           e1' <- checkType_ SNat e1
           inferBinder SNat e2 $ \typ2 e2' ->
               return . TypedAST (SArray typ2) $ syn (Array_ e1' e2')

       U.Case_ e1 branches -> do
           TypedAST typ1 e1' <- inferType_ e1
           mode <- getMode
           case mode of
               StrictMode -> inferCaseStrict typ1 e1' branches
               LaxMode    -> inferCaseLax sourceSpan typ1 e1' branches
               UnsafeMode -> inferCaseLax sourceSpan typ1 e1' branches

       U.Dirac_ e1 -> do
           TypedAST typ1 e1' <- inferType_ e1
           return . TypedAST (SMeasure typ1) $ syn (Dirac :$ e1' :* End)

       U.MBind_ e1 e2 ->
           caseBind e2 $ \x e2' -> do
           TypedAST typ1 e1' <- inferType_ e1
           case typ1 of
               SMeasure typ2 ->
                   let x' = makeVar x typ2 in
                   pushCtx x' $ do
                       TypedAST typ3 e3' <- inferType_ e2'
                       case typ3 of
                           SMeasure _ ->
                               return . TypedAST typ3 $
                                   syn (MBind :$ e1' :* bind x' e3' :* End)
                           _ -> typeMismatch sourceSpan (Left "HMeasure") (Right typ3)
                   {-
                   -- BUG: the \"ambiguous\" @abt@ issue again...
                   inferBinder typ2 e2 $ \typ3 e2' ->
                       case typ3 of
                       SMeasure _ -> return . TypedAST typ3 $
                           syn (MBind :$ e1' :* e2' :* End)
                       _ -> typeMismatch (Left "HMeasure") (Right typ3)
                   -}
               _ -> typeMismatch sourceSpan (Left "HMeasure") (Right typ1)

       U.Plate_ e1 e2 ->
           caseBind e2 $ \x e2' -> do
           e1' <- checkType_ SNat e1
           let x' = makeVar x SNat
           pushCtx x' $ do
               TypedAST typ2 e3' <- inferType_ e2'
               case typ2 of
                   SMeasure typ3 ->
                       return . TypedAST (SMeasure . SArray $ typ3) $
                              syn (Plate :$ e1' :* bind x' e3' :* End)
                   _ -> typeMismatch sourceSpan (Left "HMeasure") (Right typ2)

       U.Chain_ e1 e2 e3 ->
           caseBind e3 $ \x e3' -> do
           e1' <- checkType_ SNat e1
           TypedAST typ2 e2' <- inferType_ e2
           let x' = makeVar x typ2
           pushCtx x' $ do
               TypedAST typ3 e4' <- inferType_ e3'
               case typ3 of
                   SMeasure (SData (STyCon sym `STyApp` a `STyApp` b) _) ->
                       case (jmEq1 sym sSymbol_Pair, jmEq1 b typ2) of
                       (Just Refl, Just Refl) ->
                           return . TypedAST (SMeasure $ sPair (SArray a) typ2) $
                                  syn (Chain :$ e1' :* e2' :* bind x' e4' :* End)
                       _ -> typeMismatch sourceSpan (Left "HMeasure(HPair)") (Right typ3)
                   _     -> typeMismatch sourceSpan (Left "HMeasure(HPair)") (Right typ3)

       U.Integrate_ e1 e2 e3 -> do
           e1' <- checkType_ SReal e1
           e2' <- checkType_ SReal e2
           e3' <- checkBinder SReal SProb e3
           return . TypedAST SProb $
                  syn (Integrate :$ e1' :* e2' :* e3' :* End)

       U.Summate_ e1 e2 e3 -> do
           TypedAST typ1 e1' <- inferType e1
           e2' <- checkType_ typ1 e2
           inferBinder typ1 e3 $ \typ2 ee' ->
               case (hDiscrete_Sing typ1, hSemiring_Sing typ2) of
                 (Just h1, Just h2) ->
                     return . TypedAST typ2 $
                            syn (Summate h1 h2 :$ e1' :* e2' :* ee' :* End)
                 _                  -> failwith_ "Summate given bounds which are not discrete"

       U.Product_ e1 e2 e3 -> do
           TypedAST typ1 e1' <- inferType e1
           e2' <- checkType_ typ1 e2
           inferBinder typ1 e3 $ \typ2 e3' ->
               case (hDiscrete_Sing typ1, hSemiring_Sing typ2) of
                 (Just h1, Just h2) ->
                     return . TypedAST typ2 $
                            syn (Product h1 h2 :$ e1' :* e2' :* e3' :* End)
                 _                  -> failwith_ "Product given bounds which are not discrete"

       U.Expect_ e1 e2 -> do
           TypedAST typ1 e1' <- inferType_ e1
           case typ1 of
               SMeasure typ2 -> do
                   e2' <- checkBinder typ2 SProb e2
                   return . TypedAST SProb $ syn (Expect :$ e1' :* e2' :* End)
               _ -> typeMismatch sourceSpan (Left "HMeasure") (Right typ1)

       U.Observe_ e1 e2 -> do
           TypedAST typ1 e1' <- inferType_ e1
           case typ1 of
               SMeasure typ2 -> do
                   e2' <- checkType_ typ2 e2
                   return . TypedAST typ1 $ syn (Observe :$ e1' :* e2' :* End)
               _ -> typeMismatch sourceSpan (Left "HMeasure") (Right typ1)

       U.Superpose_ pes -> do
           -- TODO: clean up all this @map fst@, @map snd@, @zip@ stuff
           mode <- getMode
           TypedASTs typ es' <-
               case mode of
               StrictMode -> inferOneCheckOthers_    (L.toList $ fmap snd pes)
               LaxMode    -> inferLubType sourceSpan (L.toList $ fmap snd pes)
               UnsafeMode -> inferLubType sourceSpan (L.toList $ fmap snd pes)
           case typ of
               SMeasure _ -> do
                   ps' <- T.traverse (checkType SProb) (fmap fst pes)
                   return $ TypedAST typ (syn (Superpose_ (L.zip ps' (L.fromList es'))))
               _ -> typeMismatch sourceSpan (Left "HMeasure") (Right typ)

       _   | mustCheck e0 -> ambiguousMustCheck sourceSpan
           | otherwise    -> error "inferType: missing an inferable branch!"

  inferPrimOp
      :: U.PrimOp
      -> [U.AST]
      -> TypeCheckMonad (TypedAST abt)
  inferPrimOp U.Not es =
      case es of
        [e] -> do e' <- checkType_ sBool e
                  return . TypedAST sBool $ syn (PrimOp_ Not :$ e' :* End)
        _   -> argumentNumberError

  inferPrimOp U.Pi es =
      case es of
        [] -> return . TypedAST SProb $ syn (PrimOp_ Pi :$ End)
        _  -> argumentNumberError

  inferPrimOp U.Cos es =
      case es of
        [e] -> do e' <- checkType_ SReal e
                  return . TypedAST SReal $ syn (PrimOp_ Cos :$ e' :* End)
        _   -> argumentNumberError

  inferPrimOp U.RealPow es =
      case es of
        [e1, e2] -> do e1' <- checkType_ SProb e1
                       e2' <- checkType_ SReal e2
                       return . TypedAST SProb $
                              syn (PrimOp_ RealPow :$ e1' :* e2' :* End)
        _        -> argumentNumberError

  inferPrimOp U.Exp es =
      case es of
        [e] -> do e' <- checkType_ SReal e
                  return . TypedAST SProb $ syn (PrimOp_ Exp :$ e' :* End)
        _   -> argumentNumberError

  inferPrimOp U.Log es =
      case es of
        [e] -> do e' <- checkType_ SProb e
                  return . TypedAST SReal $ syn (PrimOp_ Log :$ e' :* End)
        _   -> argumentNumberError

  inferPrimOp U.Infinity es =
      case es of
        [] -> return . TypedAST SProb $
                     syn (PrimOp_ (Infinity HIntegrable_Prob) :$ End)
        _  -> argumentNumberError

  inferPrimOp U.GammaFunc es =
      case es of
        [e] -> do e' <- checkType_ SReal e
                  return . TypedAST SProb $ syn (PrimOp_ GammaFunc :$ e' :* End)
        _   -> argumentNumberError

  inferPrimOp U.BetaFunc es =
      case es of
        [e1, e2] -> do e1' <- checkType_ SProb e1
                       e2' <- checkType_ SProb e2
                       return . TypedAST SProb $
                              syn (PrimOp_ BetaFunc :$ e1' :* e2' :* End)
        _        -> argumentNumberError

  inferPrimOp U.Equal es =
      case es of
        [_, _] -> do mode <- getMode
                     TypedASTs typ [e1', e2'] <-
                         case mode of
                           StrictMode -> inferOneCheckOthers_ es
                           _          -> inferLubType Nothing es
                     primop <- Equal <$> getHEq typ
                     return . TypedAST sBool $
                            syn (PrimOp_ primop :$ e1' :* e2' :* End)
        _      -> argumentNumberError

  inferPrimOp U.Less es =
      case es of
        [_, _] -> do mode <- getMode
                     TypedASTs typ [e1', e2'] <-
                         case mode of
                           StrictMode -> inferOneCheckOthers_ es
                           _          -> inferLubType Nothing es
                     primop <- Less <$> getHOrd typ
                     return . TypedAST sBool $
                            syn (PrimOp_ primop :$ e1' :* e2' :* End)
        _      -> argumentNumberError

  inferPrimOp U.NatPow es =
      case es of
        [e1, e2] -> do TypedAST typ e1' <- inferType_ e1
                       e2' <- checkType_ SNat e2
                       primop <- NatPow <$> getHSemiring typ
                       return . TypedAST typ $
                              syn (PrimOp_ primop :$ e1' :* e2' :* End)
        _        -> argumentNumberError

  inferPrimOp U.Negate es =
      case es of
        [e] -> do TypedAST typ e' <- inferType_ e
                  mode <- getMode
                  SomeRing ring c <- getHRing typ mode
                  primop <- Negate <$> return ring
                  let e'' = case c of
                              CNil -> e'
                              c'   -> unLC_ . coerceTo c' $ LC_ e'
                  return . TypedAST (sing_HRing ring) $
                         syn (PrimOp_ primop :$ e'' :* End)
        _   -> argumentNumberError

  inferPrimOp U.Abs es =
      case es of
        [e] -> do TypedAST typ e' <- inferType_ e
                  mode <- getMode
                  SomeRing ring c <- getHRing typ mode
                  primop <- Abs <$> return ring
                  let e'' = case c of
                              CNil -> e'
                              c'   -> unLC_ . coerceTo c' $ LC_ e'
                  return . TypedAST (sing_NonNegative ring) $
                         syn (PrimOp_ primop :$ e'' :* End)
        _   -> argumentNumberError

  inferPrimOp U.Signum es =
      case es of
        [e] -> do TypedAST typ e' <- inferType_ e
                  mode <- getMode
                  SomeRing ring c <- getHRing typ mode
                  primop <- Signum <$> return ring
                  let e'' = case c of
                              CNil -> e'
                              c'   -> unLC_ . coerceTo c' $ LC_ e'
                  return . TypedAST (sing_HRing ring) $
                         syn (PrimOp_ primop :$ e'' :* End)
        _   -> argumentNumberError

  inferPrimOp U.Recip es =
      case es of
        [e] -> do TypedAST typ e' <- inferType_ e
                  mode <- getMode
                  SomeFractional frac c <- getHFractional typ mode
                  primop <- Recip <$> return frac
                  let e'' = case c of
                              CNil -> e'
                              c'   -> unLC_ . coerceTo c' $ LC_ e'
                  return . TypedAST (sing_HFractional frac) $
                         syn (PrimOp_ primop :$ e'' :* End)
        _   -> argumentNumberError

  -- BUG: Only defined for HRadical_Prob
  inferPrimOp U.NatRoot es =
      case es of
        [e1, e2] -> do e1' <- checkType_ SProb e1
                       e2' <- checkType_ SNat  e2
                       return . TypedAST SProb $
                              syn (PrimOp_ (NatRoot HRadical_Prob)
                                   :$ e1' :* e2' :* End)
        _   -> argumentNumberError

  -- BUG: Only defined for HContinuous_Real
  inferPrimOp U.Erf es =
      case es of
        [e] -> do e' <- checkType_ SReal e
                  return . TypedAST SReal $
                         syn (PrimOp_ (Erf HContinuous_Real)
                              :$ e' :* End)
        _   -> argumentNumberError

  inferPrimOp x es
      | Just y <- lookup x
                 [(U.Sin  , Sin  ),
                  (U.Cos  , Cos  ),
                  (U.Tan  , Tan  ),
                  (U.Asin , Asin ),
                  (U.Acos , Acos ),
                  (U.Atan , Atan ),
                  (U.Sinh , Sinh ),
                  (U.Cosh , Cosh ),
                  (U.Tanh , Tanh ),
                  (U.Asinh, Asinh),
                  (U.Acosh, Acosh),
                  (U.Atanh, Atanh)] =
      case es of
        [e] -> do e' <- checkType_ SReal e
                  return . TypedAST SReal $
                         syn (PrimOp_ y :$ e' :* End)
        _   -> argumentNumberError

  inferPrimOp x _ = error ("TODO: inferPrimOp: " ++ show x)


  inferArrayOp :: U.ArrayOp
               -> [U.AST]
               -> TypeCheckMonad (TypedAST abt)
  inferArrayOp U.Index_ es =
      case es of
        [e1, e2] -> do TypedAST typ1 e1' <- inferType_ e1
                       case typ1 of
                         SArray typ2 -> do
                           e2' <- checkType_ SNat e2
                           return . TypedAST typ2 $
                                  syn (ArrayOp_ (Index typ2) :$ e1' :* e2' :* End)
                         _ -> typeMismatch Nothing (Left "HArray") (Right typ1)
        _        -> argumentNumberError

  inferArrayOp U.Size es =
      case es of
        [e] -> do TypedAST typ e' <- inferType_ e
                  case typ of
                    SArray typ1 -> do
                       return . TypedAST SNat $
                              syn (ArrayOp_ (Size typ1) :$ e' :* End)
                    _ -> typeMismatch Nothing (Left "HArray") (Right typ)
        _   -> argumentNumberError

  inferArrayOp U.Reduce es =
      case es of
        [e1, e2, e3] -> do
           TypedAST typ e1' <- inferType_ e1
           case typ of
             SFun typ1 typ2 -> do
               Refl <- jmEq1_ typ2 (SFun typ1 typ1)
               e2' <- checkType_ typ1 e2
               e3' <- checkType_ (SArray typ1) e3
               return . TypedAST typ1 $
                      syn (ArrayOp_ (Reduce typ1)
                           :$ e1' :* e2' :* e3' :* End)
             _ -> typeMismatch Nothing (Right typ) (Left "HFun")
        _            -> argumentNumberError

make_NaryOp :: Sing a -> U.NaryOp -> TypeCheckMonad (NaryOp a)
make_NaryOp a U.And  = isBool a >>= \Refl -> return And
make_NaryOp a U.Or   = isBool a >>= \Refl -> return Or
make_NaryOp a U.Xor  = isBool a >>= \Refl -> return Xor
make_NaryOp a U.Iff  = isBool a >>= \Refl -> return Iff
make_NaryOp a U.Min  = Min  <$> getHOrd a
make_NaryOp a U.Max  = Max  <$> getHOrd a
make_NaryOp a U.Sum  = Sum  <$> getHSemiring a
make_NaryOp a U.Prod = Prod <$> getHSemiring a

isBool :: Sing a -> TypeCheckMonad (TypeEq a HBool)
isBool typ =
    case jmEq1 typ sBool of
    Just proof -> return proof
    Nothing    -> typeMismatch Nothing (Left "HBool") (Right typ)

jmEq1_ :: Sing (a :: Hakaru)
       -> Sing (b :: Hakaru)
       -> TypeCheckMonad (TypeEq a b)
jmEq1_ typA typB =
    case jmEq1 typA typB of
    Just proof -> return proof
    Nothing    -> typeMismatch Nothing (Right typA) (Right typB)


getHEq :: Sing a -> TypeCheckMonad (HEq a)
getHEq typ =
    case hEq_Sing typ of
    Just theEq -> return theEq
    Nothing    -> missingInstance "HEq" typ Nothing

getHOrd :: Sing a -> TypeCheckMonad (HOrd a)
getHOrd typ =
    case hOrd_Sing typ of
    Just theOrd -> return theOrd
    Nothing     -> missingInstance "HOrd" typ Nothing

getHSemiring :: Sing a -> TypeCheckMonad (HSemiring a)
getHSemiring typ =
    case hSemiring_Sing typ of
    Just theSemi -> return theSemi
    Nothing      -> missingInstance "HSemiring" typ Nothing

getHRing :: Sing a -> TypeCheckMode -> TypeCheckMonad (SomeRing a)
getHRing typ mode =
    case mode of
    StrictMode -> case hRing_Sing typ of
                    Just theRing -> return (SomeRing theRing CNil)
                    Nothing      -> missingInstance "HRing" typ Nothing
    LaxMode    -> case findRing typ of
                    Just proof   -> return proof
                    Nothing      -> missingInstance "HRing" typ Nothing
    UnsafeMode -> case findRing typ of
                    Just proof   -> return proof
                    Nothing      -> missingInstance "HRing" typ Nothing

getHFractional :: Sing a -> TypeCheckMode -> TypeCheckMonad (SomeFractional a)
getHFractional typ mode =
    case mode of
    StrictMode -> case hFractional_Sing typ of
                    Just theFrac -> return (SomeFractional theFrac CNil)
                    Nothing      -> missingInstance "HFractional" typ Nothing
    LaxMode    -> case findFractional typ of
                    Just proof   -> return proof
                    Nothing      -> missingInstance "HFractional" typ Nothing
    UnsafeMode -> case findFractional typ of
                    Just proof   -> return proof
                    Nothing      -> missingInstance "HFractional" typ Nothing



----------------------------------------------------------------
data TypedASTs (abt :: [Hakaru] -> Hakaru -> *)
    = forall b. TypedASTs !(Sing b) [abt '[] b]

{-
instance Show2 abt => Show (TypedASTs abt) where
    showsPrec p (TypedASTs typ es) =
        showParen_1x p "TypedASTs" typ es
-}


-- TODO: can we make this lazier in the second component of 'TypedASTs'
-- so that we can perform case analysis on the type component before
-- actually evaluating 'checkOthers'? Problem is, even though we
-- have the type to return we don't know whether the whole thing
-- will succeed or not until after calling 'checkOthers'... We could
-- handle this by changing the return type to @TypeCheckMonad (exists
-- b. (Sing b, TypeCheckMonad [abt '[] b]))@ thereby making the
-- staging explicit.
--
-- | Given a list of terms which must all have the same type, try
-- inferring each term in order until one of them succeeds and then
-- check all the others against that type. This is appropriate for
-- 'StrictMode' where we won't need to insert coercions; for
-- 'LaxMode', see 'inferLubType' instead.
inferOneCheckOthers
    :: forall abt
    .  (ABT Term abt)
    => [U.AST]
    -> TypeCheckMonad (TypedASTs abt)
inferOneCheckOthers = inferOne []
    where
    inferOne :: [U.AST] -> [U.AST] -> TypeCheckMonad (TypedASTs abt)
    inferOne ls []
        | null ls   = ambiguousEmptyNary Nothing
        | otherwise = ambiguousMustCheckNary Nothing
    inferOne ls (e:rs) = do
        m <- try $ inferType e
        case m of
            Nothing                -> inferOne (e:ls) rs
            Just (TypedAST typ e') -> do
                ls' <- checkOthers typ ls
                rs' <- checkOthers typ rs
                return (TypedASTs typ (reverse ls' ++ e' : rs'))

    checkOthers
        :: forall a. Sing a -> [U.AST] -> TypeCheckMonad [abt '[] a]
    checkOthers typ = T.traverse (checkType typ)


-- TODO: some day we may want a variant of this function which
-- returns the error message in the event that the computation fails
-- (e.g., to provide all of them if 'inferOneCheckOthers' ultimately
-- fails.
--
-- | a la @optional :: Alternative f => f a -> f (Maybe a)@ but
-- without needing the 'empty' of the 'Alternative' class.
try :: TypeCheckMonad a -> TypeCheckMonad (Maybe a)
try m = TCM $ \ctx input mode -> Right $
    case unTCM m ctx input mode of
    Left  _ -> Nothing -- Don't worry; no side effects to unwind
    Right e -> Just e

-- | Tries to typecheck in a given mode
tryWith :: TypeCheckMode -> TypeCheckMonad a -> TypeCheckMonad (Maybe a)
tryWith mode m = TCM $ \ctx input _ -> Right $
    case unTCM m ctx input mode of
    Left  _ -> Nothing
    Right e -> Just e

-- | Given a list of terms which must all have the same type, infer
-- all the terms in order and coerce them to the lub of all their
-- types. This is appropriate for 'LaxMode' where we need to insert
-- coercions; for 'StrictMode', see 'inferOneCheckOthers' instead.
inferLubType
    :: forall abt
    .  (ABT Term abt)
    => Maybe U.SourceSpan
    -> [U.AST]
    -> TypeCheckMonad (TypedASTs abt)
inferLubType s = start
    where
    start :: [U.AST] -> TypeCheckMonad (TypedASTs abt)
    start []     = ambiguousEmptyNary Nothing
    start (u:us) = do
        TypedAST  typ1 e1 <- inferType u
        TypedASTs typ2 es <- F.foldlM step (TypedASTs typ1 [e1]) us
        return (TypedASTs typ2 (reverse es))

    -- TODO: inline 'F.foldlM' and then inline this, to unpack the first argument.
    step :: TypedASTs abt -> U.AST -> TypeCheckMonad (TypedASTs abt)
    step (TypedASTs typ1 es) u = do
        TypedAST typ2 e2 <- inferType u
        case findLub typ1 typ2 of
            Nothing              -> missingLub typ1 typ2 s
            Just (Lub typ c1 c2) ->
                let es' = map (unLC_ . coerceTo c1 . LC_) es
                    e2' = unLC_ . coerceTo c2 $ LC_ e2
                in return (TypedASTs typ (e2' : es'))


inferCaseStrict
    :: forall abt a
    .  (ABT Term abt)
    => Sing a
    -> abt '[] a
    -> [U.Branch]
    -> TypeCheckMonad (TypedAST abt)
inferCaseStrict typA e1 = inferOne []
    where
    inferOne :: [U.Branch] -> [U.Branch] -> TypeCheckMonad (TypedAST abt)
    inferOne ls []
        | null ls   = ambiguousEmptyNary Nothing
        | otherwise = ambiguousMustCheckNary Nothing
    inferOne ls (b@(U.Branch_ pat e):rs) = do
        SP pat' vars <- checkPattern typA pat
        m <- try $ inferBinders vars e $ \typ e' -> do
                    ls' <- checkOthers typ ls
                    rs' <- checkOthers typ rs
                    return (TypedAST typ $ syn (Case_ e1 (reverse ls' ++ (Branch pat' e') : rs')))
        case m of
            Nothing -> inferOne (b:ls) rs
            Just m' -> return m'

    checkOthers
        :: forall b. Sing b -> [U.Branch] -> TypeCheckMonad [Branch a abt b]
    checkOthers typ = T.traverse (checkBranch typA typ)

data SomeBranch a abt = forall b. SomeBranch !(Sing b) [Branch a abt b]

-- TODO: find a better name, and move to where 'LC_' is defined.
lc :: (LC_ abt a -> LC_ abt b) -> abt '[] a -> abt '[] b
lc f = unLC_ . f . LC_

coerceTo_nonLC :: (ABT Term abt) => Coercion a b -> abt xs a -> abt xs b
coerceTo_nonLC = underBinders . lc . coerceTo

coerceFrom_nonLC :: (ABT Term abt) => Coercion a b -> abt xs b -> abt xs a
coerceFrom_nonLC = underBinders . lc . coerceFrom

-- BUG: how to make this not an orphan, without dealing with cyclic imports between AST.hs (for the 'LC_' instance), Datum.hs, and Coercion.hs?
instance (ABT Term abt) => Coerce (Branch a abt) where
    coerceTo   c (Branch pat e) = Branch pat (coerceTo_nonLC   c e)
    coerceFrom c (Branch pat e) = Branch pat (coerceFrom_nonLC c e)


inferCaseLax
    :: forall abt a
    .  (ABT Term abt)
    => Maybe U.SourceSpan
    -> Sing a
    -> abt '[] a
    -> [U.Branch]
    -> TypeCheckMonad (TypedAST abt)
inferCaseLax s typA e1 = start
    where
    start :: [U.Branch] -> TypeCheckMonad (TypedAST abt)
    start []     = ambiguousEmptyNary Nothing
    start ((U.Branch_ pat e):us) = do
        SP pat' vars <- checkPattern typA pat
        inferBinders vars e $ \typ1 e' -> do
            SomeBranch typ2 bs <- F.foldlM step (SomeBranch typ1 [Branch pat' e']) us
            return . TypedAST typ2 . syn . Case_ e1 $ reverse bs

    -- TODO: inline 'F.foldlM' and then inline this, to unpack the first argument.
    step :: SomeBranch a abt
        -> U.Branch
        -> TypeCheckMonad (SomeBranch a abt)
    step (SomeBranch typB bs) (U.Branch_ pat e) = do
        SP pat' vars <- checkPattern typA pat
        inferBinders vars e $ \typE e' ->
            case findLub typB typE of
            Nothing                     -> missingLub typB typE s
            Just (Lub typLub coeB coeE) ->
                return $ SomeBranch typLub
                    ( Branch pat' (coerceTo_nonLC coeE e')
                    : map (coerceTo coeB) bs
                    )

----------------------------------------------------------------
----------------------------------------------------------------

-- HACK: we must add the constraints that 'LCs' and 'UnLCs' are inverses.
-- TODO: how can we do that in general rather than needing to repeat
-- it here and in the various constructors of 'SCon'?
checkSArgs
    :: (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs)
    => List1 Sing typs
    -> [U.AST]
    -> TypeCheckMonad (SArgs abt args)
checkSArgs Nil1             []     = return End
checkSArgs (Cons1 typ typs) (e:es) =
    (:*) <$> checkType typ e <*> checkSArgs typs es
checkSArgs _ _ =
    error "checkSArgs: the number of types and terms doesn't match up"


-- | Given a typing environment, a type, and a term, verify that
-- the term satisfies the type (and produce an elaborated term):
--
-- > Γ ⊢ τ ∋ e ⇒ e'
checkType
    :: forall abt a
    .  (ABT Term abt)
    => Sing a
    -> U.AST
    -> TypeCheckMonad (abt '[] a)
checkType = checkType_
    where
    -- HACK: to convince GHC to stop being stupid about resolving
    -- the \"choice\" of @abt'@. I'm not sure why we don't need to
    -- use this same hack when 'inferType' calls 'checkType', but whatevs.
    inferType_ :: U.AST -> TypeCheckMonad (TypedAST abt)
    inferType_ = inferType

    checkVariable
        :: forall b
        .  Sing b
        -> Maybe U.SourceSpan
        -> Variable 'U.U
        -> TypeCheckMonad (abt '[] b)
    checkVariable typ0 sourceSpan x = do
      TypedAST typ' e0' <- inferType_ (var x)
      mode <- getMode
      case mode of
        StrictMode ->
            case jmEq1 typ0 typ' of
              Just Refl -> return e0'
              Nothing   -> typeMismatch sourceSpan (Right typ0) (Right typ')
        LaxMode    -> checkOrCoerce       sourceSpan e0' typ' typ0
        UnsafeMode -> checkOrUnsafeCoerce sourceSpan e0' typ' typ0


    checkType_
        :: forall b. Sing b -> U.AST -> TypeCheckMonad (abt '[] b)
    checkType_ typ0 e0 =
      let s = getMetadata e0 in
      caseVarSyn e0 (checkVariable typ0 s) (go s)
      where
      go sourceSpan t =
        case t of
        -- Change of direction rule suggests this doesn't need to be here
        -- We keep it here in case, we later use a U.Lam which doesn't
        -- carry the type of its variable 
        U.Lam_ (U.SSing typ) e1 ->
            case typ0 of
            SFun typ1 typ2 ->
                case jmEq1 typ1 typ of
                  Just Refl -> do e1' <- checkBinder typ1 typ2 e1
                                  return $ syn (Lam_ :$ e1' :* End)
                  Nothing   -> typeMismatch sourceSpan (Right typ1) (Right typ)
            _ -> typeMismatch sourceSpan (Right typ0) (Left "function type")

        U.Let_ e1 e2 -> do
            TypedAST typ1 e1' <- inferType_ e1
            e2' <- checkBinder typ1 typ0 e2
            return $ syn (Let_ :$ e1' :* e2' :* End)

        U.CoerceTo_ (Some2 c) e1 ->
            case singCoerceDomCod c of
            Nothing -> do
                e1' <- checkType_ typ0 e1
                return $ syn (CoerceTo_ CNil :$  e1' :* End)
            Just (dom, cod) ->
                case jmEq1 typ0 cod of
                Just Refl -> do
                    e1' <- checkType_ dom e1
                    return $ syn (CoerceTo_ c :$ e1' :* End)
                Nothing -> typeMismatch sourceSpan (Right typ0) (Right cod)

        U.UnsafeTo_ (Some2 c) e1 ->
            case singCoerceDomCod c of
            Nothing -> do
                e1' <- checkType_ typ0 e1
                return $ syn (UnsafeFrom_ CNil :$  e1' :* End)
            Just (dom, cod) ->
                case jmEq1 typ0 dom of
                Just Refl -> do
                    e1' <- checkType_ cod e1
                    return $ syn (UnsafeFrom_ c :$ e1' :* End)
                Nothing -> typeMismatch sourceSpan (Right typ0) (Right dom)

        -- TODO: Find better place to put this logic
        U.PrimOp_ U.Infinity [] -> do
            case typ0 of
              SNat  -> return $
                         syn (PrimOp_ (Infinity HIntegrable_Nat) :$ End)
              SInt  -> checkOrCoerce sourceSpan (syn (PrimOp_ (Infinity HIntegrable_Nat) :$ End))
                         SNat
                         SInt
              SProb -> return $
                         syn (PrimOp_ (Infinity HIntegrable_Prob) :$ End)
              SReal -> checkOrCoerce sourceSpan (syn (PrimOp_ (Infinity HIntegrable_Prob) :$ End))
                         SProb
                         SReal
              _     -> failwith =<<
                        makeErrMsg
                         "Type Mismatch:"
                         sourceSpan
                         "infinity can only be checked against nat or prob"

        U.NaryOp_ op es -> do
            mode <- getMode
            case mode of
              StrictMode -> safeNaryOp typ0
              LaxMode    -> safeNaryOp typ0
              UnsafeMode -> do
                es' <- tryWith LaxMode (safeNaryOp typ0)
                case es' of
                  Just es'' -> return es''
                  Nothing   -> do
                    TypedAST typ e0' <- inferType (syn $ U.NaryOp_ op es)
                    checkOrUnsafeCoerce sourceSpan e0' typ typ0
            where
            safeNaryOp :: forall c. Sing c -> TypeCheckMonad (abt '[] c)
            safeNaryOp typ = do
                op'  <- make_NaryOp typ op
                es'  <- T.forM es $ checkType_ typ
                return $ syn (NaryOp_ op' (S.fromList es'))

        U.Empty_ ->
            case typ0 of
            SArray _ -> return $ syn (Empty_ typ0)
            _        -> typeMismatch sourceSpan (Right typ0) (Left "HArray")

        U.Pair_ e1 e2 ->
            case typ0 of
            SData (STyCon sym `STyApp` a `STyApp` b) _ ->
                case jmEq1 sym sSymbol_Pair of
                Just Refl  -> do
                    e1' <- checkType_ a e1
                    e2' <- checkType_ b e2
                    return $ syn (Datum_ $ dPair_ a b e1' e2')
                Nothing    -> typeMismatch sourceSpan (Right typ0) (Left "HPair")
            _              -> typeMismatch sourceSpan (Right typ0) (Left "HPair")

        U.Array_ e1 e2 ->
            case typ0 of
            SArray typ1 -> do
                e1' <- checkType_  SNat e1
                e2' <- checkBinder SNat typ1 e2
                return $ syn (Array_ e1' e2')
            _ -> typeMismatch sourceSpan (Right typ0) (Left "HArray")

        U.Datum_ (U.Datum hint d) ->
            case typ0 of
            SData _ typ2 ->
                (syn . Datum_ . Datum hint typ0)
                    <$> checkDatumCode typ0 typ2 d
            _ -> typeMismatch sourceSpan (Right typ0) (Left "HData")

        U.Case_ e1 branches -> do
            TypedAST typ1 e1' <- inferType_ e1
            branches' <- T.forM branches $ checkBranch typ1 typ0
            return $ syn (Case_ e1' branches')

        U.Dirac_ e1 ->
            case typ0 of
            SMeasure typ1 -> do
                e1' <- checkType_ typ1 e1
                return $ syn (Dirac :$ e1' :* End)
            _ -> typeMismatch sourceSpan (Right typ0) (Left "HMeasure")

        U.MBind_ e1 e2 ->
            case typ0 of
            SMeasure _ -> do
                TypedAST typ1 e1' <- inferType_ e1
                case typ1 of
                    SMeasure typ2 -> do
                        e2' <- checkBinder typ2 typ0 e2
                        return $ syn (MBind :$ e1' :* e2' :* End)
                    _ -> typeMismatch sourceSpan (Right typ0) (Right typ1)
            _ -> typeMismatch sourceSpan (Right typ0) (Left "HMeasure")

        U.Plate_ e1 e2 ->
            case typ0 of
            SMeasure typ1 -> do
                e1' <- checkType_ SNat e1
                case typ1 of
                    SArray typ2 -> do
                        e2' <- checkBinder SNat (SMeasure typ2) e2
                        return $ syn (Plate :$ e1' :* e2' :* End)
                    _ -> typeMismatch sourceSpan (Right typ1) (Left "HArray")
            _ -> typeMismatch sourceSpan (Right typ0) (Left "HMeasure")

        U.Chain_ e1 e2 e3 ->
            case typ0 of
            SMeasure (SData (STyCon sym `STyApp` (SArray a) `STyApp` s) _) ->
                case jmEq1 sym sSymbol_Pair of
                Just Refl -> do
                    e1' <- checkType_  SNat e1
                    e2' <- checkType_  s    e2
                    e3' <- checkBinder s    (SMeasure $ sPair a s) e3
                    return $ syn (Chain :$ e1' :* e2' :* e3' :* End)
                Nothing -> typeMismatch sourceSpan (Right typ0) (Left "HMeasure(HPair(HArray, s)")
            _           -> typeMismatch sourceSpan (Right typ0) (Left "HMeasure(HPair(HArray, s)")


        U.Expect_ e1 e2 ->
            case typ0 of
            SProb -> do
                TypedAST typ1 e1' <- inferType_ e1
                case typ1 of
                    SMeasure typ2 -> do
                        e2' <- checkBinder typ2 typ0 e2
                        return $ syn (Expect :$ e1' :* e2' :* End)
                    _ -> typeMismatch sourceSpan (Left "HMeasure") (Right typ1)
            _ -> typeMismatch sourceSpan (Right typ0) (Left "HProb")

        U.Observe_ e1 e2 ->
            case typ0 of
            SMeasure typ2 -> do
                e1' <- checkType_ typ0 e1
                e2' <- checkType_ typ2 e2
                return $ syn (Observe :$ e1' :* e2' :* End)
            _ -> typeMismatch sourceSpan (Right typ0) (Left "HMeasure")

        U.Superpose_ pes ->
            case typ0 of
            SMeasure _ ->
                fmap (syn . Superpose_) .
                    T.forM pes $ \(p,e) ->
                        (,) <$> checkType_ SProb p <*> checkType_ typ0 e
            _ -> typeMismatch sourceSpan (Right typ0) (Left "HMeasure")

        U.Reject_ ->
            case typ0 of
            SMeasure _ -> return $ syn (Reject_ typ0)
            _          -> typeMismatch sourceSpan (Right typ0) (Left "HMeasure")

        _   | inferable e0 -> do
                TypedAST typ' e0' <- inferType_ e0
                mode <- getMode
                case mode of
                  StrictMode ->
                    case jmEq1 typ0 typ' of
                    Just Refl -> return e0'
                    Nothing   -> typeMismatch sourceSpan (Right typ0) (Right typ')
                  LaxMode    -> checkOrCoerce       sourceSpan e0' typ' typ0
                  UnsafeMode -> checkOrUnsafeCoerce sourceSpan e0' typ' typ0
            | otherwise -> error "checkType: missing an mustCheck branch!"


    --------------------------------------------------------
    -- We make these local to 'checkType' for the same reason we have 'checkType_'
    -- TODO: can we combine these in with the 'checkBranch' functions somehow?
    checkDatumCode
        :: forall xss t
        .  Sing (HData' t)
        -> Sing xss
        -> U.DCode_
        -> TypeCheckMonad (DatumCode xss (abt '[]) (HData' t))
    checkDatumCode typA typ d =
        case d of
        U.Inr d2 ->
            case typ of
            SPlus _ typ2 -> Inr <$> checkDatumCode typA typ2 d2
            _            -> failwith_ "expected datum of `inr' type"
        U.Inl d1 ->
            case typ of
            SPlus typ1 _ -> Inl <$> checkDatumStruct typA typ1 d1
            _            -> failwith_ "expected datum of `inl' type"

    checkDatumStruct
        :: forall xs t
        .  Sing (HData' t)
        -> Sing xs
        -> U.DStruct_
        -> TypeCheckMonad (DatumStruct xs (abt '[]) (HData' t))
    checkDatumStruct typA typ d =
        case d of
        U.Et d1 d2 ->
            case typ of
            SEt typ1 typ2 -> Et
                <$> checkDatumFun    typA typ1 d1
                <*> checkDatumStruct typA typ2 d2
            _     -> failwith_ "expected datum of `et' type"
        U.Done ->
            case typ of
            SDone -> return Done
            _     -> failwith_ "expected datum of `done' type"

    checkDatumFun
        :: forall x t
        .  Sing (HData' t)
        -> Sing x
        -> U.DFun_
        -> TypeCheckMonad (DatumFun x (abt '[]) (HData' t))
    checkDatumFun typA typ d =
        case d of
        U.Ident e1 ->
            case typ of
            SIdent      -> Ident <$> checkType_ typA e1
            _           -> failwith_ "expected datum of `I' type"
        U.Konst e1 ->
            case typ of
            SKonst typ1 -> Konst <$> checkType_ typ1 e1
            _           -> failwith_ "expected datum of `K' type"


----------------------------------------------------------------
-- BUG: haddock doesn't like annotations on GADT constructors. So
-- here we'll avoid using the GADT syntax, even though it'd make
-- the data type declaration prettier\/cleaner.
-- <https://github.com/hakaru-dev/hakaru/issues/6>
data SomePattern (a :: Hakaru) =
    forall vars.
        SP  !(Pattern vars a)
            !(List1 Variable vars)

data SomePatternCode xss t =
    forall vars.
        SPC !(PDatumCode xss vars (HData' t))
            !(List1 Variable vars)

data SomePatternStruct xs t =
    forall vars.
        SPS !(PDatumStruct xs vars (HData' t))
            !(List1 Variable vars)

data SomePatternFun x t =
    forall vars.
        SPF !(PDatumFun x vars (HData' t))
            !(List1 Variable vars)

checkBranch
    :: (ABT Term abt)
    => Sing a
    -> Sing b
    -> U.Branch
    -> TypeCheckMonad (Branch a abt b)
checkBranch patTyp bodyTyp (U.Branch_ pat body) = do
    SP pat' vars <- checkPattern patTyp pat
    Branch pat' <$> checkBinders vars bodyTyp body

checkPattern
    :: Sing a
    -> U.Pattern
    -> TypeCheckMonad (SomePattern a)
checkPattern = \typA pat ->
    case pat of
    U.PVar x -> return $ SP PVar (Cons1 (makeVar (U.nameToVar x) typA) Nil1)
    U.PWild  -> return $ SP PWild Nil1
    U.PDatum hint pat1 ->
        case typA of
        SData _ typ1 -> do
            SPC pat1' xs <- checkPatternCode typA typ1 pat1
            return $ SP (PDatum hint pat1') xs
        _ -> typeMismatch Nothing (Right typA) (Left "HData")
    where
    checkPatternCode
        :: Sing (HData' t)
        -> Sing xss
        -> U.PCode
        -> TypeCheckMonad (SomePatternCode xss t)
    checkPatternCode typA typ pat =
        case pat of
        U.PInr pat2 ->
            case typ of
            SPlus _ typ2 -> do
                SPC pat2' xs <- checkPatternCode typA typ2 pat2
                return $ SPC (PInr pat2') xs
            _            -> failwith_ "expected pattern of `sum' type"
        U.PInl pat1 ->
            case typ of
            SPlus typ1 _ -> do
                SPS pat1' xs <- checkPatternStruct typA typ1 pat1
                return $ SPC (PInl pat1') xs
            _ -> failwith_ "expected pattern of `zero' type"

    checkPatternStruct
        :: Sing (HData' t)
        -> Sing xs
        -> U.PStruct
        -> TypeCheckMonad (SomePatternStruct xs t)
    checkPatternStruct  typA typ pat =
        case pat of
        U.PEt pat1 pat2 ->
            case typ of
            SEt typ1 typ2 -> do
                SPF pat1' xs <- checkPatternFun    typA typ1 pat1
                SPS pat2' ys <- checkPatternStruct typA typ2 pat2
                return $ SPS (PEt pat1' pat2') (append1 xs ys)
            _ -> failwith_ "expected pattern of `et' type"
        U.PDone ->
            case typ of
            SDone -> return $ SPS PDone Nil1
            _     -> failwith_ "expected pattern of `done' type"

    checkPatternFun
        :: Sing (HData' t)
        -> Sing x
        -> U.PFun
        -> TypeCheckMonad (SomePatternFun x t)
    checkPatternFun typA typ pat =
        case pat of
        U.PIdent pat1 ->
            case typ of
            SIdent -> do
                SP pat1' xs <- checkPattern typA pat1
                return $ SPF (PIdent pat1') xs
            _ -> failwith_ "expected pattern of `I' type"
        U.PKonst pat1 ->
            case typ of
            SKonst typ1 -> do
                SP pat1' xs <- checkPattern typ1 pat1
                return $ SPF (PKonst pat1') xs
            _ -> failwith_ "expected pattern of `K' type"

checkOrCoerce
    :: (ABT Term abt)
    => Maybe (U.SourceSpan)
    -> abt '[] a
    -> Sing a
    -> Sing b
    -> TypeCheckMonad (abt '[] b)
checkOrCoerce s e typA typB =
    case findCoercion typA typB of
    Just c  -> return . unLC_ . coerceTo c $ LC_ e
    Nothing -> typeMismatch s (Right typB) (Right typA)

checkOrUnsafeCoerce
    :: (ABT Term abt)
    => Maybe (U.SourceSpan)
    -> abt '[] a
    -> Sing a
    -> Sing b
    -> TypeCheckMonad (abt '[] b)
checkOrUnsafeCoerce s e typA typB =
    case findEitherCoercion typA typB of
    Just (Unsafe  c) ->
        return . unLC_ . coerceFrom c $ LC_ e
    Just (Safe    c) ->
        return . unLC_ . coerceTo c $ LC_ e
    Just (Mixed   (_, c1, c2)) ->
        return . unLC_ . coerceTo c2 . coerceFrom c1 $ LC_ e
    Nothing ->
        case (typA, typB) of
          -- mighty, mighty hack!
          (SMeasure typ1, SMeasure _) -> do
            let x = Variable (pack "") 0 U.SU
            e2' <- checkBinder typ1 typB (bind x $ syn $ U.Dirac_ (var x))
            return $ syn (MBind :$ e :* e2' :* End)
          (_ ,  _) -> typeMismatch s (Right typB) (Right typA)



----------------------------------------------------------------
----------------------------------------------------------- fin.