module Numeric.RootFinding
(
Root(..)
, fromRoot
, ridders
, newtonRaphson
) where
import Control.Applicative (Alternative(..), Applicative(..))
import Control.Monad (MonadPlus(..), ap)
import Data.Data (Data, Typeable)
#if __GLASGOW_HASKELL__ > 704
import GHC.Generics (Generic)
#endif
import Numeric.MathFunctions.Comparison (within)
data Root a = NotBracketed
| SearchFailed
| Root a
deriving (Eq, Read, Show, Typeable, Data
#if __GLASGOW_HASKELL__ > 704
, Generic
#endif
)
instance Functor Root where
fmap _ NotBracketed = NotBracketed
fmap _ SearchFailed = SearchFailed
fmap f (Root a) = Root (f a)
instance Monad Root where
NotBracketed >>= _ = NotBracketed
SearchFailed >>= _ = SearchFailed
Root a >>= m = m a
return = Root
instance MonadPlus Root where
mzero = SearchFailed
r@(Root _) `mplus` _ = r
_ `mplus` p = p
instance Applicative Root where
pure = Root
(<*>) = ap
instance Alternative Root where
empty = SearchFailed
r@(Root _) <|> _ = r
_ <|> p = p
fromRoot :: a
-> Root a
-> a
fromRoot _ (Root a) = a
fromRoot a _ = a
ridders :: Double
-> (Double,Double)
-> (Double -> Double)
-> Root Double
ridders tol (lo,hi) f
| flo == 0 = Root lo
| fhi == 0 = Root hi
| flo*fhi > 0 = NotBracketed
| otherwise = go lo flo hi fhi 0
where
go !a !fa !b !fb !i
| within 1 a b = Root a
| fm == 0 = Root m
| fn == 0 = Root n
| d < tol = Root n
| i >= (100 :: Int) = SearchFailed
| n == a || n == b = case () of
_| fm*fa < 0 -> go a fa m fm (i+1)
| otherwise -> go m fm b fb (i+1)
| fn*fm < 0 = go n fn m fm (i+1)
| fn*fa < 0 = go a fa n fn (i+1)
| otherwise = go n fn b fb (i+1)
where
d = abs (b a)
dm = (b a) * 0.5
!m = a + dm
!fm = f m
!dn = signum (fb fa) * dm * fm / sqrt(fm*fm fa*fb)
!n = m signum dn * min (abs dn) (abs dm 0.5 * tol)
!fn = f n
!flo = f lo
!fhi = f hi
newtonRaphson
:: Double
-> (Double,Double,Double)
-> (Double -> (Double,Double))
-> Root Double
newtonRaphson !prec (!low,!guess,!hi) function
= go low guess hi
where
go !xMin !x !xMax
| f == 0 = Root x
| abs (dx / x) < prec = Root x
| otherwise = go xMin' x' xMax'
where
(f,f') = function x
delta | f' == 0 = error "handle f'==0"
| otherwise = f / f'
(dx,x') | z <= xMin = let d = 0.5*(x xMin) in (d, x d)
| z >= xMax = let d = 0.5*(x xMax) in (d, x d)
| otherwise = (delta, z)
where z = x delta
xMin' | dx < 0 = x
| otherwise = xMin
xMax' | dx > 0 = x
| otherwise = xMax