{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Array.Accelerate.Classes.RealFloat (
RealFloat(..),
) where
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Language ( cond, while )
import Data.Array.Accelerate.Pattern
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Data.Bits
import Data.Array.Accelerate.Classes.Eq
import Data.Array.Accelerate.Classes.Floating
import Data.Array.Accelerate.Classes.FromIntegral
import Data.Array.Accelerate.Classes.Num
import Data.Array.Accelerate.Classes.Ord
import Data.Array.Accelerate.Classes.RealFrac
import Text.Printf
import Prelude ( (.), ($), String, error, undefined, unlines, otherwise )
import qualified Prelude as P
class (RealFrac a, Floating a) => RealFloat a where
floatRadix :: Exp a -> Exp Int64
default floatRadix :: P.RealFloat a => Exp a -> Exp Int64
floatRadix _ = P.fromInteger (P.floatRadix (undefined::a))
floatDigits :: Exp a -> Exp Int
default floatDigits :: P.RealFloat a => Exp a -> Exp Int
floatDigits _ = constant (P.floatDigits (undefined::a))
floatRange :: Exp a -> (Exp Int, Exp Int)
default floatRange :: P.RealFloat a => Exp a -> (Exp Int, Exp Int)
floatRange _ = let (m,n) = P.floatRange (undefined::a)
in (constant m, constant n)
decodeFloat :: Exp a -> (Exp Int64, Exp Int)
encodeFloat :: Exp Int64 -> Exp Int -> Exp a
default encodeFloat :: (FromIntegral Int a, FromIntegral Int64 a) => Exp Int64 -> Exp Int -> Exp a
encodeFloat x e = fromIntegral x * (fromIntegral (floatRadix (undefined :: Exp a)) ** fromIntegral e)
exponent :: Exp a -> Exp Int
exponent x = let (m,n) = decodeFloat x
in cond (m == 0)
0
(n + floatDigits x)
significand :: Exp a -> Exp a
significand x = let (m,_) = decodeFloat x
in encodeFloat m (negate (floatDigits x))
scaleFloat :: Exp Int -> Exp a -> Exp a
scaleFloat k x =
cond (k == 0 || isFix) x
$ encodeFloat m (n + clamp b)
where
isFix = x == 0 || isNaN x || isInfinite x
(m,n) = decodeFloat x
(l,h) = floatRange x
d = floatDigits x
b = h - l + 4*d
clamp bd = max (-bd) (min bd k)
isNaN :: Exp a -> Exp Bool
isInfinite :: Exp a -> Exp Bool
isDenormalized :: Exp a -> Exp Bool
isNegativeZero :: Exp a -> Exp Bool
isIEEE :: Exp a -> Exp Bool
default isIEEE :: P.RealFloat a => Exp a -> Exp Bool
isIEEE _ = constant (P.isIEEE (undefined::a))
atan2 :: Exp a -> Exp a -> Exp a
instance RealFloat Half where
atan2 = mkAtan2
isNaN = mkIsNaN
isInfinite = mkIsInfinite
isDenormalized = ieee754 "isDenormalized" (ieee754_f16_is_denormalized . mkBitcast)
isNegativeZero = ieee754 "isNegativeZero" (ieee754_f16_is_negative_zero . mkBitcast)
decodeFloat = ieee754 "decodeFloat" (\x -> let T2 m n = ieee754_f16_decode (mkBitcast x)
in (fromIntegral m, n))
instance RealFloat Float where
atan2 = mkAtan2
isNaN = mkIsNaN
isInfinite = mkIsInfinite
isDenormalized = ieee754 "isDenormalized" (ieee754_f32_is_denormalized . mkBitcast)
isNegativeZero = ieee754 "isNegativeZero" (ieee754_f32_is_negative_zero . mkBitcast)
decodeFloat = ieee754 "decodeFloat" (\x -> let T2 m n = ieee754_f32_decode (mkBitcast x)
in (fromIntegral m, n))
instance RealFloat Double where
atan2 = mkAtan2
isNaN = mkIsNaN
isInfinite = mkIsInfinite
isDenormalized = ieee754 "isDenormalized" (ieee754_f64_is_denormalized . mkBitcast)
isNegativeZero = ieee754 "isNegativeZero" (ieee754_f64_is_negative_zero . mkBitcast)
decodeFloat = ieee754 "decodeFloat" (\x -> let T2 m n = ieee754_f64_decode (mkBitcast x)
in (m, n))
instance RealFloat CFloat where
atan2 = mkAtan2
isNaN = mkIsNaN . mkBitcast @Float
isInfinite = mkIsInfinite . mkBitcast @Float
isDenormalized = ieee754 "isDenormalized" (ieee754_f32_is_denormalized . mkBitcast)
isNegativeZero = ieee754 "isNegativeZero" (ieee754_f32_is_negative_zero . mkBitcast)
decodeFloat = ieee754 "decodeFloat" (\x -> let T2 m n = ieee754_f32_decode (mkBitcast x)
in (fromIntegral m, n))
encodeFloat x e = mkBitcast (encodeFloat @Float x e)
instance RealFloat CDouble where
atan2 = mkAtan2
isNaN = mkIsNaN . mkBitcast @Double
isInfinite = mkIsInfinite . mkBitcast @Double
isDenormalized = ieee754 "isDenormalized" (ieee754_f64_is_denormalized . mkBitcast)
isNegativeZero = ieee754 "isNegativeZero" (ieee754_f64_is_negative_zero . mkBitcast)
decodeFloat = ieee754 "decodeFloat" (\x -> let T2 m n = ieee754_f64_decode (mkBitcast x)
in (m, n))
encodeFloat x e = mkBitcast (encodeFloat @Double x e)
instance RealFloat a => P.RealFloat (Exp a) where
floatRadix = preludeError "floatRadix"
floatDigits = preludeError "floatDigits"
floatRange = preludeError "floatRange"
decodeFloat = preludeError "decodeFloat"
encodeFloat = preludeError "encodeFloat"
isNaN = preludeError "isNaN"
isInfinite = preludeError "isInfinite"
isDenormalized = preludeError "isDenormalized"
isNegativeZero = preludeError "isNegativeZero"
isIEEE = preludeError "isIEEE"
preludeError :: String -> a
preludeError x
= error
$ unlines [ printf "Prelude.%s applied to EDSL types: use Data.Array.Accelerate.%s instead" x x
, ""
, "These Prelude.RealFloat instances are present only to fulfil superclass"
, "constraints for subsequent classes in the standard Haskell numeric hierarchy."
]
ieee754 :: forall a b. HasCallStack => P.RealFloat a => String -> (Exp a -> b) -> Exp a -> b
ieee754 name f x
| P.isIEEE (undefined::a) = f x
| otherwise = internalError (printf "%s: Not implemented for non-IEEE floating point" name)
ieee754_f64_is_denormalized :: Exp Word64 -> Exp Bool
ieee754_f64_is_denormalized x =
ieee754_f64_mantissa x == 0 &&
ieee754_f64_exponent x /= 0
ieee754_f32_is_denormalized :: Exp Word32 -> Exp Bool
ieee754_f32_is_denormalized x =
ieee754_f32_mantissa x == 0 &&
ieee754_f32_exponent x /= 0
ieee754_f16_is_denormalized :: Exp Word16 -> Exp Bool
ieee754_f16_is_denormalized x =
ieee754_f16_mantissa x == 0 &&
ieee754_f16_exponent x /= 0
ieee754_f64_is_negative_zero :: Exp Word64 -> Exp Bool
ieee754_f64_is_negative_zero x =
ieee754_f64_negative x &&
ieee754_f64_exponent x == 0 &&
ieee754_f64_mantissa x == 0
ieee754_f32_is_negative_zero :: Exp Word32 -> Exp Bool
ieee754_f32_is_negative_zero x =
ieee754_f32_negative x &&
ieee754_f32_exponent x == 0 &&
ieee754_f32_mantissa x == 0
ieee754_f16_is_negative_zero :: Exp Word16 -> Exp Bool
ieee754_f16_is_negative_zero x =
ieee754_f16_negative x &&
ieee754_f16_exponent x == 0 &&
ieee754_f16_mantissa x == 0
ieee754_f64_mantissa :: Exp Word64 -> Exp Word64
ieee754_f64_mantissa x = x .&. 0xFFFFFFFFFFFFF
ieee754_f64_exponent :: Exp Word64 -> Exp Word16
ieee754_f64_exponent x = fromIntegral (x `unsafeShiftR` 52) .&. 0x7FF
ieee754_f64_negative :: Exp Word64 -> Exp Bool
ieee754_f64_negative x = testBit x 63
ieee754_f32_mantissa :: Exp Word32 -> Exp Word32
ieee754_f32_mantissa x = x .&. 0x7FFFFF
ieee754_f32_exponent :: Exp Word32 -> Exp Word8
ieee754_f32_exponent x = fromIntegral (x `unsafeShiftR` 23)
ieee754_f32_negative :: Exp Word32 -> Exp Bool
ieee754_f32_negative x = testBit x 31
ieee754_f16_mantissa :: Exp Word16 -> Exp Word16
ieee754_f16_mantissa x = x .&. 0x3FF
ieee754_f16_exponent :: Exp Word16 -> Exp Word8
ieee754_f16_exponent x = fromIntegral (x `unsafeShiftR` 10) .&. 0x1F
ieee754_f16_negative :: Exp Word16 -> Exp Bool
ieee754_f16_negative x = testBit x 15
ieee754_f16_decode :: Exp Word16 -> Exp (Int16, Int)
ieee754_f16_decode i =
let
_HHIGHBIT = 0x0400
_HMSBIT = 0x8000
_HMINEXP = ((_HALF_MIN_EXP) - (_HALF_MANT_DIG) - 1)
_HALF_MANT_DIG = floatDigits (undefined::Exp Half)
(_HALF_MIN_EXP, _HALF_MAX_EXP) = floatRange (undefined::Exp Half)
high1 = fromIntegral i
high2 = high1 .&. (_HHIGHBIT - 1)
exp1 = ((fromIntegral high1 `unsafeShiftR` 10) .&. 0x1F) + _HMINEXP
exp2 = exp1 + 1
T2 high3 exp3
= cond (exp1 /= _HMINEXP)
(T2 (high2 .|. _HHIGHBIT) exp1)
(while (\(T2 h _) -> (h .&. _HHIGHBIT) /= 0 )
(\(T2 h e) -> T2 (h `unsafeShiftL` 1) (e-1))
(T2 high2 exp2))
high4 = cond (fromIntegral i < (0 :: Exp Int16)) (-high3) high3
in
cond (high1 .&. complement _HMSBIT == 0)
(T2 0 0)
(T2 high4 exp3)
ieee754_f32_decode :: Exp Word32 -> Exp (Int32, Int)
ieee754_f32_decode i =
let
_FHIGHBIT = 0x00800000
_FMSBIT = 0x80000000
_FMINEXP = ((_FLT_MIN_EXP) - (_FLT_MANT_DIG) - 1)
_FLT_MANT_DIG = floatDigits (undefined::Exp Float)
(_FLT_MIN_EXP, _FLT_MAX_EXP) = floatRange (undefined::Exp Float)
high1 = fromIntegral i
high2 = high1 .&. (_FHIGHBIT - 1)
exp1 = ((fromIntegral high1 `unsafeShiftR` 23) .&. 0xFF) + _FMINEXP
exp2 = exp1 + 1
T2 high3 exp3
= cond (exp1 /= _FMINEXP)
(T2 (high2 .|. _FHIGHBIT) exp1)
(while (\(T2 h _) -> (h .&. _FHIGHBIT) /= 0 )
(\(T2 h e) -> T2 (h `unsafeShiftL` 1) (e-1))
(T2 high2 exp2))
high4 = cond (fromIntegral i < (0 :: Exp Int32)) (-high3) high3
in
cond (high1 .&. complement _FMSBIT == 0)
(T2 0 0)
(T2 high4 exp3)
ieee754_f64_decode :: Exp Word64 -> Exp (Int64, Int)
ieee754_f64_decode i =
let T4 s h l e = ieee754_f64_decode2 i
in T2 (fromIntegral s * (fromIntegral h `unsafeShiftL` 32 .|. fromIntegral l)) e
ieee754_f64_decode2 :: Exp Word64 -> Exp (Int, Word32, Word32, Int)
ieee754_f64_decode2 i =
let
_DHIGHBIT = 0x00100000
_DMSBIT = 0x80000000
_DMINEXP = ((_DBL_MIN_EXP) - (_DBL_MANT_DIG) - 1)
_DBL_MANT_DIG = floatDigits (undefined::Exp Double)
(_DBL_MIN_EXP, _DBL_MAX_EXP) = floatRange (undefined::Exp Double)
low = fromIntegral i
high = fromIntegral (i `unsafeShiftR` 32)
iexp = (fromIntegral ((high `unsafeShiftR` 20) .&. 0x7FF) + _DMINEXP)
sign = cond (fromIntegral i < (0 :: Exp Int64)) (-1) 1
high2 = high .&. (_DHIGHBIT - 1)
iexp2 = iexp + 1
T3 hi lo ie
= cond (iexp2 /= _DMINEXP)
(T3 (high2 .|. _DHIGHBIT) low iexp)
(while (\(T3 h _ _) -> (h .&. _DHIGHBIT) /= 0)
(\(T3 h l e) ->
let h1 = h `unsafeShiftL` 1
h2 = cond ((l .&. _DMSBIT) /= 0) (h1+1) h1
in T3 h2 (l `unsafeShiftL` 1) (e-1))
(T3 high2 low iexp2))
in
cond (low == 0 && (high .&. (complement _DMSBIT)) == 0)
(T4 1 0 0 0)
(T4 sign hi lo ie)