{-# LANGUAGE OverloadedStrings #-}
module Futhark.Data.Compare
( compareValues,
compareSeveralValues,
Tolerance (..),
Mismatch,
)
where
import Data.List (intersperse)
import qualified Data.Text as T
import qualified Data.Vector.Storable as SVec
import Futhark.Data
data Mismatch
=
PrimValueMismatch Int [Int] T.Text T.Text
| ArrayShapeMismatch Int [Int] [Int]
| TypeMismatch Int T.Text T.Text
| ValueCountMismatch Int Int
showText :: Show a => a -> T.Text
showText :: forall a. Show a => a -> Text
showText = String -> Text
T.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> String
show
explainMismatch :: T.Text -> T.Text -> T.Text -> T.Text -> T.Text
explainMismatch :: Text -> Text -> Text -> Text -> Text
explainMismatch Text
i Text
what Text
got Text
expected =
Text
"Value #" forall a. Semigroup a => a -> a -> a
<> Text
i forall a. Semigroup a => a -> a -> a
<> Text
": expected " forall a. Semigroup a => a -> a -> a
<> Text
what forall a. Semigroup a => a -> a -> a
<> Text
expected forall a. Semigroup a => a -> a -> a
<> Text
", got " forall a. Semigroup a => a -> a -> a
<> Text
got
instance Show Mismatch where
show :: Mismatch -> String
show (PrimValueMismatch Int
vi [] Text
got Text
expected) =
Text -> String
T.unpack forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text -> Text -> Text
explainMismatch (forall a. Show a => a -> Text
showText Int
vi) Text
"" Text
got Text
expected
show (PrimValueMismatch Int
vi [Int]
js Text
got Text
expected) =
Text -> String
T.unpack forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text -> Text -> Text
explainMismatch (forall a. Show a => a -> Text
showText Int
vi forall a. Semigroup a => a -> a -> a
<> Text
" index [" forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat (forall a. a -> [a] -> [a]
intersperse Text
"," (forall a b. (a -> b) -> [a] -> [b]
map forall a. Show a => a -> Text
showText [Int]
js)) forall a. Semigroup a => a -> a -> a
<> Text
"]") Text
"" Text
got Text
expected
show (ArrayShapeMismatch Int
i [Int]
got [Int]
expected) =
Text -> String
T.unpack forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text -> Text -> Text
explainMismatch (forall a. Show a => a -> Text
showText Int
i) Text
"array of shape " (forall a. Show a => a -> Text
showText [Int]
got) (forall a. Show a => a -> Text
showText [Int]
expected)
show (TypeMismatch Int
i Text
got Text
expected) =
Text -> String
T.unpack forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text -> Text -> Text
explainMismatch (forall a. Show a => a -> Text
showText Int
i) Text
"value of type " Text
got Text
expected
show (ValueCountMismatch Int
got Int
expected) =
Text -> String
T.unpack forall a b. (a -> b) -> a -> b
$ Text
"Expected " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> Text
showText Int
expected forall a. Semigroup a => a -> a -> a
<> Text
" values, got " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> Text
showText Int
got
newtype Tolerance = Tolerance Double
deriving (Tolerance -> Tolerance -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Tolerance -> Tolerance -> Bool
$c/= :: Tolerance -> Tolerance -> Bool
== :: Tolerance -> Tolerance -> Bool
$c== :: Tolerance -> Tolerance -> Bool
Eq, Eq Tolerance
Tolerance -> Tolerance -> Bool
Tolerance -> Tolerance -> Ordering
Tolerance -> Tolerance -> Tolerance
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Tolerance -> Tolerance -> Tolerance
$cmin :: Tolerance -> Tolerance -> Tolerance
max :: Tolerance -> Tolerance -> Tolerance
$cmax :: Tolerance -> Tolerance -> Tolerance
>= :: Tolerance -> Tolerance -> Bool
$c>= :: Tolerance -> Tolerance -> Bool
> :: Tolerance -> Tolerance -> Bool
$c> :: Tolerance -> Tolerance -> Bool
<= :: Tolerance -> Tolerance -> Bool
$c<= :: Tolerance -> Tolerance -> Bool
< :: Tolerance -> Tolerance -> Bool
$c< :: Tolerance -> Tolerance -> Bool
compare :: Tolerance -> Tolerance -> Ordering
$ccompare :: Tolerance -> Tolerance -> Ordering
Ord, Int -> Tolerance -> ShowS
[Tolerance] -> ShowS
Tolerance -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Tolerance] -> ShowS
$cshowList :: [Tolerance] -> ShowS
show :: Tolerance -> String
$cshow :: Tolerance -> String
showsPrec :: Int -> Tolerance -> ShowS
$cshowsPrec :: Int -> Tolerance -> ShowS
Show)
toleranceFloat :: RealFloat a => Tolerance -> a
toleranceFloat :: forall a. RealFloat a => Tolerance -> a
toleranceFloat (Tolerance Double
x) = forall a. Fractional a => Rational -> a
fromRational forall a b. (a -> b) -> a -> b
$ forall a. Real a => a -> Rational
toRational Double
x
compareValues :: Tolerance -> Value -> Value -> [Mismatch]
compareValues :: Tolerance -> Value -> Value -> [Mismatch]
compareValues Tolerance
tol = Tolerance -> Int -> Value -> Value -> [Mismatch]
compareValue Tolerance
tol Int
0
compareSeveralValues :: Tolerance -> [Value] -> [Value] -> [Mismatch]
compareSeveralValues :: Tolerance -> [Value] -> [Value] -> [Mismatch]
compareSeveralValues Tolerance
tol [Value]
got [Value]
expected
| Int
n forall a. Eq a => a -> a -> Bool
/= Int
m = [Int -> Int -> Mismatch
ValueCountMismatch Int
n Int
m]
| Bool
otherwise = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 (Tolerance -> Int -> Value -> Value -> [Mismatch]
compareValue Tolerance
tol) [Int
0 ..] [Value]
got [Value]
expected
where
n :: Int
n = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Value]
got
m :: Int
m = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Value]
expected
unflattenIndex :: [Int] -> Int -> [Int]
unflattenIndex :: [Int] -> Int -> [Int]
unflattenIndex = forall {t}. Integral t => [t] -> t -> [t]
unflattenIndexFromSlices forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> [a] -> [a]
drop Int
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a}. Num a => [a] -> [a]
sliceSizes
where
sliceSizes :: [a] -> [a]
sliceSizes [] = [a
1]
sliceSizes (a
n : [a]
ns) = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (a
n forall a. a -> [a] -> [a]
: [a]
ns) forall a. a -> [a] -> [a]
: [a] -> [a]
sliceSizes [a]
ns
unflattenIndexFromSlices :: [t] -> t -> [t]
unflattenIndexFromSlices [] t
_ = []
unflattenIndexFromSlices (t
size : [t]
slices) t
i =
(t
i forall a. Integral a => a -> a -> a
`quot` t
size) forall a. a -> [a] -> [a]
: [t] -> t -> [t]
unflattenIndexFromSlices [t]
slices (t
i forall a. Num a => a -> a -> a
- (t
i forall a. Integral a => a -> a -> a
`quot` t
size) forall a. Num a => a -> a -> a
* t
size)
compareValue :: Tolerance -> Int -> Value -> Value -> [Mismatch]
compareValue :: Tolerance -> Int -> Value -> Value -> [Mismatch]
compareValue Tolerance
tol Int
i Value
got_v Value
expected_v
| Value -> [Int]
valueShape Value
got_v forall a. Eq a => a -> a -> Bool
== Value -> [Int]
valueShape Value
expected_v =
case (Value
got_v, Value
expected_v) of
(I8Value Vector Int
_ Vector Int8
got_vs, I8Value Vector Int
_ Vector Int8
expected_vs) ->
forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Int8
got_vs Vector Int8
expected_vs
(I16Value Vector Int
_ Vector Int16
got_vs, I16Value Vector Int
_ Vector Int16
expected_vs) ->
forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Int16
got_vs Vector Int16
expected_vs
(I32Value Vector Int
_ Vector Int32
got_vs, I32Value Vector Int
_ Vector Int32
expected_vs) ->
forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Int32
got_vs Vector Int32
expected_vs
(I64Value Vector Int
_ Vector Int64
got_vs, I64Value Vector Int
_ Vector Int64
expected_vs) ->
forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Int64
got_vs Vector Int64
expected_vs
(U8Value Vector Int
_ Vector Word8
got_vs, U8Value Vector Int
_ Vector Word8
expected_vs) ->
forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Word8
got_vs Vector Word8
expected_vs
(U16Value Vector Int
_ Vector Word16
got_vs, U16Value Vector Int
_ Vector Word16
expected_vs) ->
forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Word16
got_vs Vector Word16
expected_vs
(U32Value Vector Int
_ Vector Word32
got_vs, U32Value Vector Int
_ Vector Word32
expected_vs) ->
forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Word32
got_vs Vector Word32
expected_vs
(U64Value Vector Int
_ Vector Word64
got_vs, U64Value Vector Int
_ Vector Word64
expected_vs) ->
forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Word64
got_vs Vector Word64
expected_vs
(F16Value Vector Int
_ Vector Half
got_vs, F16Value Vector Int
_ Vector Half
expected_vs) ->
forall a.
(Storable a, RealFloat a, Show a) =>
a -> Vector a -> Vector a -> [Mismatch]
compareFloat (forall a. (RealFloat a, Storable a) => a -> Vector a -> a
tolerance (forall a. RealFloat a => Tolerance -> a
toleranceFloat Tolerance
tol) Vector Half
expected_vs) Vector Half
got_vs Vector Half
expected_vs
(F32Value Vector Int
_ Vector Float
got_vs, F32Value Vector Int
_ Vector Float
expected_vs) ->
forall a.
(Storable a, RealFloat a, Show a) =>
a -> Vector a -> Vector a -> [Mismatch]
compareFloat (forall a. (RealFloat a, Storable a) => a -> Vector a -> a
tolerance (forall a. RealFloat a => Tolerance -> a
toleranceFloat Tolerance
tol) Vector Float
expected_vs) Vector Float
got_vs Vector Float
expected_vs
(F64Value Vector Int
_ Vector Double
got_vs, F64Value Vector Int
_ Vector Double
expected_vs) ->
forall a.
(Storable a, RealFloat a, Show a) =>
a -> Vector a -> Vector a -> [Mismatch]
compareFloat (forall a. (RealFloat a, Storable a) => a -> Vector a -> a
tolerance (forall a. RealFloat a => Tolerance -> a
toleranceFloat Tolerance
tol) Vector Double
expected_vs) Vector Double
got_vs Vector Double
expected_vs
(BoolValue Vector Int
_ Vector Bool
got_vs, BoolValue Vector Int
_ Vector Bool
expected_vs) ->
forall {a} {t} {a}.
(Storable a, Storable t) =>
(Int -> a -> t -> Maybe a) -> Vector a -> Vector t -> [a]
compareGen forall {a}. (Eq a, Show a) => Int -> a -> a -> Maybe Mismatch
compareBool Vector Bool
got_vs Vector Bool
expected_vs
(Value, Value)
_ ->
[Int -> Text -> Text -> Mismatch
TypeMismatch Int
i (PrimType -> Text
primTypeText forall a b. (a -> b) -> a -> b
$ Value -> PrimType
valueElemType Value
got_v) (PrimType -> Text
primTypeText forall a b. (a -> b) -> a -> b
$ Value -> PrimType
valueElemType Value
expected_v)]
| Bool
otherwise =
[Int -> [Int] -> [Int] -> Mismatch
ArrayShapeMismatch Int
i (Value -> [Int]
valueShape Value
got_v) (Value -> [Int]
valueShape Value
expected_v)]
where
unflatten :: Int -> [Int]
unflatten = [Int] -> Int -> [Int]
unflattenIndex (Value -> [Int]
valueShape Value
got_v)
value :: Show a => a -> T.Text
value :: forall a. Show a => a -> Text
value = String -> Text
T.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> String
show
{-# INLINE compareGen #-}
{-# INLINE compareNum #-}
{-# INLINE compareFloat #-}
{-# INLINE compareFloatElement #-}
{-# INLINE compareElement #-}
compareNum :: (SVec.Storable a, Eq a, Show a) => SVec.Vector a -> SVec.Vector a -> [Mismatch]
compareNum :: forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum = forall {a} {t} {a}.
(Storable a, Storable t) =>
(Int -> a -> t -> Maybe a) -> Vector a -> Vector t -> [a]
compareGen forall a. (Show a, Eq a) => Int -> a -> a -> Maybe Mismatch
compareElement
compareFloat :: (SVec.Storable a, RealFloat a, Show a) => a -> SVec.Vector a -> SVec.Vector a -> [Mismatch]
compareFloat :: forall a.
(Storable a, RealFloat a, Show a) =>
a -> Vector a -> Vector a -> [Mismatch]
compareFloat = forall {a} {t} {a}.
(Storable a, Storable t) =>
(Int -> a -> t -> Maybe a) -> Vector a -> Vector t -> [a]
compareGen forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a.
(Show a, RealFloat a) =>
a -> Int -> a -> a -> Maybe Mismatch
compareFloatElement
compareGen :: (Int -> a -> t -> Maybe a) -> Vector a -> Vector t -> [a]
compareGen Int -> a -> t -> Maybe a
cmp Vector a
got Vector t
expected =
let l :: Int
l = forall a. Storable a => Vector a -> Int
SVec.length Vector a
got
check :: [a] -> Int -> [a]
check [a]
acc Int
j
| Int
j forall a. Ord a => a -> a -> Bool
< Int
l =
case Int -> a -> t -> Maybe a
cmp Int
j (Vector a
got forall a. Storable a => Vector a -> Int -> a
SVec.! Int
j) (Vector t
expected forall a. Storable a => Vector a -> Int -> a
SVec.! Int
j) of
Just a
mismatch ->
[a] -> Int -> [a]
check (a
mismatch forall a. a -> [a] -> [a]
: [a]
acc) (Int
j forall a. Num a => a -> a -> a
+ Int
1)
Maybe a
Nothing ->
[a] -> Int -> [a]
check [a]
acc (Int
j forall a. Num a => a -> a -> a
+ Int
1)
| Bool
otherwise =
[a]
acc
in forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ [a] -> Int -> [a]
check [] Int
0
compareElement :: (Show a, Eq a) => Int -> a -> a -> Maybe Mismatch
compareElement :: forall a. (Show a, Eq a) => Int -> a -> a -> Maybe Mismatch
compareElement Int
j a
got a
expected
| a
got forall a. Eq a => a -> a -> Bool
== a
expected = forall a. Maybe a
Nothing
| Bool
otherwise = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> Text -> Text -> Mismatch
PrimValueMismatch Int
i (Int -> [Int]
unflatten Int
j) (forall a. Show a => a -> Text
value a
got) (forall a. Show a => a -> Text
value a
expected)
compareFloatElement :: (Show a, RealFloat a) => a -> Int -> a -> a -> Maybe Mismatch
compareFloatElement :: forall a.
(Show a, RealFloat a) =>
a -> Int -> a -> a -> Maybe Mismatch
compareFloatElement a
abstol Int
j a
got a
expected
| forall a. RealFloat a => a -> Bool
isNaN a
got,
forall a. RealFloat a => a -> Bool
isNaN a
expected =
forall a. Maybe a
Nothing
| forall a. RealFloat a => a -> Bool
isInfinite a
got,
forall a. RealFloat a => a -> Bool
isInfinite a
expected,
forall a. Num a => a -> a
signum a
got forall a. Eq a => a -> a -> Bool
== forall a. Num a => a -> a
signum a
expected =
forall a. Maybe a
Nothing
| forall a. Num a => a -> a
abs (a
got forall a. Num a => a -> a -> a
- a
expected) forall a. Ord a => a -> a -> Bool
<= a
abstol = forall a. Maybe a
Nothing
| Bool
otherwise = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> Text -> Text -> Mismatch
PrimValueMismatch Int
i (Int -> [Int]
unflatten Int
j) (forall a. Show a => a -> Text
value a
got) (forall a. Show a => a -> Text
value a
expected)
compareBool :: Int -> a -> a -> Maybe Mismatch
compareBool Int
j a
got a
expected
| a
got forall a. Eq a => a -> a -> Bool
== a
expected = forall a. Maybe a
Nothing
| Bool
otherwise = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> Text -> Text -> Mismatch
PrimValueMismatch Int
i (Int -> [Int]
unflatten Int
j) (forall a. Show a => a -> Text
value a
got) (forall a. Show a => a -> Text
value a
expected)
tolerance :: (RealFloat a, SVec.Storable a) => a -> Vector a -> a
tolerance :: forall a. (RealFloat a, Storable a) => a -> Vector a -> a
tolerance a
tol = forall b a. Storable b => (a -> b -> a) -> a -> Vector b -> a
SVec.foldl a -> a -> a
tolerance' a
tol forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Storable a => (a -> Bool) -> Vector a -> Vector a
SVec.filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. RealFloat a => a -> Bool
nanOrInf)
where
tolerance' :: a -> a -> a
tolerance' a
t a
v = forall a. Ord a => a -> a -> a
max a
t forall a b. (a -> b) -> a -> b
$ a
tol forall a. Num a => a -> a -> a
* a
v
nanOrInf :: a -> Bool
nanOrInf a
x = forall a. RealFloat a => a -> Bool
isInfinite a
x Bool -> Bool -> Bool
|| forall a. RealFloat a => a -> Bool
isNaN a
x