{-# LANGUAGE OverloadedStrings #-}

module Clash.Core.EqSolver where

import Data.Maybe (catMaybes, mapMaybe)

import Clash.Core.Name (Name(nameOcc))
import Clash.Core.Term
import Clash.Core.TyCon
import Clash.Core.Type
import Clash.Core.Var

-- | Data type that indicates what kind of solution (if any) was found
data TypeEqSolution
  = Solution (TyVar, Type)
  -- ^ Solution was found. Variable equals some integer.
  | AbsurdSolution
  -- ^ A solution was found, but it involved negative naturals.
  | NoSolution
  -- ^ Given type wasn't an equation, or it was unsolvable.
    deriving (Int -> TypeEqSolution -> ShowS
[TypeEqSolution] -> ShowS
TypeEqSolution -> String
(Int -> TypeEqSolution -> ShowS)
-> (TypeEqSolution -> String)
-> ([TypeEqSolution] -> ShowS)
-> Show TypeEqSolution
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TypeEqSolution] -> ShowS
$cshowList :: [TypeEqSolution] -> ShowS
show :: TypeEqSolution -> String
$cshow :: TypeEqSolution -> String
showsPrec :: Int -> TypeEqSolution -> ShowS
$cshowsPrec :: Int -> TypeEqSolution -> ShowS
Show, TypeEqSolution -> TypeEqSolution -> Bool
(TypeEqSolution -> TypeEqSolution -> Bool)
-> (TypeEqSolution -> TypeEqSolution -> Bool) -> Eq TypeEqSolution
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TypeEqSolution -> TypeEqSolution -> Bool
$c/= :: TypeEqSolution -> TypeEqSolution -> Bool
== :: TypeEqSolution -> TypeEqSolution -> Bool
$c== :: TypeEqSolution -> TypeEqSolution -> Bool
Eq)

catSolutions :: [TypeEqSolution] -> [(TyVar, Type)]
catSolutions :: [TypeEqSolution] -> [(TyVar, Type)]
catSolutions = (TypeEqSolution -> Maybe (TyVar, Type))
-> [TypeEqSolution] -> [(TyVar, Type)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe TypeEqSolution -> Maybe (TyVar, Type)
getSol
 where
  getSol :: TypeEqSolution -> Maybe (TyVar, Type)
getSol (Solution (TyVar, Type)
s) = (TyVar, Type) -> Maybe (TyVar, Type)
forall a. a -> Maybe a
Just (TyVar, Type)
s
  getSol TypeEqSolution
_ = Maybe (TyVar, Type)
forall a. Maybe a
Nothing

-- | Solve given equations and return all non-absurd solutions
solveNonAbsurds :: TyConMap -> [(Type, Type)] -> [(TyVar, Type)]
solveNonAbsurds :: TyConMap -> [(Type, Type)] -> [(TyVar, Type)]
solveNonAbsurds TyConMap
_tcm [] = []
solveNonAbsurds TyConMap
tcm ((Type, Type)
eq:[(Type, Type)]
eqs) =
  [(TyVar, Type)]
solved [(TyVar, Type)] -> [(TyVar, Type)] -> [(TyVar, Type)]
forall a. [a] -> [a] -> [a]
++ TyConMap -> [(Type, Type)] -> [(TyVar, Type)]
solveNonAbsurds TyConMap
tcm [(Type, Type)]
eqs
 where
  solvers :: [(Type, Type) -> [TypeEqSolution]]
solvers = [TypeEqSolution -> [TypeEqSolution]
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (TypeEqSolution -> [TypeEqSolution])
-> ((Type, Type) -> TypeEqSolution)
-> (Type, Type)
-> [TypeEqSolution]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type, Type) -> TypeEqSolution
solveAdd, TyConMap -> (Type, Type) -> [TypeEqSolution]
solveEq TyConMap
tcm]
  solved :: [(TyVar, Type)]
solved = [TypeEqSolution] -> [(TyVar, Type)]
catSolutions ([[TypeEqSolution]] -> [TypeEqSolution]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [(Type, Type) -> [TypeEqSolution]
s (Type, Type)
eq | (Type, Type) -> [TypeEqSolution]
s <- [(Type, Type) -> [TypeEqSolution]]
solvers])

