[I was spending way too long on this post so I’m releasing it in bullet point form. Maybe I’ll pass it through an LLM ghostwriter.]

Intro

  • I was recently asked in an interview how you might get a confidence score from a neural network
  • It was very much a blind spot for me, despite having studied Bayesian inference and Gaussian processes at uni
  • I want to turn a weakness into a strength, so this post will be a refresher on some topics
  • There’s a lot of literature on this so I’m just going to cover some basic methods for regression, and maybe revisit in the future
  • Even the stuff I do mention here can be explored in much more depth - I’ll leave links

Contents

  • We’re going to focus on the simple task of predicting a noisy tanh function y = tanh(x) + e, e \sim N(0, 1%)
  • NN point estimate
  • BNNs (https://arxiv.org/pdf/2211.11865)
    • Put a prior on the NN weights
    • Assume data is normally distributed around a mean function to get a likelihood
    • This gives a posterior, but how do we sample from it?
    • MCMC uses rejection sampling from a conditional sampling dist to construct a Markov chain with converges to the posterior
    • This can be applied to the NN to sample thetas, which we use to produce samples from the posterior predictive dist
  • MC Dropout (Gal https://arxiv.org/pdf/1506.02142, https://www.cs.ox.ac.uk/people/yarin.gal/website/blog_3d801aa532c1ce.html)
    • This is somewhat a specific case of BNNs where the weight rows are either 0 or unaltered following a Bernoulli dist
    • Gal paper shows that this approximates a Gaussian process
    • One advantage is that we train the model as normal using gradient descent and dropout, so it’s more computationally tractable. We get the uncertainty for free
    • However it’s not clear to me what the output distribution actually represents - need to dig into the maths more
  • For comparison here’s the GP with a quadratic kernel (GPs are very cool, very Cambridge)

Uses

  • Optimisation, signal combination. E.g. quant: what’s your delta?
  • Interpretability
  • Safety (confidently incorrect)
  • Applications, e.g. science, AlphaFold
  • Robustness against adversarial oos data

Theory

  • Epistemic, aleatoric uncertainty
  • Maffs maffs maffs

Further reading

Code