Implementing a Hidden Markov Model in Rust
Over the past few months, I’ve implemented hmmm, a Rust library for Hidden Markov Models (HMMs).
HMMs are a well-established statistical machine learning technique for modeling sequences of data. They have been applied to problems like speech recognition and bioinformatics. They are called “hidden” because each discrete time step is associated with a hidden state. Since the hidden state is never observed in the data collection process, it must be inferred from the data that is observed. This is kind of hard. This post won’t explain HMMs in too much detail; if you are interested, the resources I used were:
- Wikipedia: HMMs and Baum-Welch Algorithm
- Machine Learning: a Probabilistic Perspective by Kevin Murphy, 2012 — sections 17.3 and 17.4
- Pattern Recognition and Machine Learning by Christopher Bishop, 2006 — section 13.2
Below, I’ll briefly discuss some challenges that I ran into while implementing this library.
When unit testing an algorithm, it’s always nice to include some empty inputs. Normally, this is easy because your algorithm doesn’t need to do anything! However, empty training data caused me problems because it resulted in divide-by-zero errors. To fix this, I wrote some special-cased code to fall back to uniform priors in the case of missing data.
Testing machine learning algorithms can be difficult. I used two testing strategies: painstakingly hand-computing probabilities and doing inefficient automated sampling.
For the “painstaking hand computation” testing strategy, I constructed a couple simple HMMs by hand and then hand-calculated all the possible paths through the HMM for a few time steps. Although this took a long time, I found that it helped me to develop my intuitions about HMMs.
While painstakingly hand-computing values for
HMM::smooth, I decided that it was just too painstaking. So, I wrote a sampling algorithm to compute approximations of the quantities that I was interested in.
HMM::smooth returns the probability that the HMM will be in state at time . To approximate this answer, I just ran the HMM many times and counted how many times the HMM was in each state at each time. Although this doesn’t return the expectation exactly, an approximate answer is still a valuable tool for checking correctness.
I set up a simple training benchmark with 1,000 observations:
0101010101.... Unfortunately, my implementation couldn’t even handle this relatively short input: some of my learned values were 0, implying that the sequence of observations was impossible. I realized that this was due to problems with numerical precision: with
0.0. So, storing the likelihood of a sequence of length 1,000 is actually impossible using 64-bit floats.
In many machine learning applications, this can be solved by working entirely in the log domain. Unfortunately, this does not work for Baum-Welch inference because probabilities must be added together, which can’t be done in the log domain. Bishop’s PRML suggests a solution to this, which is to renormalize the probabilities at each time step to sum to 1. I found that this technique fixed my underflow problems.
KaTeX embedded in documentation
- Performance benchmarks and optimizations
- Maximum a posteriori parameter inference
- Bayesian inference with blocked Gibbs sampling