{-# LANGUAGE DataKinds #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE PostfixOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE BangPatterns #-}
module Data.CReal.Internal
(
CReal(..)
, Cache(..)
, atPrecision
, crealPrecision
, plusInteger
, mulBounded
, (.*.)
, mulBoundedL
, (.*)
, (*.)
, recipBounded
, shiftL
, shiftR
, square
, squareBounded
, expBounded
, expPosNeg
, logBounded
, atanBounded
, sinBounded
, cosBounded
, crMemoize
, powerSeries
, alternateSign
, (/.)
, (/^)
, log2
, log10
, isqrt
, showAtPrecision
, decimalDigitsAtPrecision
, rationalToDecimal
) where
import Data.List (scanl')
import qualified Data.Bits as B
import Data.Bits hiding (shiftL, shiftR)
import GHC.Base (Int(..))
import GHC.Integer.Logarithms (integerLog2#, integerLogBase#)
import GHC.Real (Ratio(..), (%))
import GHC.TypeLits
import Text.Read
import qualified Text.Read.Lex as L
import System.Random (Random(..), RandomGen(..))
import Control.Concurrent.MVar
import Control.Exception
import System.IO.Unsafe (unsafePerformIO)
{-# ANN module ("HLint: ignore Reduce duplication" :: String) #-}
default ()
data Cache
= Never
| Current {-# UNPACK #-} !Int !Integer
deriving (Show)
data CReal (n :: Nat) = CR {-# UNPACK #-} !(MVar Cache) (Int -> Integer)
crMemoize :: (Int -> Integer) -> CReal n
crMemoize fn = unsafePerformIO $ do
mvc <- newMVar Never
return $ CR mvc fn
crealPrecision :: KnownNat n => CReal n -> Int
crealPrecision = fromInteger . natVal
atPrecision :: CReal n -> Int -> Integer
(CR mvc f) `atPrecision` (!p) = unsafePerformIO $ modifyMVar mvc $ \vc -> do
vc' <- evaluate vc
case vc' of
Current j v | j >= p -> do
pure (vc', v /^ (j - p))
_ -> do
v <- evaluate $ f p
let !vcn = Current p v
pure (vcn, v)
{-# INLINABLE atPrecision #-}
instance KnownNat n => Show (CReal n) where
show x = showAtPrecision (crealPrecision x) x
instance Read (CReal n) where
readPrec = parens $ do
lit <- lexP
case lit of
Number n -> return $ fromRational $ L.numberToRational n
Symbol "-" -> prec 6 $ do
lit' <- lexP
case lit' of
Number n -> return $ fromRational $ negate $ L.numberToRational n
_ -> pfail
_ -> pfail
{-# INLINE readPrec #-}
readListPrec = readListPrecDefault
{-# INLINE readListPrec #-}
readsPrec = readPrec_to_S readPrec
{-# INLINE readsPrec #-}
readList = readPrec_to_S readListPrec 0
{-# INLINE readList #-}
instance Num (CReal n) where
{-# INLINE fromInteger #-}
fromInteger i = let
!vc = Current 0 i
in unsafePerformIO $ do
mvc <- newMVar vc
return $ CR mvc (B.shiftL i)
{-# INLINE negate #-}
negate (CR mvc fn) = unsafePerformIO $ do
vcc <- tryReadMVar mvc
let
!vcn = case vcc of
Nothing -> Never
Just Never -> Never
Just (Current p v) -> Current p (negate v)
mvn <- newMVar vcn
return $ CR mvn (negate . fn)
{-# INLINE abs #-}
abs (CR mvc fn) = unsafePerformIO $ do
vcc <- tryReadMVar mvc
let
!vcn = case vcc of
Nothing -> Never
Just Never -> Never
Just (Current p v) -> Current p (abs v)
mvn <- newMVar vcn
return $ CR mvn (abs . fn)
{-# INLINE (+) #-}
x1 + x2 = crMemoize (\p -> let n1 = atPrecision x1 (p + 2)
n2 = atPrecision x2 (p + 2)
in (n1 + n2) /^ 2)
{-# INLINE (-) #-}
x1 - x2 = crMemoize (\p -> let n1 = atPrecision x1 (p + 2)
n2 = atPrecision x2 (p + 2)
in (n1 - n2) /^ 2)
{-# INLINE (*) #-}
x1 * x2 = let
s1 = log2 (abs (atPrecision x1 0) + 2) + 3
s2 = log2 (abs (atPrecision x2 0) + 2) + 3
in crMemoize (\p -> let n1 = atPrecision x1 (p + s2)
n2 = atPrecision x2 (p + s1)
in (n1 * n2) /^ (p + s1 + s2))
signum x = crMemoize (\p -> B.shiftL (signum (x `atPrecision` p)) p)
instance Fractional (CReal n) where
{-# INLINE fromRational #-}
fromRational (n :% d) = crMemoize (\p -> roundD (B.shiftL n p) d)
{-# INLINE recip #-}
recip x = let
s = findFirstMonotonic ((3 <=) . abs . atPrecision x)
in crMemoize (\p -> let n = atPrecision x (p + 2 * s + 2)
in bit (2 * p + 2 * s + 2) /. n)
instance Floating (CReal n) where
pi = piBy4 `shiftL` 2
exp x = let o = shiftL (x *. recipBounded (shiftL ln2 1)) 1
l = atPrecision o 0
y = x - fromInteger l *. ln2
in if l == 0
then expBounded x
else expBounded y `shiftL` fromInteger l
log x = let l = log2 (atPrecision x 2) - 2
in if
| l < 0 -> - log (recip x)
| l == 0 -> logBounded x
| l > 0 -> let a = x `shiftR` l
in logBounded a + fromIntegral l *. ln2
sqrt x = crMemoize (\p -> let n = atPrecision x (2 * p)
in isqrt n)
x ** y = exp (log x * y)
logBase x y = log y / log x
sin x = cos (x - piBy2)
cos x = let o = shiftL (x *. recipBounded pi) 2
s = atPrecision o 1 /^ 1
octant = fromInteger $ s .&. 7
offset = x - (fromIntegral s *. piBy4)
fs = [ cosBounded
, negate . sinBounded . subtract piBy4
, negate . sinBounded
, negate . cosBounded . (piBy4-)
, negate . cosBounded
, sinBounded . subtract piBy4
, sinBounded
, cosBounded . (piBy4-)]
in (fs !! octant) offset
tan x = sin x .* recip (cos x)
asin x = (atan (x .*. recipBounded (1 + sqrt (1 - squareBounded x)))) `shiftL` 1
acos x = piBy2 - asin x
atan x = let
q = x `atPrecision` 2
in if
| q < -4 -> atanBounded (negate (recipBounded x)) - piBy2
| q == -4 -> -(piBy4 + atanBounded ((x + 1) .*. recipBounded (x - 1)))
| q == 4 -> piBy4 + atanBounded ((x - 1) .*. recipBounded (x + 1))
| q > 4 -> piBy2 - atanBounded (recipBounded x)
| otherwise -> atanBounded x
sinh x = let (expX, expNegX) = expPosNeg x
in (expX - expNegX) `shiftR` 1
cosh x = let (expX, expNegX) = expPosNeg x
in (expX + expNegX) `shiftR` 1
tanh x = let e2x = exp (x `shiftL` 1)
in (e2x - 1) *. recipBounded (e2x + 1)
asinh x = log (x + sqrt (square x + 1))
acosh x = log (x + sqrt (x + 1) * sqrt (x - 1))
atanh x = (log (1 + x) - log (1 - x)) `shiftR` 1
instance KnownNat n => Real (CReal n) where
toRational x = let p = crealPrecision x
in x `atPrecision` p % bit p
instance KnownNat n => RealFrac (CReal n) where
properFraction x = let p = crealPrecision x
v = x `atPrecision` p
r = v .&. (bit p - 1)
c = unsafeShiftR (v - r) p
n = if c < 0 && r /= 0 then c + 1 else c
f = plusInteger x (negate n)
in (fromInteger n, f)
truncate x = let p = crealPrecision x
v = x `atPrecision` p
r = v .&. (bit p - 1)
c = unsafeShiftR (v - r) p
n = if c < 0 && r /= 0 then c + 1 else c
in fromInteger n
round x = let p = crealPrecision x
n = (x `atPrecision` p) /^ p
in fromInteger n
ceiling x = let p = crealPrecision x
v = x `atPrecision` p
r = v .&. (bit p - 1)
n = unsafeShiftR (v - r) p
in case r /= 0 of
True -> fromInteger $ n + 1
_ -> fromInteger n
floor x = let p = crealPrecision x
v = x `atPrecision` p
r = v .&. (bit p - 1)
n = unsafeShiftR (v - r) p
in fromInteger n
instance KnownNat n => RealFloat (CReal n) where
floatRadix _ = 2
floatDigits _ = error "Data.CReal.Internal floatDigits"
floatRange _ = error "Data.CReal.Internal floatRange"
decodeFloat x = let p = crealPrecision x
in (x `atPrecision` p, -p)
encodeFloat m n = if n <= 0
then fromRational (m % bit (negate n))
else fromRational (unsafeShiftL m n :% 1)
exponent = error "Data.CReal.Internal exponent"
significand = error "Data.CReal.Internal significand"
scaleFloat = flip shiftL
isNaN _ = False
isInfinite _ = False
isDenormalized _ = False
isNegativeZero _ = False
isIEEE _ = False
atan2 y x = crMemoize (\p ->
let y' = y `atPrecision` p
x' = x `atPrecision` p
θ = if | x' > 0 -> atan (y/x)
| x' == 0 && y' > 0 -> piBy2
| x' < 0 && y' > 0 -> pi + atan (y/x)
| x' <= 0 && y' < 0 -> -atan2 (-y) x
| y' == 0 && x' < 0 -> pi
| x'==0 && y'==0 -> 0
| otherwise -> error "Data.CReal.Internal atan2"
in θ `atPrecision` p)
instance KnownNat n => Eq (CReal n) where
CR mvx _ == CR mvy _ | mvx == mvy = True
x == y = let p = crealPrecision x + 2
in (atPrecision x p - atPrecision y p) /^ 2 == 0
instance KnownNat n => Ord (CReal n) where
compare (CR mvx _) (CR mvy _) | mvx == mvy = EQ
compare x y = let p = crealPrecision x + 2
in compare ((atPrecision x p - atPrecision y p) /^ 2) 0
max x y = crMemoize (\p -> max (atPrecision x p) (atPrecision y p))
min x y = crMemoize (\p -> min (atPrecision x p) (atPrecision y p))
instance KnownNat n => Random (CReal n) where
randomR (lo, hi) g = let d = hi - lo
l = 1 + log2 (abs d `atPrecision` 0)
p = l + crealPrecision lo
(n, g') = randomR (0, 2^p) g
r = fromRational (n % 2^p)
in (r * d + lo, g')
random g = let p = 1 + crealPrecision (undefined :: CReal n)
(n, g') = randomR (0, max 0 (2^p - 2)) g
r = fromRational (n % 2^p)
in (r, g')
piBy4 :: CReal n
piBy4 = (atanBounded (recipBounded 5) `shiftL` 2) - atanBounded (recipBounded 239)
piBy2 :: CReal n
piBy2 = piBy4 `shiftL` 1
ln2 :: CReal n
ln2 = logBounded 2
infixl 7 `mulBounded`, `mulBoundedL`, .*, *., .*.
(.*) :: CReal n -> CReal n -> CReal n
(.*) = mulBoundedL
(*.) :: CReal n -> CReal n -> CReal n
(*.) = flip mulBoundedL
(.*.) :: CReal n -> CReal n -> CReal n
(.*.) = mulBounded
mulBoundedL :: CReal n -> CReal n -> CReal n
mulBoundedL x1 x2 = let
s1 = 4
s2 = log2 (abs (atPrecision x2 0) + 2) + 3
in crMemoize (\p -> let n1 = atPrecision x1 (p + s2)
n2 = atPrecision x2 (p + s1)
in (n1 * n2) /^ (p + s1 + s2))
mulBounded :: CReal n -> CReal n -> CReal n
mulBounded x1 x2 = let
s1 = 4
s2 = 4
in crMemoize (\p -> let n1 = atPrecision x1 (p + s2)
n2 = atPrecision x2 (p + s1)
in (n1 * n2) /^ (p + s1 + s2))
recipBounded :: CReal n -> CReal n
recipBounded x = crMemoize (\p -> let s = 2
n = atPrecision x (p + 2 * s + 2)
in bit (2 * p + 2 * s + 2) /. n)
{-# INLINABLE square #-}
square :: CReal n -> CReal n
square x = let
s = log2 (abs (atPrecision x 0) + 2) + 3
in crMemoize (\p -> let n = atPrecision x (p + s)
in (n * n) /^ (p + 2 * s))
{-# INLINABLE squareBounded #-}
squareBounded :: CReal n -> CReal n
squareBounded x@(CR mvc _) = unsafePerformIO $ do
vcc <- tryReadMVar mvc
let
!s = 4
!vcn = case vcc of
Nothing -> Never
Just Never -> Never
Just (Current j n) -> case j - s of
p | p < 0 -> Never
p -> Current p ((n * n) /^ (p + 2 * s))
fn' !p = let n = atPrecision x (p + s)
in (n * n) /^ (p + 2 * s)
mvn <- newMVar vcn
return $ CR mvn fn'
expBounded :: CReal n -> CReal n
expBounded x = let q = (1%) <$> scanl' (*) 1 [1..]
in powerSeries q (max 5) x
logBounded :: CReal n -> CReal n
logBounded x = let q = [1 % n | n <- [1..]]
y = (x - 1) .* recip x
in y .* powerSeries q id y
expPosNeg :: CReal n -> (CReal n, CReal n)
expPosNeg x = let o = x / ln2
l = atPrecision o 0
y = x - fromInteger l * ln2
in if l == 0
then (expBounded x, expBounded (-x))
else (expBounded y `shiftL` fromInteger l,
expBounded (negate y) `shiftR` fromInteger l)
sinBounded :: CReal n -> CReal n
sinBounded x = let q = alternateSign (scanl' (*) 1 [ 1 % (n*(n+1)) | n <- [2,4..]])
in x .* powerSeries q (max 1) (squareBounded x)
cosBounded :: CReal n -> CReal n
cosBounded x = let q = alternateSign (scanl' (*) 1 [1 % (n*(n+1)) | n <- [1,3..]])
in powerSeries q (max 1) (squareBounded x)
atanBounded :: CReal n -> CReal n
atanBounded x = let q = scanl' (*) 1 [n % (n + 1) | n <- [2,4..]]
s = squareBounded x
rd = recipBounded (plusInteger s 1)
in (x .*. rd) .* powerSeries q (+1) (s .*. rd)
infixl 6 `plusInteger`
{-# INLINE plusInteger #-}
plusInteger :: CReal n -> Integer -> CReal n
plusInteger x 0 = x
plusInteger (CR mvc fn) n = unsafePerformIO $ do
vcc <- tryReadMVar mvc
let
!vcn = case vcc of
Nothing -> Never
Just Never -> Never
Just (Current j v) -> Current j (v + unsafeShiftL n j)
fn' !p = fn p + B.shiftL n p
mvc' <- newMVar vcn
return $ CR mvc' fn'
infixl 8 `shiftL`, `shiftR`
shiftR :: CReal n -> Int -> CReal n
shiftR x n = crMemoize (\p -> let p' = p - n
in if p' >= 0
then atPrecision x p'
else atPrecision x 0 /^ (negate p'))
shiftL :: CReal n -> Int -> CReal n
shiftL x = shiftR x . negate
showAtPrecision :: Int -> CReal n -> String
showAtPrecision p x = let places = decimalDigitsAtPrecision p
r = atPrecision x p % bit p
in rationalToDecimal places r
decimalDigitsAtPrecision :: Int -> Int
decimalDigitsAtPrecision 0 = 0
decimalDigitsAtPrecision p = log10 (bit p) + 1
rationalToDecimal :: Int -> Rational -> String
rationalToDecimal places (n :% d) = p ++ is ++ if places > 0 then "." ++ fs else ""
where p = case signum n of
-1 -> "-"
_ -> ""
ds = show (roundD (abs n * 10^places) d)
l = length ds
(is, fs) = if | l <= places -> ("0", replicate (places - l) '0' ++ ds)
| otherwise -> splitAt (l - places) ds
divZeroErr :: a
divZeroErr = error "Division by zero"
{-# NOINLINE divZeroErr #-}
roundD :: Integer -> Integer -> Integer
roundD n d = case divMod n d of
(q, r) -> case compare (unsafeShiftL r 1) d of
LT -> q
EQ -> if testBit q 0 then q + 1 else q
GT -> q + 1
{-# INLINE roundD #-}
infixl 7 /.
(/.) :: Integer -> Integer -> Integer
(!n) /. (!d) = case compare d 0 of
LT -> roundD (negate n) (negate d)
EQ -> divZeroErr
GT -> roundD n d
{-# INLINABLE (/.) #-}
infixl 7 /^
(/^) :: Integer -> Int -> Integer
(!n) /^ (!p) = case compare p 0 of
LT -> unsafeShiftL n (negate p)
EQ -> n
GT -> let
!bp = bit p
!r = n .&. (bp - 1)
!q = unsafeShiftR (n - r) p
in case compare (unsafeShiftL r 1) bp of
LT -> q
EQ -> if testBit q 0 then q + 1 else q
GT -> q + 1
{-# INLINE log2 #-}
log2 :: Integer -> Int
log2 x = I# (integerLog2# x)
{-# INLINE log10 #-}
log10 :: Integer -> Int
log10 x = I# (integerLogBase# 10 x)
{-# INLINABLE isqrt #-}
isqrt :: Integer -> Integer
isqrt x | x < 0 = error "Sqrt applied to negative Integer"
| x == 0 = 0
| otherwise = until satisfied improve initialGuess
where improve r = unsafeShiftR (r + (x `div` r)) 1
satisfied r = let r2 = r * r in r2 <= x && r2 + unsafeShiftL r 1 >= x
initialGuess = bit (unsafeShiftR (log2 x) 1)
{-# INLINABLE findFirstMonotonic #-}
findFirstMonotonic :: (Int -> Bool) -> Int
findFirstMonotonic p = findBounds 0 1
where findBounds !l !u = if p u then binarySearch l u
else findBounds u (u * 2)
binarySearch !l !u = let !m = l + ((u - l) `div` 2)
in if | l+1 == u -> l
| p m -> binarySearch l m
| otherwise -> binarySearch m u
{-# INLINABLE alternateSign #-}
alternateSign :: Num a => [a] -> [a]
alternateSign = \ls -> foldr (\a r b -> if b then (negate a):r False else a:r True) (const []) ls False
powerSeries :: [Rational] -> (Int -> Int) -> CReal n -> CReal n
powerSeries q termsAtPrecision x = crMemoize
(\p -> let t = termsAtPrecision p
d = log2 (toInteger t) + 2
p' = p + d
p'' = p' + d
m = atPrecision x p''
xs = (%1) <$> iterate (\e -> m * e /^ p'') (bit p')
r = sum . take (t + 1) . fmap (round . (* fromInteger (bit d))) $ zipWith (*) q xs
in r /^ (2 * d))