{-# LANGUAGE OverloadedStrings #-}

-- | Facilities for comparing values for equality.  While 'Eq'
-- instances are defined, these are not useful when NaNs are involved,
-- and do not *explain* the differences.
module Futhark.Data.Compare
  ( compareValues,
    compareSeveralValues,
    Tolerance (..),
    Mismatch,
  )
where

import Data.List (intersperse)
import qualified Data.Text as T
import qualified Data.Vector.Storable as SVec
import Futhark.Data

-- | Two values differ in some way.  The 'Show' instance produces a
-- human-readable explanation.
data Mismatch
  = -- | The position the value number and a flat index
    -- into the array.
    PrimValueMismatch Int [Int] T.Text T.Text
  | ArrayShapeMismatch Int [Int] [Int]
  | TypeMismatch Int T.Text T.Text
  | ValueCountMismatch Int Int

showText :: Show a => a -> T.Text
showText :: forall a. Show a => a -> Text
showText = String -> Text
T.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> String
show

-- | A human-readable description of how two values are not the same.
explainMismatch :: T.Text -> T.Text -> T.Text -> T.Text -> T.Text
explainMismatch :: Text -> Text -> Text -> Text -> Text
explainMismatch Text
i Text
what Text
got Text
expected =
  Text
"Value #" forall a. Semigroup a => a -> a -> a
<> Text
i forall a. Semigroup a => a -> a -> a
<> Text
": expected " forall a. Semigroup a => a -> a -> a
<> Text
what forall a. Semigroup a => a -> a -> a
<> Text
expected forall a. Semigroup a => a -> a -> a
<> Text
", got " forall a. Semigroup a => a -> a -> a
<> Text
got

instance Show Mismatch where
  show :: Mismatch -> String
show (PrimValueMismatch Int
vi [] Text
got Text
expected) =
    Text -> String
T.unpack forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text -> Text -> Text
explainMismatch (forall a. Show a => a -> Text
showText Int
vi) Text
"" Text
got Text
expected
  show (PrimValueMismatch Int
vi [Int]
js Text
got Text
expected) =
    Text -> String
T.unpack forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text -> Text -> Text
explainMismatch (forall a. Show a => a -> Text
showText Int
vi forall a. Semigroup a => a -> a -> a
<> Text
" index [" forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat (forall a. a -> [a] -> [a]
intersperse Text
"," (forall a b. (a -> b) -> [a] -> [b]
map forall a. Show a => a -> Text
showText [Int]
js)) forall a. Semigroup a => a -> a -> a
<> Text
"]") Text
"" Text
got Text
expected
  show (ArrayShapeMismatch Int
i [Int]
got [Int]
expected) =
    Text -> String
T.unpack forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text -> Text -> Text
explainMismatch (forall a. Show a => a -> Text
showText Int
i) Text
"array of shape " (forall a. Show a => a -> Text
showText [Int]
got) (forall a. Show a => a -> Text
showText [Int]
expected)
  show (TypeMismatch Int
i Text
got Text
expected) =
    Text -> String
T.unpack forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text -> Text -> Text
explainMismatch (forall a. Show a => a -> Text
showText Int
i) Text
"value of type " Text
got Text
expected
  show (ValueCountMismatch Int
got Int
expected) =
    Text -> String
T.unpack forall a b. (a -> b) -> a -> b
$ Text
"Expected " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> Text
showText Int
expected forall a. Semigroup a => a -> a -> a
<> Text
" values, got " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> Text
showText Int
got

-- | The maximum relative tolerance used for comparing floating-point
-- results.  0.002 (0.2%) is a fine default if you have no particular
-- opinion.
newtype Tolerance = Tolerance Double
  deriving (Tolerance -> Tolerance -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Tolerance -> Tolerance -> Bool
$c/= :: Tolerance -> Tolerance -> Bool
== :: Tolerance -> Tolerance -> Bool
$c== :: Tolerance -> Tolerance -> Bool
Eq, Eq Tolerance
Tolerance -> Tolerance -> Bool
Tolerance -> Tolerance -> Ordering
Tolerance -> Tolerance -> Tolerance
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Tolerance -> Tolerance -> Tolerance
$cmin :: Tolerance -> Tolerance -> Tolerance
max :: Tolerance -> Tolerance -> Tolerance
$cmax :: Tolerance -> Tolerance -> Tolerance
>= :: Tolerance -> Tolerance -> Bool
$c>= :: Tolerance -> Tolerance -> Bool
> :: Tolerance -> Tolerance -> Bool
$c> :: Tolerance -> Tolerance -> Bool
<= :: Tolerance -> Tolerance -> Bool
$c<= :: Tolerance -> Tolerance -> Bool
< :: Tolerance -> Tolerance -> Bool
$c< :: Tolerance -> Tolerance -> Bool
compare :: Tolerance -> Tolerance -> Ordering
$ccompare :: Tolerance -> Tolerance -> Ordering
Ord, Int -> Tolerance -> ShowS
[Tolerance] -> ShowS
Tolerance -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Tolerance] -> ShowS
$cshowList :: [Tolerance] -> ShowS
show :: Tolerance -> String
$cshow :: Tolerance -> String
showsPrec :: Int -> Tolerance -> ShowS
$cshowsPrec :: Int -> Tolerance -> ShowS
Show)

