[Blogtober #6] Uncertainty in ML
[I was spending way too long on this post so I’m releasing it in bullet point form. Maybe I’ll pass it through an LLM ghostwriter.]
Intro
- I was recently asked in an interview how you might get a confidence score from a neural network
- It was very much a blind spot for me, despite having studied Bayesian inference and Gaussian processes at uni
- I want to turn a weakness into a strength, so this post will be a refresher on some topics
- There’s a lot of literature on this so I’m just going to cover some basic methods for regression, and maybe revisit in the future
- Even the stuff I do mention here can be explored in much more depth - I’ll leave links
Contents
- We’re going to focus on the simple task of predicting a noisy tanh function y = tanh(x) + e, e \sim N(0, 1%)
- NN point estimate
- BNNs (https://arxiv.org/pdf/2211.11865)
- Put a prior on the NN weights
- Assume data is normally distributed around a mean function to get a likelihood
- This gives a posterior, but how do we sample from it?
- MCMC uses rejection sampling from a conditional sampling dist to construct a Markov chain with converges to the posterior
- This can be applied to the NN to sample thetas, which we use to produce samples from the posterior predictive dist
- MC Dropout (Gal https://arxiv.org/pdf/1506.02142, https://www.cs.ox.ac.uk/people/yarin.gal/website/blog_3d801aa532c1ce.html)
- This is somewhat a specific case of BNNs where the weight rows are either 0 or unaltered following a Bernoulli dist
- Gal paper shows that this approximates a Gaussian process
- One advantage is that we train the model as normal using gradient descent and dropout, so it’s more computationally tractable. We get the uncertainty for free
- However it’s not clear to me what the output distribution actually represents - need to dig into the maths more
- For comparison here’s the GP with a quadratic kernel (GPs are very cool, very Cambridge)
Uses
- Optimisation, signal combination. E.g. quant: what’s your delta?
- Interpretability
- Safety (confidently incorrect)
- Applications, e.g. science, AlphaFold
- Robustness against adversarial oos data
Theory
- Epistemic, aleatoric uncertainty
- Maffs maffs maffs
Further reading
- Variational inference, PCA, VAE (https://arxiv.org/pdf/2101.00734)
- Calibration (Guo https://arxiv.org/pdf/1706.04599, Wang https://arxiv.org/pdf/2308.01222)
Code