{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE TypeFamilies #-}
module Numeric.RootFinding
(
Root(..)
, fromRoot
, Tolerance(..)
, withinTolerance
, IterationStep(..)
, findRoot
, RiddersParam(..)
, ridders
, riddersIterations
, RiddersStep(..)
, NewtonParam(..)
, newtonRaphson
, newtonRaphsonIterations
, NewtonStep(..)
) where
import Control.Applicative (Alternative(..), Applicative(..))
import Control.Monad (MonadPlus(..), ap)
import Control.DeepSeq (NFData(..))
import Data.Data (Data, Typeable)
import Data.Monoid (Monoid(..))
import Data.Foldable (Foldable)
import Data.Traversable (Traversable)
import Data.Default.Class
#if __GLASGOW_HASKELL__ > 704
import GHC.Generics (Generic)
#endif
import Numeric.MathFunctions.Comparison (within,eqRelErr)
import Numeric.MathFunctions.Constants (m_epsilon)
data Root a = NotBracketed
| SearchFailed
| Root !a
deriving (Eq, Read, Show, Typeable, Data, Foldable, Traversable
#if __GLASGOW_HASKELL__ > 704
, Generic
#endif
)
instance (NFData a) => NFData (Root a) where
rnf NotBracketed = ()
rnf SearchFailed = ()
rnf (Root a) = rnf a
instance Functor Root where
fmap _ NotBracketed = NotBracketed
fmap _ SearchFailed = SearchFailed
fmap f (Root a) = Root (f a)
instance Applicative Root where
pure = return
(<*>) = ap
instance Monad Root where
NotBracketed >>= _ = NotBracketed
SearchFailed >>= _ = SearchFailed
Root a >>= f = f a
return = Root
instance MonadPlus Root where
mzero = empty
mplus = (<|>)
instance Alternative Root where
empty = NotBracketed
r@Root{} <|> _ = r
_ <|> r@Root{} = r
NotBracketed <|> r = r
r <|> NotBracketed = r
_ <|> r = r
fromRoot :: a
-> Root a
-> a
fromRoot _ (Root a) = a
fromRoot a _ = a
data Tolerance
= RelTol !Double
| AbsTol !Double
deriving (Eq, Read, Show, Typeable, Data
#if __GLASGOW_HASKELL__ > 704
, Generic
#endif
)
withinTolerance :: Tolerance -> Double -> Double -> Bool
withinTolerance _ a b
| within 1 a b = True
withinTolerance (RelTol eps) a b = eqRelErr eps a b
withinTolerance (AbsTol tol) a b = abs (a - b) < tol
class IterationStep a where
matchRoot :: Tolerance -> a -> Maybe (Root Double)
findRoot :: IterationStep a
=> Int
-> Tolerance
-> [a]
-> Root Double
findRoot maxN tol = go 0
where
go !i _ | i >= maxN = SearchFailed
go !_ [] = SearchFailed
go i (x:xs) = case matchRoot tol x of
Just r -> r
Nothing -> go (i+1) xs
{-# INLINABLE findRoot #-}
{-# SPECIALIZE findRoot :: Int -> Tolerance -> [RiddersStep] -> Root Double #-}
{-# SPECIALIZE findRoot :: Int -> Tolerance -> [NewtonStep] -> Root Double #-}
data RiddersParam = RiddersParam
{ riddersMaxIter :: !Int
, riddersTol :: !Tolerance
}
deriving (Eq, Read, Show, Typeable, Data
#if __GLASGOW_HASKELL__ > 704
, Generic
#endif
)
instance Default RiddersParam where
def = RiddersParam
{ riddersMaxIter = 100
, riddersTol = RelTol (4 * m_epsilon)
}
data RiddersStep
= RiddersStep !Double !Double
| RiddersBisect !Double !Double
| RiddersRoot !Double
| RiddersNoBracket
deriving (Eq, Read, Show, Typeable, Data
#if __GLASGOW_HASKELL__ > 704
, Generic
#endif
)
instance NFData RiddersStep where
rnf x = x `seq` ()
instance IterationStep RiddersStep where
matchRoot tol r = case r of
RiddersRoot x -> Just $ Root x
RiddersNoBracket -> Just NotBracketed
RiddersStep a b
| withinTolerance tol a b -> Just $ Root ((a + b) / 2)
| otherwise -> Nothing
RiddersBisect a b
| withinTolerance tol a b -> Just $ Root ((a + b) / 2)
| otherwise -> Nothing
ridders
:: RiddersParam
-> (Double,Double)
-> (Double -> Double)
-> Root Double
ridders p bracket fun
= findRoot (riddersMaxIter p) (riddersTol p)
$ riddersIterations bracket fun
riddersIterations :: (Double,Double) -> (Double -> Double) -> [RiddersStep]
riddersIterations (lo,hi) f
| flo == 0 = [RiddersRoot lo]
| fhi == 0 = [RiddersRoot hi]
| flo*fhi > 0 = [RiddersNoBracket]
| lo < hi = RiddersStep lo hi : go lo flo hi fhi
| otherwise = RiddersStep lo hi : go hi fhi lo flo
where
flo = f lo
fhi = f hi
go !a !fa !b !fb
| fm == 0 = [RiddersRoot m]
| fn == 0 = [RiddersRoot n]
| n <= a || n >= b = case () of
_| fm*fa < 0 -> recBisect a fa m fm
| otherwise -> recBisect m fm b fb
| fn*fm < 0 = recRidders n fn m fm
| fn*fa < 0 = recRidders a fa n fn
| otherwise = recRidders n fn b fb
where
recBisect x fx y fy = RiddersBisect x y : go x fx y fy
recRidders x fx y fy = RiddersStep x y : go x fx y fy
dm = (b - a) * 0.5
m = (a + b) / 2
fm = f m
n = m - signum (fb - fa) * dm * fm / sqrt(fm*fm - fa*fb)
fn = f n
data NewtonParam = NewtonParam
{ newtonMaxIter :: !Int
, newtonTol :: !Tolerance
}
deriving (Eq, Read, Show, Typeable, Data
#if __GLASGOW_HASKELL__ > 704
, Generic
#endif
)
instance Default NewtonParam where
def = NewtonParam
{ newtonMaxIter = 50
, newtonTol = RelTol (4 * m_epsilon)
}
data NewtonStep
= NewtonStep !Double !Double
| NewtonBisection !Double !Double
| NewtonRoot !Double
| NewtonNoBracket
deriving (Eq, Read, Show, Typeable, Data
#if __GLASGOW_HASKELL__ > 704
, Generic
#endif
)
instance NFData NewtonStep where
rnf x = x `seq` ()
instance IterationStep NewtonStep where
matchRoot tol r = case r of
NewtonRoot x -> Just (Root x)
NewtonNoBracket -> Just NotBracketed
NewtonStep x x'
| withinTolerance tol x x' -> Just (Root x')
| otherwise -> Nothing
NewtonBisection a b
| withinTolerance tol a b -> Just (Root ((a + b) / 2))
| otherwise -> Nothing
{-# INLINE matchRoot #-}
newtonRaphson
:: NewtonParam
-> (Double,Double,Double)
-> (Double -> (Double,Double))
-> Root Double
newtonRaphson p guess fun
= findRoot (newtonMaxIter p) (newtonTol p)
$ newtonRaphsonIterations guess fun
newtonRaphsonIterations :: (Double,Double,Double) -> (Double -> (Double,Double)) -> [NewtonStep]
newtonRaphsonIterations (lo,guess,hi) function
| flo == 0 = [NewtonRoot lo]
| fhi == 0 = [NewtonRoot hi]
| flo*fhi > 0 = [NewtonNoBracket]
| flo > 0 = go hi guess' lo
| otherwise = go lo guess hi
where
(flo,_) = function lo
(fhi,_) = function hi
guess'
| guess >= lo && guess <= hi = guess
| guess >= hi && guess <= lo = guess
| otherwise = (lo + hi) / 2
go xA x xB
| f == 0 = [NewtonRoot x]
| f' == 0 = bisectionStep
| (x' - xA) * (x' - xB) < 0 = newtonStep
| otherwise = bisectionStep
where
(f,f') = function x
x' = x - f / f'
newtonStep
| f > 0 = NewtonStep x x' : go xA x' x
| otherwise = NewtonStep x x' : go x x' xB
bisectionStep
| f > 0 = NewtonBisection xA x : go xA ((xA + x) / 2) x
| otherwise = NewtonBisection x xB : go x ((x + xB) / 2) xB