toleranceFloat :: RealFloat a => Tolerance -> a
toleranceFloat :: forall a. RealFloat a => Tolerance -> a
toleranceFloat (Tolerance Double
x) = forall a. Fractional a => Rational -> a
fromRational forall a b. (a -> b) -> a -> b
$ forall a. Real a => a -> Rational
toRational Double
x

-- | Compare two Futhark values for equality.
compareValues :: Tolerance -> Value -> Value -> [Mismatch]
compareValues :: Tolerance -> Value -> Value -> [Mismatch]
compareValues Tolerance
tol = Tolerance -> Int -> Value -> Value -> [Mismatch]
compareValue Tolerance
tol Int
0

-- | As 'compareValues', but compares several values.  The two lists
-- must have the same length.
compareSeveralValues :: Tolerance -> [Value] -> [Value] -> [Mismatch]
compareSeveralValues :: Tolerance -> [Value] -> [Value] -> [Mismatch]
compareSeveralValues Tolerance
tol [Value]
got [Value]
expected
  | Int
n forall a. Eq a => a -> a -> Bool
/= Int
m = [Int -> Int -> Mismatch
ValueCountMismatch Int
n Int
m]
  | Bool
otherwise = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 (Tolerance -> Int -> Value -> Value -> [Mismatch]
compareValue Tolerance
tol) [Int
0 ..] [Value]
got [Value]
expected
  where
    n :: Int
n = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Value]
got
    m :: Int
m = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Value]
expected

unflattenIndex :: [Int] -> Int -> [Int]
unflattenIndex :: [Int] -> Int -> [Int]
unflattenIndex = forall {t}. Integral t => [t] -> t -> [t]
unflattenIndexFromSlices forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> [a] -> [a]
drop Int
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a}. Num a => [a] -> [a]
sliceSizes
  where
    sliceSizes :: [a] -> [a]
sliceSizes [] = [a
1]
    sliceSizes (a
n : [a]
ns) = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (a
n forall a. a -> [a] -> [a]
: [a]
ns) forall a. a -> [a] -> [a]
: [a] -> [a]
sliceSizes [a]
ns
    unflattenIndexFromSlices :: [t] -> t -> [t]
unflattenIndexFromSlices [] t
_ = []
    unflattenIndexFromSlices (t
size : [t]
slices) t
i =
      (t
i forall a. Integral a => a -> a -> a
`quot` t
size) forall a. a -> [a] -> [a]
: [t] -> t -> [t]
unflattenIndexFromSlices [t]
slices (t
i forall a. Num a => a -> a -> a
- (t
i forall a. Integral a => a -> a -> a
`quot` t
size) forall a. Num a => a -> a -> a
* t
size)

compareValue :: Tolerance -> Int -> Value -> Value -> [Mismatch]
compareValue :: Tolerance -> Int -> Value -> Value -> [Mismatch]
compareValue Tolerance
tol Int
i Value
got_v Value
expected_v
  | Value -> [Int]
valueShape Value
got_v forall a. Eq a => a -> a -> Bool
== Value -> [Int]
valueShape Value
expected_v =
    case (Value
got_v, Value
expected_v) of
      (I8Value Vector Int
_ Vector Int8
got_vs, I8Value Vector Int
_ Vector Int8
expected_vs) ->
        forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Int8
got_vs Vector Int8
expected_vs
      (I16Value Vector Int
_ Vector Int16
got_vs, I16Value Vector Int
_ Vector Int16
expected_vs) ->
        forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Int16
got_vs Vector Int16
expected_vs
      (I32Value Vector Int
_ Vector Int32
got_vs, I32Value Vector Int
_ Vector Int32
expected_vs) ->
        forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Int32
got_vs Vector Int32
expected_vs
      (I64Value Vector Int
_ Vector Int64
got_vs, I64Value Vector Int
_ Vector Int64
expected_vs) ->
        forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Int64
got_vs Vector Int64
expected_vs
      (U8Value Vector Int
_ Vector Word8
got_vs, U8Value Vector Int
_ Vector Word8
expected_vs) ->
        forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Word8