-- | Solve simple equalities such as:
--
--   * a ~ 3
--   * 3 ~ a
--   * SomeType a b ~ SomeType 3 5
--   * SomeType 3 5 ~ SomeType a b
--   * SomeType a 5 ~ SomeType 3 b
--
solveEq :: TyConMap -> (Type, Type) -> [TypeEqSolution]
solveEq :: TyConMap -> (Type, Type) -> [TypeEqSolution]
solveEq TyConMap
tcm (TyConMap -> Type -> Type
coreView TyConMap
tcm -> Type
left, TyConMap -> Type -> Type
coreView TyConMap
tcm -> Type
right) =
  case (Type
left, Type
right) of
    (VarTy TyVar
tyVar, ConstTy {}) ->
      -- a ~ 3
      [(TyVar, Type) -> TypeEqSolution
Solution (TyVar
tyVar, Type
right)]
    (ConstTy {}, VarTy TyVar
tyVar) ->
      -- 3 ~ a
      [(TyVar, Type) -> TypeEqSolution
Solution (TyVar
tyVar, Type
left)]
    (ConstTy {}, ConstTy {}) ->
      -- Int /= Char
      if Type
left Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
/= Type
right then [TypeEqSolution
AbsurdSolution] else []
    (LitTy {}, LitTy {}) ->
      -- 3 /= 5
      if Type
left Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
/= Type
right then [TypeEqSolution
AbsurdSolution] else []
    (Type, Type)
