{-# LANGUAGE CPP, BangPatterns, ScopedTypeVariables, ForeignFunctionInterface #-}
module Numeric.SpecFunctions.Internal
( erf
, erfc
, invErf
, invErfc
, log2
) where
import Data.Bits ((.&.), (.|.), shiftR)
import Data.Word (Word64)
import qualified Data.Vector.Unboxed as U
import Numeric.MathFunctions.Constants
erf :: Double -> Double
{-# INLINE erf #-}
erf = c_erf
erfc :: Double -> Double
{-# INLINE erfc #-}
erfc = c_erfc
foreign import ccall "erf" c_erf :: Double -> Double
foreign import ccall "erfc" c_erfc :: Double -> Double
invErf :: Double
-> Double
invErf p = invErfc (1 - p)
invErfc :: Double
-> Double
invErfc p
| p == 2 = m_neg_inf
| p == 0 = m_pos_inf
| p >0 && p < 2 = if p <= 1 then r else -r
| otherwise = modErr $ "invErfc: p must be in [0,2] got " ++ show p
where
pp = if p <= 1 then p else 2 - p
t = sqrt $ -2 * log( 0.5 * pp)
x0 = -0.70711 * ((2.30753 + t * 0.27061) / (1 + t * (0.99229 + t * 0.04481)) - t)
r = loop 0 x0
loop :: Int -> Double -> Double
loop !j !x
| j >= 2 = x
| otherwise = let err = erfc x - pp
x' = x + err / (1.12837916709551257 * exp(-x * x) - x * err)
in loop (j+1) x'
log2 :: Int -> Int
log2 v0
| v0 <= 0 = modErr $ "log2: nonpositive input, got " ++ show v0
| otherwise = go 5 0 v0
where
go !i !r !v | i == -1 = r
| v .&. b i /= 0 = let si = U.unsafeIndex sv i
in go (i-1) (r .|. si) (v `shiftR` si)
| otherwise = go (i-1) r v
b = U.unsafeIndex bv
!bv = U.fromList [ 0x02, 0x0c, 0xf0, 0xff00
, fromIntegral (0xffff0000 :: Word64)
, fromIntegral (0xffffffff00000000 :: Word64)]
!sv = U.fromList [1,2,4,8,16,32]
modErr :: String -> a
modErr msg = error $ "Numeric.SpecFunctions." ++ msg