got_vs Vector Word8
expected_vs
      (U16Value Vector Int
_ Vector Word16
got_vs, U16Value Vector Int
_ Vector Word16
expected_vs) ->
        forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Word16
got_vs Vector Word16
expected_vs
      (U32Value Vector Int
_ Vector Word32
got_vs, U32Value Vector Int
_ Vector Word32
expected_vs) ->
        forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Word32
got_vs Vector Word32
expected_vs
      (U64Value Vector Int
_ Vector Word64
got_vs, U64Value Vector Int
_ Vector Word64
expected_vs) ->
        forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Word64
got_vs Vector Word64
expected_vs
      (F16Value Vector Int
_ Vector Half
got_vs, F16Value Vector Int
_ Vector Half
expected_vs) ->
        forall a.
(Storable a, RealFloat a, Show a) =>
a -> Vector a -> Vector a -> [Mismatch]
compareFloat (forall a. (RealFloat a, Storable a) => a -> Vector a -> a
tolerance (forall a. RealFloat a => Tolerance -> a
toleranceFloat Tolerance
tol) Vector Half
expected_vs) Vector Half
got_vs Vector Half
expected_vs
      (F32Value Vector Int
_ Vector Float
got_vs, F32Value Vector Int
_ Vector Float
expected_vs) ->
        forall a.
(Storable a, RealFloat a, Show a) =>
a -> Vector a -> Vector a -> [Mismatch]
compareFloat (forall a. (RealFloat a, Storable a) => a -> Vector a -> a
tolerance (forall a. RealFloat a => Tolerance -> a
toleranceFloat Tolerance
tol) Vector Float
expected_vs) Vector Float
got_vs Vector Float
expected_vs
      (F64Value Vector Int
_ Vector Double
got_vs, F64Value Vector Int
_ Vector Double
expected_vs) ->
        forall a.
(Storable a, RealFloat a, Show a) =>
a -> Vector a -> Vector a -> [Mismatch]
compareFloat (forall a. (RealFloat a, Storable a) => a -> Vector a -> a
tolerance (forall a. RealFloat a => Tolerance -> a
toleranceFloat Tolerance
tol) Vector Double
expected_vs) Vector Double
got_vs Vector Double
expected_vs
      (BoolValue Vector Int
_ Vector Bool
got_vs, BoolValue Vector Int
_ Vector Bool
expected_vs) ->
        forall {a} {t} {a}.
(Storable a, Storable t) =>
(Int -> a -> t -> Maybe a) -> Vector a -> Vector t -> [a]
compareGen forall {a}. (Eq a, Show a) => Int -> a -> a -> Maybe Mismatch
compareBool Vector Bool
got_vs Vector Bool
expected_vs
      (Value, Value)
_ ->
        [Int -> Text -> Text -> Mismatch
TypeMismatch Int
i (PrimType -> Text
primTypeText forall a b. (a -> b) -> a -> b
$ Value -> PrimType
valueElemType Value
got_v) (PrimType -> Text
primTypeText forall a b. (a -> b) -> a -> b
$ Value -> PrimType
valueElemType Value
expected_v)]
  | Bool
otherwise =
    [Int -> [Int] -> [Int] -> Mismatch
ArrayShapeMismatch Int
i (Value -> [Int]
valueShape Value
got_v) (Value -> [Int]
valueShape Value
expected_v)]
  where
    unflatten :: Int -> [Int]
unflatten = [Int] -> Int -> [Int]
unflattenIndex (Value -> [Int]
valueShape Value
got_v)
    value :: Show a => a -> T.Text
    value :: forall a. Show a => a -> Text
