module Math.Optimization.SPSA.Optimize (
runSPSA
) where
import Control.Monad.State (runState)
import Numeric.LinearAlgebra (Vector,scale,scaleRecip)
import Math.Optimization.SPSA.Types (
StateSPSA, defaultSPSA, checkSPSA,
getLoss, getConstraint,
peelAll,
getStop, shouldStop,
getIterations, incrementIteration
)
runSPSA :: StateSPSA a -> Vector Double -> Vector Double
runSPSA st t0 = fst $ runState (st >> checkSPSA t0 >> runSPSA' t0) defaultSPSA
runSPSA' :: Vector Double -> StateSPSA (Vector Double)
runSPSA' t = do
t' <- singleIteration t
stop <- checkStop t t'
incrementIteration
if stop then return t' else runSPSA' t'
singleIteration :: Vector Double -> StateSPSA (Vector Double)
singleIteration t = do
(a, c, d) <- peelAll
lossF <- getLoss
constrainF <- getConstraint
let cd = c `scale` d
let ya = lossF (t + cd)
let yb = lossF (t cd)
let grad = ((ya yb) / 2) `scaleRecip` cd
return $ constrainF (t (a `scale` grad))
checkStop :: Vector Double -> Vector Double -> StateSPSA Bool
checkStop t t' = do
crits <- getStop
iter <- getIterations
return $ any (\c -> shouldStop c iter t t') crits