{-# LANGUAGE CPP #-}
-----------------------------------------------------------------------------
-- |
-- Copyright   :  (c) Edward Kmett 2010-2021
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-- Root finding using Halley's rational method (the second in
-- the class of Householder methods). Assumes the function is three
-- times continuously differentiable and converges cubically when
-- progress can be made.
--
-----------------------------------------------------------------------------

module Numeric.AD.Rank1.Halley
  (
  -- * Halley's Method (Tower AD)
    findZero
  , findZeroNoEq
  , inverse
  , inverseNoEq
  , fixedPoint
  , fixedPointNoEq
  , extremum
  , extremumNoEq
  ) where

import Prelude hiding (all)
import Numeric.AD.Internal.Forward (Forward)
import Numeric.AD.Internal.On
import Numeric.AD.Internal.Tower (Tower)
import Numeric.AD.Mode
import Numeric.AD.Rank1.Tower (diffs0)
import Numeric.AD.Rank1.Forward (diff)
import Numeric.AD.Internal.Combinators (takeWhileDifferent)

-- $setup
-- >>> import Data.Complex

-- | The 'findZero' function finds a zero of a scalar function using
-- Halley's method; its output is a stream of increasingly accurate
-- results.  (Modulo the usual caveats.) If the stream becomes constant
-- ("it converges"), no further elements are returned.
--
-- Examples:
--
-- >>> take 10 $ findZero (\x->x^2-4) 1
-- [1.0,1.8571428571428572,1.9997967892704736,1.9999999999994755,2.0]
--
-- >>> last $ take 10 $ findZero ((+1).(^2)) (1 :+ 1)
-- 0.0 :+ 1.0
findZero :: (Fractional a, Eq a) => (Tower a -> Tower a) -> a -> [a]
findZero :: forall a. (Fractional a, Eq a) => (Tower a -> Tower a) -> a -> [a]
findZero Tower a -> Tower a
f = [a] -> [a]
forall a. Eq a => [a] -> [a]
takeWhileDifferent ([a] -> [a]) -> (a -> [a]) -> a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tower a -> Tower a) -> a -> [a]
forall a. Fractional a => (Tower a -> Tower a) -> a -> [a]
findZeroNoEq Tower a -> Tower a
f
{-# INLINE findZero #-}

-- | The 'findZeroNoEq' function behaves the same as 'findZero' except that it
-- doesn't truncate the list once the results become constant. This means it
-- can be used with types without an 'Eq' instance.
findZeroNoEq :: Fractional a => (Tower a -> Tower a) -> a -> [a]
findZeroNoEq :: forall a. Fractional a => (Tower a -> Tower a) -> a -> [a]
findZeroNoEq Tower a -> Tower a
f = (a -> a) -> a -> [a]
forall a. (a -> a) -> a -> [a]
iterate a -> a
go where
  go :: a -> a
go a
x = a
xn where
    (a
y,a
y',a
y'') = case (Tower a -> Tower a) -> a -> [a]
forall a. Num a => (Tower a -> Tower a) -> a -> [a]
diffs0 Tower a -> Tower a
f a
x of
                   (a
z:a
z':a
z'':[a]
_) -> (a
z,a
z',a
z'')
                   [a]
_ -> [Char] -> (a, a, a)
forall a. HasCallStack => [Char] -> a
error [Char]
"findZeroNoEq: Impossible (diffs0 should produce an infinite list)"
    xn :: a
xn = a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
2a -> a -> a
forall a. Num a => a -> a -> a
*a
ya -> a -> a
forall a. Num a => a -> a -> a
*a
y'a -> a -> a
forall a. Fractional a => a -> a -> a
/(a
2a -> a -> a
forall a. Num a => a -> a -> a
*a
y'a -> a -> a
forall a. Num a => a -> a -> a
*a
y'a -> a -> a
forall a. Num a => a -> a -> a
-a
ya -> a -> a
forall a. Num a => a -> a -> a
*a
y'') -- 9.606671960457536 bits error
       -- = x - recip (y'/y - y''/ y') -- "improved error" = 6.640625e-2 bits
       -- = x - y' / (y'/y/y' - y''/2) -- "improved error" = 1.4
#ifdef HERBIE
{-# ANN findZeroNoEq "NoHerbie" #-}
#endif
{-# INLINE findZeroNoEq #-}

-- | The 'inverse' function inverts a scalar function using
-- Halley's method; its output is a stream of increasingly accurate
-- results.  (Modulo the usual caveats.) If the stream becomes constant
-- ("it converges"), no further elements are returned.
--
-- Note: the @take 10 $ inverse sqrt 1 (sqrt 10)@ example that works for Newton's method
-- fails with Halley's method because the preconditions do not hold!
inverse :: (Fractional a, Eq a) => (Tower a -> Tower a) -> a -> a -> [a]
inverse :: forall a.
(Fractional a, Eq a) =>
(Tower a -> Tower a) -> a -> a -> [a]
inverse Tower a -> Tower a
f a
x0 = [a] -> [a]
forall a. Eq a => [a] -> [a]
takeWhileDifferent ([a] -> [a]) -> (a -> [a]) -> a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tower a -> Tower a) -> a -> a -> [a]
forall a. Fractional a => (Tower a -> Tower a) -> a -> a -> [a]
inverseNoEq Tower a -> Tower a
f a
x0
{-# INLINE inverse  #-}

-- | The 'inverseNoEq' function behaves the same as 'inverse' except that it
-- doesn't truncate the list once the results become constant. This means it
-- can be used with types without an 'Eq' instance.
inverseNoEq :: Fractional a => (Tower a -> Tower a) -> a -> a -> [a]
inverseNoEq :: forall a. Fractional a => (Tower a -> Tower a) -> a -> a -> [a]
inverseNoEq Tower a -> Tower a
f a
x0 a
y = (Tower a -> Tower a) -> a -> [a]
forall a. Fractional a => (Tower a -> Tower a) -> a -> [a]
findZeroNoEq (\Tower a
x -> Tower a -> Tower a
f Tower a
x Tower a -> Tower a -> Tower a
forall a. Num a => a -> a -> a
- Scalar (Tower a) -> Tower a
forall t. Mode t => Scalar t -> t
auto a
Scalar (Tower a)
y) a
x0
{-# INLINE inverseNoEq #-}

-- | The 'fixedPoint' function find a fixedpoint of a scalar
-- function using Halley's method; its output is a stream of
-- increasingly accurate results.  (Modulo the usual caveats.)
--
-- If the stream becomes constant ("it converges"), no further
-- elements are returned.
--
-- >>> last $ take 10 $ fixedPoint cos 1
-- 0.7390851332151607
fixedPoint :: (Fractional a, Eq a) => (Tower a -> Tower a) -> a -> [a]
fixedPoint :: forall a. (Fractional a, Eq a) => (Tower a -> Tower a) -> a -> [a]
fixedPoint Tower a -> Tower a
f = [a] -> [a]
forall a. Eq a => [a] -> [a]
takeWhileDifferent ([a] -> [a]) -> (a -> [a]) -> a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tower a -> Tower a) -> a -> [a]
forall a. Fractional a => (Tower a -> Tower a) -> a -> [a]
fixedPointNoEq Tower a -> Tower a
f
{-# INLINE fixedPoint #-}

-- | The 'fixedPointNoEq' function behaves the same as 'fixedPoint' except that
-- it doesn't truncate the list once the results become constant. This means it
-- can be used with types without an 'Eq' instance.
fixedPointNoEq :: Fractional a => (Tower a -> Tower a) -> a -> [a]
fixedPointNoEq :: forall a. Fractional a => (Tower a -> Tower a) -> a -> [a]
fixedPointNoEq Tower a -> Tower a
f = (Tower a -> Tower a) -> a -> [a]
forall a. Fractional a => (Tower a -> Tower a) -> a -> [a]
findZeroNoEq (\Tower a
x -> Tower a -> Tower a
f Tower a
x Tower a -> Tower a -> Tower a
forall a. Num a => a -> a -> a
- Tower a
x)
{-# INLINE fixedPointNoEq #-}

-- | The 'extremum' function finds an extremum of a scalar
-- function using Halley's method; produces a stream of increasingly
-- accurate results.  (Modulo the usual caveats.) If the stream becomes
-- constant ("it converges"), no further elements are returned.
--
-- >>> take 10 $ extremum cos 1
-- [1.0,0.29616942658570555,4.59979519460002e-3,1.6220740159042513e-8,0.0]
extremum :: (Fractional a, Eq a) => (On (Forward (Tower a)) -> On (Forward (Tower a))) -> a -> [a]
extremum :: forall a.
(Fractional a, Eq a) =>
(On (Forward (Tower a)) -> On (Forward (Tower a))) -> a -> [a]
extremum On (Forward (Tower a)) -> On (Forward (Tower a))
f = [a] -> [a]
forall a. Eq a => [a] -> [a]
takeWhileDifferent ([a] -> [a]) -> (a -> [a]) -> a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (On (Forward (Tower a)) -> On (Forward (Tower a))) -> a -> [a]
forall a.
Fractional a =>
(On (Forward (Tower a)) -> On (Forward (Tower a))) -> a -> [a]
extremumNoEq On (Forward (Tower a)) -> On (Forward (Tower a))
f
{-# INLINE extremum #-}

-- | The 'extremumNoEq' function behaves the same as 'extremum' except that it
-- doesn't truncate the list once the results become constant. This means it
-- can be used with types without an 'Eq' instance.
extremumNoEq :: Fractional a => (On (Forward (Tower a)) -> On (Forward (Tower a))) -> a -> [a]
extremumNoEq :: forall a.
Fractional a =>
(On (Forward (Tower a)) -> On (Forward (Tower a))) -> a -> [a]
extremumNoEq On (Forward (Tower a)) -> On (Forward (Tower a))
f = (Tower a -> Tower a) -> a -> [a]
forall a. Fractional a => (Tower a -> Tower a) -> a -> [a]
findZeroNoEq ((Forward (Tower a) -> Forward (Tower a)) -> Tower a -> Tower a
forall a. Num a => (Forward a -> Forward a) -> a -> a
diff (On (Forward (Tower a)) -> Forward (Tower a)
forall t. On t -> t
off (On (Forward (Tower a)) -> Forward (Tower a))
-> (Forward (Tower a) -> On (Forward (Tower a)))
-> Forward (Tower a)
-> Forward (Tower a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. On (Forward (Tower a)) -> On (Forward (Tower a))
f (On (Forward (Tower a)) -> On (Forward (Tower a)))
-> (Forward (Tower a) -> On (Forward (Tower a)))
-> Forward (Tower a)
-> On (Forward (Tower a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Forward (Tower a) -> On (Forward (Tower a))
forall t. t -> On t
On))
{-# INLINE extremumNoEq #-}