{- |
Example of an HMM with continuous emissions.
We train a model to accept sine waves of a certain frequency.

There are four hidden states: 'Rising', 'High', 'Falling', 'Low'.
-}
module Math.HiddenMarkovModel.Example.SineWave
{-# WARNING "do not import that module, it is only intended for demonstration" #-}
   where

import qualified Math.HiddenMarkovModel as HMM
import qualified Math.HiddenMarkovModel.Distribution as Distr
import Math.HiddenMarkovModel.Utility
         (normalizeProb, squareFromLists, hermitianFromList, singleton)

import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Vector (Vector)

import qualified Data.Array.Comfort.Boxed as Array
import qualified Data.Array.Comfort.Shape as Shape

import qualified Data.NonEmpty.Class as NonEmptyC
import qualified Data.NonEmpty as NonEmpty
import Data.Function.HT (nest)
import Data.Tuple.HT (mapSnd)



data State = Rising | High | Falling | Low
   deriving (Eq, Ord, Enum, Bounded)

type StateSet = Shape.Enumeration State

stateSet :: StateSet
stateSet = Shape.Enumeration


type HMM = HMM.Gaussian () StateSet Double

hmm :: HMM
hmm =
   HMM.Cons {
      HMM.initial = normalizeProb $ Vector.constant stateSet 1,
      HMM.transition =
         squareFromLists stateSet $
            stateVector 0.9 0.0 0.0 0.1 :
            stateVector 0.1 0.9 0.0 0.0 :
            stateVector 0.0 0.1 0.9 0.0 :
            stateVector 0.0 0.0 0.1 0.9 :
            [],
      HMM.distribution =
         Distr.gaussian $ Array.fromList stateSet $
            (singleton   0 , hermitianFromList () [1]) :
            (singleton   1 , hermitianFromList () [1]) :
            (singleton   0 , hermitianFromList () [1]) :
            (singleton (-1), hermitianFromList () [1]) :
            []
   }

stateVector :: Double -> Double -> Double -> Double -> Vector StateSet Double
stateVector x0 x1 x2 x3 = Vector.fromList stateSet [x0,x1,x2,x3]

sineWaveLabeled :: NonEmpty.T [] (State, Double)
sineWaveLabeled =
   NonEmpty.mapTail (take 200) $
   fmap (\x -> (toEnum $ mod (floor (x*2/pi+0.5)) 4, sin x)) $
   NonEmptyC.iterate (0.1+) 0

sineWave :: NonEmpty.T [] Double
sineWave = fmap snd sineWaveLabeled

revealed :: NonEmpty.T [] State
revealed = HMM.reveal hmmTrainedSupervised $ fmap singleton sineWave

hmmTrainedSupervised :: HMM
hmmTrainedSupervised =
   HMM.finishTraining $ HMM.trainSupervised stateSet $
   fmap (mapSnd singleton) sineWaveLabeled

hmmTrainedUnsupervised :: HMM
hmmTrainedUnsupervised =
   HMM.finishTraining $ HMM.trainUnsupervised hmm $ fmap singleton sineWave

hmmIterativelyTrained :: HMM
hmmIterativelyTrained =
   nest 100
      (\model ->
         HMM.finishTraining $ HMM.trainUnsupervised model $
         fmap singleton sineWave)
      hmm