{-# LANGUAGE TypeFamilies #-}
module Math.HiddenMarkovModel (
T(..),
Discrete, DiscreteTrained,
Gaussian, GaussianTrained,
uniform,
generate,
generateLabeled,
probabilitySequence,
Normalized.logLikelihood,
Normalized.reveal,
Trained(..),
trainSupervised,
Normalized.trainUnsupervised,
mergeTrained, finishTraining, trainMany,
deviation,
toCSV,
fromCSV,
) where
import qualified Math.HiddenMarkovModel.Distribution as Distr
import qualified Math.HiddenMarkovModel.Normalized as Normalized
import qualified Math.HiddenMarkovModel.CSV as HMMCSV
import Math.HiddenMarkovModel.Private
(T(..), Trained(..), mergeTrained, toCells, parseCSV)
import Math.HiddenMarkovModel.Utility
(SquareMatrix, squareConstant,
randomItemProp, normalizeProb, attachOnes)
import qualified Numeric.LAPACK.Matrix as Matrix
import qualified Numeric.LAPACK.Vector as Vector
import qualified Numeric.LAPACK.Scalar as Scalar
import qualified Numeric.Netlib.Class as Class
import qualified Data.Array.Comfort.Storable as StorableArray
import qualified Data.Array.Comfort.Shape as Shape
import qualified Data.Array.Comfort.Boxed as Array
import qualified Text.CSV.Lazy.String as CSV
import qualified System.Random as Rnd
import qualified Control.Monad.Exception.Synchronous as ME
import qualified Control.Monad.Trans.State as MS
import qualified Control.Monad.HT as Monad
import qualified Data.NonEmpty as NonEmpty
import Data.Traversable (Traversable, mapAccumL)
import Data.Foldable (Foldable)
type DiscreteTrained symbol sh prob =
Trained (Distr.DiscreteTrained symbol sh prob) sh prob
type Discrete symbol sh prob = T (Distr.Discrete symbol sh prob) sh prob
type GaussianTrained emiSh stateSh a =
Trained (Distr.GaussianTrained emiSh stateSh a) stateSh a
type Gaussian emiSh stateSh a = T (Distr.Gaussian emiSh stateSh a) stateSh a
uniform ::
(Distr.Info distr, Distr.StateShape distr ~ sh, Shape.C sh,
Distr.Probability distr ~ prob) =>
distr -> T distr sh prob
uniform distr =
let sh = Distr.statesShape distr
c = recip $ fromIntegral $ Shape.size sh
in Cons {
initial = Vector.constant sh c,
transition = squareConstant sh c,
distribution = distr
}
probabilitySequence ::
(Traversable f, Distr.EmissionProb distr,
Distr.StateShape distr ~ sh, Shape.Indexed sh, Shape.Index sh ~ state,
Distr.Probability distr ~ prob, Distr.Emission distr ~ emission) =>
T distr sh prob -> f (state, emission) -> f prob
probabilitySequence hmm =
snd
.
mapAccumL
(\index (s, e) ->
((transition hmm StorableArray.!) . flip (,) s,
index s * Distr.emissionStateProb (distribution hmm) e s))
(initial hmm StorableArray.!)
generate ::
(Rnd.RandomGen g, Ord prob, Rnd.Random prob, Distr.Generate distr,
Distr.StateShape distr ~ sh, Shape.Indexed sh, Shape.Index sh ~ state,
Distr.Probability distr ~ prob, Distr.Emission distr ~ emission) =>
T distr sh prob -> g -> [emission]
generate hmm = map snd . generateLabeled hmm
generateLabeled ::
(Rnd.RandomGen g, Ord prob, Rnd.Random prob, Distr.Generate distr,
Distr.StateShape distr ~ sh, Shape.Indexed sh, Shape.Index sh ~ state,
Distr.Probability distr ~ prob, Distr.Emission distr ~ emission) =>
T distr sh prob -> g -> [(state, emission)]
generateLabeled hmm =
MS.evalState $
flip MS.evalStateT (initial hmm) $
Monad.repeat $ MS.StateT $ \v0 -> do
s <-
randomItemProp $
zip (Shape.indices $ StorableArray.shape v0) (Vector.toList v0)
x <- Distr.generate (distribution hmm) s
return ((s, x), Matrix.takeColumn (transition hmm) s)
trainSupervised ::
(Distr.StateShape distr ~ sh, Shape.Index sh ~ state,
Distr.Estimate tdistr distr,
Distr.Probability distr ~ prob, Distr.Emission distr ~ emission) =>
sh -> NonEmpty.T [] (state, emission) -> Trained tdistr sh prob
trainSupervised sh xs =
let getState (s, _x) = s
in Trained {
trainedInitial =
StorableArray.fromAssociations sh 0
[(getState (NonEmpty.head xs), 1)],
trainedTransition =
Matrix.transpose $
StorableArray.accumulate (+) (squareConstant sh 0) $
attachOnes $ NonEmpty.mapAdjacent (,) $ fmap getState xs,
trainedDistribution =
Distr.accumulateEmissions $ Array.map attachOnes $
Array.accumulate (flip (:))
(Array.fromList sh $ replicate (Shape.size sh) [])
(NonEmpty.flatten xs)
}
finishTraining ::
(Shape.C sh, Eq sh,
Distr.Estimate tdistr distr, Distr.Probability distr ~ prob) =>
Trained tdistr sh prob -> T distr sh prob
finishTraining hmm =
Cons {
initial = normalizeProb $ trainedInitial hmm,
transition = normalizeProbColumns $ trainedTransition hmm,
distribution = Distr.normalize $ trainedDistribution hmm
}
normalizeProbColumns ::
(Shape.C sh, Eq sh, Class.Real a) => SquareMatrix sh a -> SquareMatrix sh a
normalizeProbColumns m =
Matrix.scaleColumns (StorableArray.map recip (Matrix.columnSums m)) m
trainMany ::
(Shape.C sh, Eq sh,
Distr.Estimate tdistr distr, Distr.Probability distr ~ prob,
Foldable f) =>
(trainingData -> Trained tdistr sh prob) ->
NonEmpty.T f trainingData -> T distr sh prob
trainMany train =
finishTraining . NonEmpty.foldl1Map mergeTrained train
deviation ::
(Shape.InvIndexed sh, Eq sh, Class.Real prob, Ord prob) =>
T distr sh prob -> T distr sh prob -> prob
deviation hmm0 hmm1 =
deviationVec (initial hmm0) (initial hmm1)
`max`
deviationVec (transition hmm0) (transition hmm1)
deviationVec ::
(Shape.InvIndexed sh, Eq sh, Class.Real a) =>
StorableArray.Array sh a -> StorableArray.Array sh a -> a
deviationVec =
getDeviation $ Class.switchReal deviationVecAux deviationVecAux
newtype Deviation f a = Deviation {getDeviation :: f a -> f a -> a}
deviationVecAux ::
(Shape.InvIndexed sh, Eq sh, Ord a, Class.Real a, Scalar.RealOf a ~ a) =>
Deviation (StorableArray.Array sh) a
deviationVecAux =
Deviation $ \x y ->
Scalar.absolute $ snd $ Vector.argAbsMaximum $ Vector.sub x y
toCSV ::
(Distr.ToCSV distr, Shape.Indexed sh, Class.Real prob, Show prob) =>
T distr sh prob -> String
toCSV hmm =
CSV.ppCSVTable $ snd $ CSV.toCSVTable $ HMMCSV.padTable "" $
toCells hmm
fromCSV ::
(Distr.FromCSV distr, Distr.StateShape distr ~ stateSh,
Shape.Indexed stateSh, Shape.Index stateSh ~ state,
Class.Real prob, Read prob) =>
(Int -> stateSh) -> String -> ME.Exceptional String (T distr stateSh prob)
fromCSV makeShape =
MS.evalStateT (parseCSV makeShape) . map HMMCSV.fixShortRow . CSV.parseCSV