value = String -> Text
T.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> String
show
    {-# INLINE compareGen #-}
    {-# INLINE compareNum #-}
    {-# INLINE compareFloat #-}
    {-# INLINE compareFloatElement #-}
    {-# INLINE compareElement #-}
    compareNum :: (SVec.Storable a, Eq a, Show a) => SVec.Vector a -> SVec.Vector a -> [Mismatch]
    compareNum :: forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum = forall {a} {t} {a}.
(Storable a, Storable t) =>
(Int -> a -> t -> Maybe a) -> Vector a -> Vector t -> [a]
compareGen forall a. (Show a, Eq a) => Int -> a -> a -> Maybe Mismatch
compareElement
    compareFloat :: (SVec.Storable a, RealFloat a, Show a) => a -> SVec.Vector a -> SVec.Vector a -> [Mismatch]
    compareFloat :: forall a.
(Storable a, RealFloat a, Show a) =>
a -> Vector a -> Vector a -> [Mismatch]
compareFloat = forall {a} {t} {a}.
(Storable a, Storable t) =>
(Int -> a -> t -> Maybe a) -> Vector a -> Vector t -> [a]
compareGen forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a.
(Show a, RealFloat a) =>
a -> Int -> a -> a -> Maybe Mismatch
compareFloatElement

    compareGen :: (Int -> a -> t -> Maybe a) -> Vector a -> Vector t -> [a]
compareGen Int -> a -> t -> Maybe a
cmp Vector a
got Vector t
expected =
      let l :: Int
l = forall a. Storable a => Vector a -> Int
SVec.length Vector a
got
          check :: [a] -> Int -> [a]
check [a]
acc Int
j
            | Int
j forall a. Ord a => a -> a -> Bool
< Int
l =
              case Int -> a -> t -> Maybe a
cmp Int
j (Vector a
got forall a. Storable a => Vector a -> Int -> a
SVec.! Int
j) (Vector t
expected forall a. Storable a => Vector a -> Int -> a
SVec.! Int
j) of
                Just a
mismatch ->
                  [a] -> Int -> [a]
check (a
mismatch forall a. a -> [a] -> [a]
: [a]
acc) (Int
j forall a. Num a => a -> a -> a
+ Int
1)
                Maybe a
Nothing ->
                  [a] -> Int -> [a]
check [a]
acc (Int
j forall a. Num a => a -> a -> a
+ Int
1)
            | Bool
otherwise =
              [a]
acc
       in forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ [a] -> Int -> [a]
check [] Int
0

    compareElement :: (Show a, Eq a) => Int -> a -> a -> Maybe Mismatch
    compareElement :: forall a. (Show a, Eq a) => Int -> a -> a -> Maybe Mismatch
compareElement Int
j a
got a
expected
      | a
got forall a. Eq a => a -> a -> Bool
== a
expected = forall a. Maybe a
Nothing
      | Bool
otherwise = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> Text -> Text -> Mismatch
PrimValueMismatch Int
i (Int -> [Int]
unflatten Int
j) (forall a. Show a => a -> Text
value a
got) (forall a. Show a => a -> Text
value a
expected)

    compareFloatElement :: (Show a, RealFloat a) => a -> Int -> a -> a -> Maybe Mismatch
    compareFloatElement :: forall a.
(Show a, RealFloat a) =>
a -> Int -> a -> a -> Maybe Mismatch
compareFloatElement a
abstol Int
j a
got a
expected
      | forall a. RealFloat a => a -> Bool
isNaN a
got,
        forall a. RealFloat a => a -> Bool
isNaN a
expected =
        forall a. Maybe a
Nothing
      | forall a. RealFloat a => a -> Bool
isInfinite a
got,
        forall a. RealFloat a => a -> Bool
isInfinite a
expected,
        forall a. Num a => a -> a
signum a
got forall a. Eq a => a -> a -> Bool
== forall a. Num a => a -> a
signum a
expected =
        forall a. Maybe a
Nothing
      | forall a. Num a => a -> a
abs (a
got forall a. Num a => a -> a -> a
- a
expected) forall a. Ord a => a -> a -> Bool
<= a
abstol = forall a. Maybe a
Nothing
      | Bool
otherwise = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> Text -> Text -> Mismatch
PrimValueMismatch Int
i (Int -> [Int]
unflatten Int
j) (forall a. Show a => a -> Text
value a
got) (forall a. Show a => a -> Text
value a
expected)

    compareBool :: Int -> a -> a -> Maybe Mismatch
compareBool Int
j a
got a
expected
      | a
got forall a. Eq a => a -> a -> Bool
== a
expected = forall a. Maybe a
Nothing
      | Bool
otherwise = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> Text -> Text -> Mismatch
PrimValueMismatch Int
i (Int -> [Int]
unflatten Int
j) (forall a. Show a => a -> Text
value a
got) (forall a. Show a => a -> Text
value a
expected)

tolerance :: (RealFloat a, SVec.Storable a) => a -> Vector a -> a
tolerance :: forall a. (RealFloat a, Storable a) => a -> Vector a -> a
tolerance a
tol = forall b a. Storable b => (a -> b -> a) -> a -> Vector b -> a
SVec.foldl a -> a -> a
tolerance' a
tol forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Storable a => (a -> Bool) -> Vector a -> Vector a
SVec.filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. RealFloat a => a -> Bool
nanOrInf)
  where
    tolerance' :: a -> a -> a
tolerance' a
t a
v = forall a. Ord a => a -> a -> a
max a
t forall a b. (a -> b) -> a -> b
$ a
tol forall a. Num a => a -> a -> a
* a
v
    nanOrInf :: a -> Bool
nanOrInf a
x = forall a. RealFloat a => a -> Bool
isInfinite a
x Bool -> Bool -> Bool
|| forall a. RealFloat a => a -> Bool
isNaN a
x