{-# LANGUAGE BangPatterns, DeriveDataTypeable, DeriveGeneric #-}
module Statistics.Math.RootFinding
(
Root(..)
, fromRoot
, ridders
) where
import Control.Applicative
import Control.Monad (MonadPlus(..), ap)
import Data.Data (Data, Typeable)
import GHC.Generics (Generic)
import Numeric.MathFunctions.Comparison (within)
import Prelude
data Root a = NotBracketed
| SearchFailed
| Root a
deriving (Eq, Read, Show, Typeable, Data, Generic)
instance Functor Root where
fmap _ NotBracketed = NotBracketed
fmap _ SearchFailed = SearchFailed
fmap f (Root a) = Root (f a)
instance Monad Root where
NotBracketed >>= _ = NotBracketed
SearchFailed >>= _ = SearchFailed
Root a >>= m = m a
return = Root
instance MonadPlus Root where
mzero = SearchFailed
r@(Root _) `mplus` _ = r
_ `mplus` p = p
instance Applicative Root where
pure = Root
(<*>) = ap
instance Alternative Root where
empty = SearchFailed
r@(Root _) <|> _ = r
_ <|> p = p
fromRoot :: a
-> Root a
-> a
fromRoot _ (Root a) = a
fromRoot a _ = a
ridders :: Double
-> (Double,Double)
-> (Double -> Double)
-> Root Double
ridders tol (lo,hi) f
| flo == 0 = Root lo
| fhi == 0 = Root hi
| flo*fhi > 0 = NotBracketed
| otherwise = go lo flo hi fhi 0
where
go !a !fa !b !fb !i
| within 1 a b = Root a
| fm == 0 = Root m
| fn == 0 = Root n
| d < tol = Root n
| i >= (100 :: Int) = SearchFailed
| n == a || n == b = case () of
_| fm*fa < 0 -> go a fa m fm (i+1)
| otherwise -> go m fm b fb (i+1)
| fn*fm < 0 = go n fn m fm (i+1)
| fn*fa < 0 = go a fa n fn (i+1)
| otherwise = go n fn b fb (i+1)
where
d = abs (b - a)
dm = (b - a) * 0.5
!m = a + dm
!fm = f m
!dn = signum (fb - fa) * dm * fm / sqrt(fm*fm - fa*fb)
!n = m - signum dn * min (abs dn) (abs dm - 0.5 * tol)
!fn = f n
!flo = f lo
!fhi = f hi