_ ->
      -- The call to 'coreView' at the start of 'solveEq' should have reduced
      -- all solvable type families. If we encounter one here that means the
      -- type family is stuck (and that we shouldn't compare it to anything!).
      if (Type -> Bool) -> [Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any (TyConMap -> Type -> Bool
isTypeFamilyApplication TyConMap
tcm) [Type
left, Type
right] then
        []
      else
        case (Type -> TypeView
tyView Type
left, Type -> TypeView
tyView Type
right) of
          (TyConApp TyConName
leftNm [Type]
leftTys, TyConApp TyConName
rightNm [Type]
rightTys) ->
            -- SomeType a b ~ SomeType 3 5 (or other way around)
            if TyConName
leftNm TyConName -> TyConName -> Bool
forall a. Eq a => a -> a -> Bool
== TyConName
rightNm then
              [[TypeEqSolution]] -> [TypeEqSolution]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat (((Type, Type) -> [TypeEqSolution])
-> [(Type, Type)] -> [[TypeEqSolution]]
forall a b. (a -> b) -> [a] -> [b]
map (TyConMap -> (Type, Type) -> [TypeEqSolution]
solveEq TyConMap
tcm) ([Type] -> [Type] -> [(Type, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
leftTys [Type]
rightTys))
            else
              [TypeEqSolution
AbsurdSolution]
          (TypeView, TypeView)
_ ->
            []

-- | Solve equations supported by @normalizeAdd@. See documentation of
-- @TypeEqSolution@ to understand the return value.
solveAdd
  :: (Type, Type)
  -> TypeEqSolution
solveAdd :: (Type, Type) -> TypeEqSolution
solveAdd (Type, Type)
ab =
  case (Type, Type) -> Maybe (Integer, Integer, Type)
normalizeAdd (Type, Type)
ab of
    Just (Integer
n, Integer
m, VarTy TyVar
tyVar) ->
      if Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
0 Bool -> Bool -> Bool
&& Integer
m Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
0 Bool -> Bool -> Bool
&& Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
m Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
0 then
        (TyVar, Type) -> TypeEqSolution
Solution (TyVar
tyVar, (LitTy -> Type
LitTy (Integer -> LitTy
NumTy (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
m))))
      else
        TypeEqSolution
AbsurdSolution
    Maybe (Integer, Integer, Type)
_ ->
      TypeEqSolution
NoSolution

-- | Given the left and right side of an equation, normalize it such that
-- equations of the following forms:
--
--     * 5     ~ n + 2
--     * 5     ~ 2 + n
--     * n + 2 ~ 5
--     * 2 + n ~ 5
--
-- are returned as (5, 2, n)
normalizeAdd
  :: (Type, Type)
  -> Maybe (Integer, Integer, Type)
normalizeAdd :: (Type, Type) -> Maybe (Integer, Integer, Type)
normalizeAdd (Type
a, Type
b) = do
  (Integer
n, Type
rhs) <- Type -> Type -> Maybe (Integer, Type)
lhsLit Type
a Type
b
  case Type -> TypeView
tyView Type
rhs of
    TyConApp (TyConName -> OccName
forall a. Name a -> OccName
nameOcc -> OccName
"GHC.TypeNats.+") [Type
left, Type
right] -> do
      (Integer
m, Type
o) <- Type -> Type -> Maybe (Integer, Type)
lhsLit Type
left Type
right
      (Integer, Integer, Type) -> Maybe (Integer, Integer, Type)
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Integer
n, Integer
m, Type
o)
    TypeView
_ ->
      Maybe (Integer, Integer, Type)
forall a. Maybe a
Nothing
 where
  lhsLit :: Type -> Type -> Maybe (Integer, Type)
lhsLit Type
x                 (LitTy (NumTy Integer
n)) = (Integer, Type) -> Maybe (Integer, Type)
forall a. a -> Maybe a
Just (Integer
n, Type
x)
  lhsLit (LitTy (NumTy Integer
n)) Type
y                 = (Integer, Type) -> Maybe (Integer, Type)
forall a. a -> Maybe a
Just (Integer
n, Type
y)
  lhsLit Type
_                 Type
_                 = Maybe (Integer, Type)
forall a. Maybe a
Nothing

-- | Tests for unreachable alternative due to types being "absurd". See
-- @isAbsurdEq@ for more info.
isAbsurdAlt
  :: TyConMap
  -> Alt
  -> Bool
isAbsurdAlt :: TyConMap -> Alt -> Bool
isAbsurdAlt TyConMap
tcm Alt
alt =
  ((Type, Type) -> Bool) -> [(Type, Type)] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any (TyConMap -> (Type, Type) -> Bool
isAbsurdEq TyConMap
tcm) (TyConMap -> Alt -> [(Type, Type)]
altEqs TyConMap
tcm Alt
alt)

-- | Determines if an "equation" obtained through @altEqs@ or @typeEq@ is
-- absurd. That is, it tests if two types that are definitely not equal are
-- asserted to be equal OR if the computation of the types yield some absurd
-- (intermediate) result such as -1.
isAbsurdEq
  :: TyConMap
  -> (Type, Type)
  -> Bool
isAbsurdEq :: TyConMap -> (Type, Type) -> Bool
isAbsurdEq TyConMap
tcm ((Type
left0, Type
right0)) =
  case (TyConMap -> Type -> Type
coreView TyConMap
tcm Type
left0, TyConMap -> Type -> Type
coreView TyConMap
tcm Type
right0) of
    ((Type, Type) -> TypeEqSolution
solveAdd -> TypeEqSolution
AbsurdSolution) -> Bool
True
    (Type, Type)
lr -> (TypeEqSolution -> Bool) -> [TypeEqSolution] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any (TypeEqSolution -> TypeEqSolution -> Bool
forall a. Eq a => a -> a -> Bool
==TypeEqSolution
AbsurdSolution) (TyConMap -> (Type, Type) -> [TypeEqSolution]
solveEq TyConMap
tcm (Type, Type)
lr)

-- | Get constraint equations
altEqs
  :: TyConMap
  -> Alt
  -> [(Type, Type)]
altEqs :: TyConMap -> Alt -> [(Type, Type)]
altEqs TyConMap
tcm (Pat
pat, Term
_term) =
 [Maybe (Type, Type)] -> [(Type, Type)]
forall a. [Maybe a] -> [a]
catMaybes ((Var Term -> Maybe (Type, Type))
-> [Var Term] -> [Maybe (Type, Type)]
forall a b. (a -> b) -> [a] -> [b]
map (TyConMap -> Type -> Maybe (Type, Type)
typeEq TyConMap
tcm (Type -> Maybe (Type, Type))
-> (Var Term -> Type) -> Var Term -> Maybe (Type, Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Var Term -> Type
forall a. Var a -> Type
varType) (([TyVar], [Var Term]) -> [Var Term]
forall a b. (a, b) -> b
snd (Pat -> ([TyVar], [Var Term])
patIds Pat
pat)))

-- | If type is an equation, return LHS and RHS.
typeEq
  :: TyConMap
  -> Type
  -> Maybe (Type, Type)
typeEq :: TyConMap -> Type -> Maybe (Type, Type)
typeEq TyConMap
tcm Type
ty =
 case Type -> TypeView
tyView (TyConMap -> Type -> Type
coreView TyConMap
tcm Type
ty) of
  TyConApp (TyConName -> OccName
forall a. Name a -> OccName
nameOcc -> OccName
"GHC.Prim.~#") [Type
_, Type
_, Type
left, Type
right] ->
    (Type, Type) -> Maybe (Type, Type)
forall a. a -> Maybe a
Just (TyConMap -> Type -> Type
coreView TyConMap
tcm Type
left, TyConMap -> Type -> Type
coreView TyConMap
tcm Type
right)
  TypeView
_ ->
    Maybe (Type, Type)
forall a. Maybe a
Nothing