module Numeric.GSL.ODE (
odeSolve, odeSolveV, odeSolveVWith, ODEMethod(..), Jacobian, StepControl(..)
) where
import Numeric.LinearAlgebra.HMatrix
import Numeric.GSL.Internal
import Foreign.Ptr(FunPtr, nullFunPtr, freeHaskellFunPtr)
import Foreign.C.Types
import System.IO.Unsafe(unsafePerformIO)
type TVV = TV (TV Res)
type TVM = TV (TM Res)
type TVVM = TV (TV (TM Res))
type TVVVM = TV (TV (TV (TM Res)))
type Jacobian = Double -> Vector Double -> Matrix Double
data ODEMethod = RK2
| RK4
| RKf45
| RKck
| RK8pd
| RK2imp Jacobian
| RK4imp Jacobian
| BSimp Jacobian
| RK1imp Jacobian
| MSAdams
| MSBDF Jacobian
data StepControl = X Double Double
| X' Double Double
| XX' Double Double Double Double
| ScXX' Double Double Double Double (Vector Double)
odeSolve
:: (Double -> [Double] -> [Double])
-> [Double]
-> Vector Double
-> Matrix Double
odeSolve xdot xi ts = odeSolveV RKf45 hi epsAbs epsRel (l2v xdot) (fromList xi) ts
where hi = (ts!1 ts!0)/100
epsAbs = 1.49012e-08
epsRel = epsAbs
l2v f = \t -> fromList . f t . toList
odeSolveV
:: ODEMethod
-> Double
-> Double
-> Double
-> (Double -> Vector Double -> Vector Double)
-> Vector Double
-> Vector Double
-> Matrix Double
odeSolveV meth hi epsAbs epsRel = odeSolveVWith meth (XX' epsAbs epsRel 1 1) hi
odeSolveVWith
:: ODEMethod
-> StepControl
-> Double
-> (Double -> Vector Double -> Vector Double)
-> Vector Double
-> Vector Double
-> Matrix Double
odeSolveVWith method control = odeSolveVWith' m mbj c epsAbs epsRel aX aX' mbsc
where (m, mbj) = case method of
RK2 -> (0 , Nothing )
RK4 -> (1 , Nothing )
RKf45 -> (2 , Nothing )
RKck -> (3 , Nothing )
RK8pd -> (4 , Nothing )
RK2imp jac -> (5 , Just jac)
RK4imp jac -> (6 , Just jac)
BSimp jac -> (7 , Just jac)
RK1imp jac -> (8 , Just jac)
MSAdams -> (9 , Nothing )
MSBDF jac -> (10, Just jac)
(c, epsAbs, epsRel, aX, aX', mbsc) = case control of
X ea er -> (0, ea, er, 1 , 0 , Nothing)
X' ea er -> (0, ea, er, 0 , 1 , Nothing)
XX' ea er ax ax' -> (0, ea, er, ax, ax', Nothing)
ScXX' ea er ax ax' sc -> (1, ea, er, ax, ax', Just sc)
odeSolveVWith'
:: CInt
-> Maybe (Double -> Vector Double -> Matrix Double)
-> CInt
-> Double
-> Double
-> Double
-> Double
-> Maybe (Vector Double)
-> Double
-> (Double -> Vector Double -> Vector Double)
-> Vector Double
-> Vector Double
-> Matrix Double
odeSolveVWith' method mbjac control epsAbs epsRel aX aX' mbsc h f xiv ts =
unsafePerformIO $ do
let n = size xiv
sc = case mbsc of
Just scv -> checkdim1 n scv
Nothing -> xiv
fp <- mkDoubleVecVecfun (\t -> aux_vTov (checkdim1 n . f t))
jp <- case mbjac of
Just jac -> mkDoubleVecMatfun (\t -> aux_vTom (checkdim2 n . jac t))
Nothing -> return nullFunPtr
sol <- vec sc $ \sc' -> vec xiv $ \xiv' ->
vec (checkTimes ts) $ \ts' -> createMIO (size ts) n
(ode_c method control h epsAbs epsRel aX aX' fp jp
// sc' // xiv' // ts' )
"ode"
freeHaskellFunPtr fp
if (jp /= nullFunPtr) then freeHaskellFunPtr jp else pure ()
return sol
foreign import ccall safe "ode"
ode_c :: CInt -> CInt -> Double
-> Double -> Double -> Double -> Double
-> FunPtr (Double -> TVV) -> FunPtr (Double -> TVM) -> TVVVM
checkdim1 n v
| size v == n = v
| otherwise = error $ "Error: "++ show n
++ " components expected in the result of the function supplied to odeSolve"
checkdim2 n m
| rows m == n && cols m == n = m
| otherwise = error $ "Error: "++ show n ++ "x" ++ show n
++ " Jacobian expected in odeSolve"
checkTimes ts | size ts > 1 && all (>0) (zipWith subtract ts' (tail ts')) = ts
| otherwise = error "odeSolve requires increasing times"
where ts' = toList ts