Bayesian neural networks without Bayes' rule

Posted on Oct 27, 2025
tl;dr: A gentle introduction to Martingale Posterior Neural Networks (NeurIPS 2025): a predictive-first alternative to Bayesian neural networks for fast sequential learning and Bayesian-style decision-making, without priors or posteriors.

tags: bayes, bayesian-deep-learning, neural-networks, uncertainty, sequential-decision-making

Introduction

Deep neural networks (DNNs) are great at fitting data, but terrible at predicting when they don’t know the right answer. Tackling this latter challenge (often called uncertainty quantification) is at the core of various problems in AI/ML. For example, reducing hallucinations in large language models, 1 or the classical balancing of exploration versus exploitation in reinforcement learning, neural bandits, and Bayesian optimization.

A popular way to aim for both reliable predictions and calibrated uncertainty quantification is through Bayesian neural networks (BNNs). In theory they are elegant: instead of fixing the parameters of the DNN, you treat them as random variables. Then, conditioned on observed data, Bayes’ rule gives you a posterior over the parameters. This, in principle, should allow for model averaging, uncertainty quantification, and sequential decision-making via algorithms like Thompson sampling.

However, BNNs can be slower than non-Bayesian approaches, 2 memory-hungry, 3 unstable, 4 and misspecified 5 which can make them weak competitors against alternative methods in many of the settings where they should shine.

In our Neurips 2025 paper, Martingale Posterior Neural Networks for Fast Sequential Decision Making, we propose a new way to tackle many of the limitations of BNNs. The main method we introduce, HiLoFi, presents a way to learn neural network parameters in a non-Bayesian way, while being Bayesian in a predictive way. In particular, we build on the Martingale posterior 6 idea, which argues that being Bayesian is not about building priors, likelihoods, or posteriors, but is instead about building a posterior predictive: the distribution of future (unseen) data given past observations.

As we show, by skipping the Bayesian prior-likelihood-posterior (PLP) formulation, and instead focusing on the posterior predictive as the main object of a BNN, we retain Bayesian uncertainty quantification without many of challenges associated with posterior estimation (or approximation) and posterior sampling.

In our experiments, we find that this approach is 10–100× faster than Bayesian PLP approaches on sequential decision-making tasks, while matching or outperforming rival methods. This is valuable in settings such as streaming reinforcement learning and neural contextual bandits, where decisions must be made fast and the agent must learn continually.

In this post, we illustrate how to get a posterior-free BNN using a simple feed-forward neural network. The code is in this notebook.

The dataset

We consider a dataset $\data_t = (x_t, y_t)$ where $x_t \in \mathbb{R}$, $y_t = f(x_t) \in \reals$, and $f$ is the 1D Ackley function. We denote $\data_{1:t} = (\data_1, \ldots, \data_t)$ and $y_{1:t} = (y_1, \ldots, y_t)$.

1
2
3
4
5
6
def ackley_1d(x, y=0):
    out = (-20*jnp.exp(-0.2*jnp.sqrt(0.5*(x**2 + y**2))) 
           - jnp.exp(0.5*(jnp.cos(2*jnp.pi*x) + jnp.cos(2*jnp.pi*y)))
           + jnp.e + 20)
    
    return out

To generate the training data, we sample $10$ points $x_t \sim \mathcal{U}[-4, 4]$ and evaluate $y_t = f(x_t) + e_t$, with $p(e_t) = \normal(e_t \mid 0,\,0.1^2)$. The sample points used throughout are shown in the Figure below.

Sample of ackley function

In our experiments, the agent observes each datapoint one at a time. Our aim is to build a good approximation for $f$ while also producing a sensible measure of uncertainty. Naturally, we expect higher uncertainty in regions where no data is available and lower uncertainty in regions where data is available.

Using jax, the dataset is generated as follows

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
n_obs = 10
key = jax.random.PRNGKey(314)
key_x, key_noise = jax.random.split(key)

noise = jax.random.normal(key_noise, (n_obs,1)) * 0.1
x_train = jax.random.uniform(key_x, shape=(n_obs,1), minval=-4, maxval=4)
y_train = ackley_1d(x_train) + noise

x_test = jnp.linspace(-4, 4, 1000)
y_test = ackley_1d(xrange)

A Gaussian process benchmark

We consider a Gaussian process (GP) regression with Gaussian kernel as a benchmark.

Fitting the data and having an estimate of uncertainty is relatively easy with a GP: having training data with observations $y_{1:t}$ and inputs $x_{1:t}$, the predictive model of a GP at a point $x \in \reals$ is 7 $$ p(y \mid x,\,\data_{1:t}) = \normal(y \mid \vK_x\,y_{1:t},\, k(x, x) - \vK_x\,\vS\,\vK_x^\intercal),\\ $$ where $\normal(z \mid m, s^2)$ is a Gaussian density, with mean $m$ and variance $s^2$, evaluated at $z$, $$ \begin{aligned} \vK_x &= \vk_x\,\vS^{-1},\\ \vk_x &= \begin{bmatrix} k(x, x_1) & \ldots & k(x, x_t) \end{bmatrix},\\ \vS &= \begin{bmatrix} k(x_1, x_1) & \ldots & k(x_1, x_t) \\ \vdots & \ddots & \vdots \\ k(x_t, x_1) & \ldots & k(x_t, x_t) \end{bmatrix}. \end{aligned} $$ Here, $$ k(u, v) = \sigma^2 \exp \left(-\frac{ \left\Vert u - v \right\Vert^2}{2\ell^2}\right), $$ is the Gaussian kernel 8 with $\sigma^2$ the amplitude and $\ell$ the lengthscale. In particular, we take $\sigma^2 = 1.5$ and $\ell^2 = 0.2$.

In code, predictions made using a GP can be written as follows

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
@partial(jax.vmap, in_axes=(0, None, None, None))
@partial(jax.vmap, in_axes=(None, 0, None, None))
def kgauss(u, v, ell2, amplitude):
    return amplitude * jnp.exp(-(u - v) ** 2 / (2 * ell2))

lenghtscale = 0.2
amplitude = 1.5
var_noise = 0.1 ** 2

I_train = jnp.eye(len(x_train))
I_test = jnp.eye(len(x_test))

# Estimate variance and covariances
var_train = kgauss(x_train, x_train, lenghtscale, amplitude) + var_noise * I_train
var_test = kgauss(x_test, x_test, lenghtscale, amplitude) + var_noise * I_test
cov_test_train = kgauss(x_test, x_train, lenghtscale, amplitude) 

# Make predictions
K = jnp.linalg.solve(var_train, cov_test_train.T).T
mu_pred = K @ y # posterior predictive mean
Sigma_pred = var_test - K @ var_train @ K.T # posterior predictive covariance

The figure below shows the GP posterior predictive mean and its confidence around two standard deviations (its uncertainty), i.e., $$ \begin{aligned} \E[y_{t+1} \mid x_{t+1},\,\data_{1:t}] &= \vK_x\,y_{1:t},\\ \var[y_{t+1} \mid x_{t+1}, \data_{1:t}] &= k(x, x) - \vK_x\,\vS\,\vK_x^\intercal. \end{aligned} $$

Gaussian process posterior predictive

As expected, when fitting a GP with small observation noise, the predictive variance variance around observed points is small. Conversely, the predictive variance increases around unobserved regions of space. This desirable property of GPs makes them, in many cases, the gold standard in Bayesian uncertainty quantification.

The neural network

One of the goals of Bayesian neural networks (BNNs) is to retain the interpolating and uncertainty-aware property of GPs, while leveraging the flexibility of neural networks to work in high-dimensional spaces. Thus, we look for ways to train neural networks that have the qualitative property of GPs described above.

To illustrate BNNs, we consider the following feed-forward neural network architecture which we implement in flax:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class MLP(nn.Module):
    n_hidden: int = 20

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.n_hidden)(x)
        x = jnp.sin(x)
        x = nn.Dense(self.n_hidden)(x)
        x = nn.elu(x)
        x = nn.Dense(self.n_hidden)(x)
        x = nn.elu(x)
        x = nn.Dense(1, name="last_layer")(x)
        return x

key = jax.random.PRNGKey(0)
model = MLP() # instance of the model
params_init = model.init(key, X) # θ(0|0) — initial estimate of model parameters

The first activation function is a sine function, which acts as a soft bias to help capture the strong non-linearities in the Ackley function.

Formally, we write $h(\vtheta, x)$ as the neural network parameterized by $\vtheta \in \reals^D$, with inputs $x \in \reals$. Furthermore, we denote the initial estimate of model parameters by $\vtheta_{0|0} \equiv$ params_init. Here, $D=901$.

Linearized probabilistic neural networks for sequential learning

Starting from the initial estimate $\vtheta_{0|0} \in \reals^D$, our goal is to obtain a sequence of parameters $\vtheta_{t|t} \in \reals^D$ for $t = 1, \ldots, T$, which are updated once every new datapoint $\data_t$ is presented. We can think of this as doing mini-batch updates of size one.

The following diagram illustrates this idea: having an estimate $\vtheta_{t-1|t-1}$, we combine the estimate with a new datapoint $\data_t$, to obtain $\vtheta_{t|t}$. We can then repeat this process for the next datapoint. $$ \begin{array}{c c c c c c c c c} \vtheta_{0|0} & \to & \vtheta_{1|1} & \to & \vtheta_{2|2} & \to & \ldots & \to & \vtheta_{T|T} \\ & & \uparrow & & \uparrow & & & & \uparrow \\ & & \data_1 & & \data_2 & & & & \data_T \\ \end{array} $$

To derive updates for $\vtheta_{t|t}$, we model the observations $y_t$ using a first-order approximation of the neural network around its previous estimate $\vtheta_{t-1|t-1}$ plus Gaussian noise. That is,

$$ \tag{1} y_t \approx h(\vtheta_{t-1|t-1},\,x_{t}) + \vJ_t\,(\vtheta_t -\vtheta_{t-1|t-1}) + e_t, $$ with $p(e_t) = \mathcal{N}(e_t \mid \vzero, r^2)$ a zero-mean Gaussian and $\vJ_t = \nabla_\vtheta h(\vtheta_{t-1|t-1},\,x_t) \in \reals^{1\times D}$ the Jacobian of the neural network.

Next, we discuss two ways to find the sequence of updates for $\vtheta_{t|t}$: a Bayesian formulation and a frequentist formulation. Regardless of the formulation, we obtain identical update rules and predictive models, but they will differ in how each perspective interprets the quantity that is being updated.

Recursive Bayesian learning

The classical prior-likelihood-posterior (PLP) Bayesian view starts with a prior over the model parameters $$ p(\vtheta) = \normal(\vtheta \mid \vtheta_{0|0}, \vSigma_0), $$ where $\vtheta_{0|0}$ is the prior mean and $\vSigma_0$ is the prior covariance. An easy way to initialize the prior covariance is as an identity times a constant (isotropic): $$ \vSigma_{0|0} = \alpha\,\vI, $$ with $\alpha > 0$.

Next, the PLP Bayesian view specifies a likelihood for each new observation. In our setting, this is given by $$ p_t(y_t \mid \vtheta, x_t) = \normal(y_t \mid h(\vtheta_{t-1|t-1},\,x_{t}) + \vJ_t\,(\vtheta_t -\vtheta_{t-1|t-1}),\,r^2). $$

Finally, updates to the prior given the likelihood are done following Bayes’ rule. This yields a posterior: $$ p(\vtheta \mid \data_{1:t}) \propto p(\vtheta)\,p(\data_{1:t} \mid \vtheta) = p(\vtheta)\,\prod_{\tau=1}^t p_\tau(y_\tau \mid \vtheta,\,x_\tau) = \normal(\vtheta \mid \vtheta_{t|t}, \vSigma_t). $$ In our online setting, we adapt the recursive view, justified by assumption that observations are conditionally independent given the parameters $\vtheta$. In this case, $$ \begin{aligned} p(\vtheta \mid \data_{1:t}) &\propto p(\vtheta \mid \data_{1:t-1})\,p(\data_t \mid \vtheta)\\ &\propto \normal(\vtheta \mid \vtheta_{t-1|t-1}, \vSigma_{t-1})\,p(y_t \mid \hat{h}(\vtheta, x_t), r^2)\\ &= \normal(\vtheta \mid \vtheta_{t|t}, \vSigma_t), \end{aligned} $$ where $\hat{h}(\vtheta, x_t) = h(\vtheta_{t-1|t-1},\,x_{t}) + \vJ_t\,(\vtheta_t -\vtheta_{t-1|t-1})$.

The Bayesian recursive update can be visualized as follows

$$ \begin{array}{c c c c c c c} p(\vtheta) & \to & p(\vtheta \mid \data_1) & \to & \ldots & \to & p(\vtheta \mid \data_{1:T}) \\ & & \uparrow & & & & \uparrow \\ & & \data_1 & & \ldots & & \data_T \\ \end{array} $$

The moments of the posterior density, $(\vtheta_{t|t}, \vSigma_{t|t})$, are given by the Kalman filter equations shown below.

Recursive Frequentist learning

The frequentist view begins not with a prior, but with an initial estimate of the model parameters $\vtheta_{0|0}$ and an initial error variance-covariance (EVC) matrix $\vSigma_0$.

At each step, the estimate $\vtheta_{t|t}$ is chosen so as to minimize the expected squared error relative to the true (but unknown) $\vtheta$: $$ \vtheta_{t|t} = \argmin_{\vv} \E[\|\vtheta - \vv\|_2^2 \mid \data_{1:t}]. $$ Rather than maintaining a posterior distribution, the uncertainty is expressed through the variance-covariance of the estimation error: $$ \vSigma_t = \var\left(\vtheta - \vtheta_{t|t}\right). $$ Thus, this approach considers two phases: (i) solve the optimization problem above, and (ii) quantify the error estimate to the true but unknown model parameters.

The frequentist recursive update can be visualized as follows $$ \begin{array}{c c c c c c c} (\vtheta_{0|0}, \vSigma_0) & \to & (\vtheta_{1|1}, \vSigma_1) & \to & \ldots & \to & (\vtheta_{t|t}, \vSigma_t)\\ & & \uparrow & & & & \uparrow \\ & & \data_1 & & \ldots & & \data_t \\ \end{array} $$

Under this approach, $\vtheta$ does not have a probability distribution. We only work with statistics of the estimate for $\vtheta$. The updates for the estimate $\vtheta_{t|t}$ and its EVC matrix are done following the Kalman filter equations shown below.

The predictive model for uncertainty quantification

Ultimately, what matters for prediction and uncertainty quantification is not our estimate of the parameters, but the predictive distribution of the next unseen observation. Under the linearized model $(1)$, both approaches above yield the same predictive model: $$ \tag{2} p(y_{t+1} \mid x_{t+1},\,\data_{1:t}) = \mathcal{N}\,\left( y_{t+1} \mid h(\vtheta_{t|t},\,\vx_{t+1}),\;\; \underbrace{\vJ_{t+1}\,\vSigma_{t}\,\vJ_{t+1}^\intercal}_{\text{epistemic}} + \underbrace{r^2}_{\text{aleatoric}} \right), $$

This expression separates epistemic uncertainty (due to limited knowledge of the parameters) from aleatoric uncertainty (inherent noise in the data).

From a Bayesian point of view, this is the posterior predictive density. From a frequentist point of view, it is the distribution of the uncentered prediction error (the residual). 9 In either case, we can treat this object as Bayesian in spirit: the unknown quantity is the next outcome $y_{t+1}$ given the features $x_{t+1}$ and past data $\data_{1:t}$.

The learning algorithm

The update for the model parameters $\vtheta_{t|t}$ and covariance $\vSigma_{t|t}$ are $$ \tag{3} \begin{aligned} \vtheta_{t|t} &= \vtheta_{t-1|t-1} + \vK_t\,\bm\epsilon_t,\\ \vSigma_{t} &= (\vI - \vK_t\,\vJ_t)\,\vSigma_{t-1}\,(\vI - \vK_t\,\vJ_t)^\intercal + r^2\,\vK_t\,\vK_t^\intercal. \end{aligned} $$ Here, the update depends on the innovation (or residual) $\bm\epsilon_t$ and the gain matrix $\vK_t$: $$ \begin{aligned} \bm\epsilon_t &= y_t - h(\vtheta_{t-1|t-1},\,x_t),\\ \vK_t &= \frac{\vSigma_{t-1}\,\vJ_t^\intercal}{\vJ_t\vSigma_{t-1}\,\vJ_t^\intercal + r^2}. \end{aligned} $$

These equations mirror the structure of the Kalman filter 10 applied to a linearized model. They key difference is how we interpret the updates: either as moments of a posterior distribution (Bayesian case), or as an point estimate with an error estimate (frequentist case). Because the frequentist perspective only works with estimators of the data (and does not require the specification of a distribution), it has the practical advantage of easily handling non-Gaussian estimators, e.g., considering a low-rank (degenerate / positive semi-definite) matrix $\vSigma_t$. This forms the basis of our LRKF and HiLoFi methods presented below, which yield yield faster updates than working with well-specified (positive definite) Gaussian covariance matrices required in the Bayesian-Gaussian case.

In pseudocode, the update equations $({\rm 3})$ for our setup are given by

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
R = jnp.array([[0.1]]) ** 2
def update(bel, y, x):
    yhat = h(bel, x)
    I = jnp.eye(len(bel.mean))

    err = y - yhat
    Jt = jax.jacrev(h)(bel.mean, x)
    St = Jt @ bel.cov @ Jt.T + R
    Kt = jnp.linalg.solve(St, Jt @ bel.cov).T

    # (theta(t|t), Sigma(t))
    mean_update = bel.mean + Kt @ err
    cov_update = (I - Kt @ Jt) @ bel.cov @ (I - Kt @ Jt).T + Kt @ R @ Kt.T

    bel = bel.replace(mean=mean_update, cov=cov_update)
    return bel

Running the algorithm above with input $(\vtheta_{t-1|t-1}, \vSigma_{t-1})$ yields $(\vtheta_{t|t}, \vSigma_t)$. A full pass of data two times (two epochs) is illustrated in the following pseudocode:

1
2
3
4
5
dataset = zip(x_train, y_train)
bel = BeliefState(params_init, Sigma_init)
for i in range(2):
    for x, y in dataset:
        bel = update(bel, y, x)

Full-covariance estimation (FC)

We begin our tour of various sequential learning methods from the PLP Bayesian point of view. Here, we approximate the posterior over the parameters as multivariate Gaussian $$ p(\vtheta \mid \data_{1:t}) \approx \normal(\vtheta \mid \vtheta_{t|t}, \vSigma_t). $$ As we will see, different approximations to $\vSigma_t$ yield different algorithms.

In the first example we consider, the full posterior covariance given by the update shown in $({\rm 3})$. Once we run the experiment through two passes of the data, we obtain a measure for the fit of the data and a measure of uncertainty. We follow two approaches: the posterior sampling approach and the posterior predictive approach.

Posterior sampling for uncertainty quantification

In the PLP Bayesian formulation, the posterior predictive can be written as the marginalization over the posterior distribution over the parameters: $$ \begin{aligned} p(y_{t+1} \mid \data_{1:t},\,x_{t+1}) &= \int p(\vtheta_t,\,y_{t+1} \mid \data_{1:t},\,x_{t+1}) \d\vtheta\\ &= \int p(y_{t+1} \mid \vtheta_t,\,x_{t+1})\,p(\vtheta_t \mid \data_{1:t}) \d\vtheta\\ &= \int p(y_{t+1} \mid \vtheta_t,\,x_{t+1})\,\normal(\vtheta_t \mid \vtheta_{t|t}, \vSigma_t) \d\vtheta. \end{aligned} $$ This integral is often intractable, so we approximate it by sampling from the posterior: $$ p(\vtheta \mid \data_{1:t}) \approx \frac{1}{S}\sum_{s=1}^S \delta\left(\vtheta - \vtheta^{(s)}\right), $$ with $\vtheta^{(s)} \sim \normal(\cdot \mid \vtheta_{t|t},\,\vSigma_t)$.

This sampling scheme allows us to approximate expectations under the posterior. For instance, the posterior predictive mean can be estimated as: $$ \begin{aligned} \E[y_{t+1} \mid x_{t+1},\,\data_{1:t}] &= \E[h(\vtheta, x_{t+1}) \mid \data_{1:t}]\\ &= \int h(\vtheta, x_{t+1}) p(\vtheta \mid \data_{1:t})\\ &\approx \frac{1}{S}\sum_{s=1}^S h\left(\vtheta^{(s)}, x_{t+1}\right). \end{aligned} $$

The result of this procedure is shown in the Figure below. We draw 100 samples from the posterior and evaluate $h(\vtheta^{(s)}, x)$ for all samples $\vtheta^{(s)}$ and varying $x$.

full covariance posterior sampling

We see that the algorithm captures, to some degree, the global shape of the Ackley function $f$. Furthermore, the uncertainty remains large throughout the test points. This pathology is known in the literature and is typically mitigated through the so-called cold posterior effect. 11

One can argue that this poor result for both fitting the data and uncertainty quantification is because our posterior is not good enough — maybe we needed more datapoints or more epochs. However, as we show below, this is likely due to to a misspecified posterior.

Predictive sampling for uncertainty quantification

Under the linearized model $(1)$, we can draw samples directly from the posterior predictive $({\rm 2})$, which happens in observation space, rather than in parameter space. This is what we actually care about: distributions over predictions, not over parameters. Thus, instead of asking what is the distribution over the parameters, we are asking what could we observe next?

A sample-based approximation of the posterior predictive at a test point $x_{t+1}$ is $$ p(y \mid x_{t+1},\,\data_{1:t}) \approx \frac{1}{S}\sum_{s=1}^S \delta\left(y - y^{(s)}\right), $$ with $ y^{(s)} \sim p(\cdot \mid x_{t+1},\,\data_{1:t}) $. This is helpful in settings such as Bayesian optimization and bandits, where we take samples from the posterior predictive to guide actions.

For the linearized model $({\rm 1})$, the posterior predictive at a point $x_{t+1}$, takes the analytical form of a Gaussian with mean and variance given by $$ \begin{aligned} \E[y_{t+1} \mid x_{t+1},\,\data_{1:t}] &= h(\vtheta_{t|t}, x_{t+1}),\\ \var[y_{t+1} \mid x_{t+1}, \data_{1:t}] &= \vJ_{t+1}\,\vSigma_{t}\,\vJ_{t+1}^\intercal + r^2. \end{aligned} $$

Below, we show the posterior predictive mean around two standard deviations across varying values of $x$.

full covariance posterior predictive sampling

Compared to posterior sampling, the predictive density in this example produces less overall variance and a tighter fit around the observed values. Interestingly, the value for $\vtheta_{t|t}$ and $\vSigma_{t|t}$ are the same in both posterior-first and predictive-first cases.

In our experiments, we find that sampling from the posterior predictive $p(y_{t+1} \mid x_{t+1}, \data_{1:t})$ yields better performance than sampling from the posterior $p(\vtheta \mid \data_{1:t})$ in sequential decision making tasks using neural networks.

Diagonal approximation (D)

A major drawback of the method described above is that the covariance $\vSigma_t$ is a dense and full-rank matrix. Maintaining such an approximation requires $O(D^2)$ parameters. This becomes infeasible even for moderately sized neural networks. For example, in our neural network with only $D=901$ number of parameters, storing the covariance matrix requires an additional $D^2=901^2 = 811,801$ number of parameters.

From a PLP Bayesian perspective, a memory-efficient way to reduce the memory cost and time cost is to approximate the posterior with a diagonal covariance: $$ \vSigma_t = \rm{diag}(\sigma_1, \ldots, \sigma_D) = \Upsilon_t. $$ So that $$ p(\vtheta \mid \data_{1:t}) \approx \normal(\vtheta \mid \vtheta_{t|t}, \Upsilon_t). $$

A popular example of this idea in a non-sequential (offline) setting is Bayes by Backprop (BBB), 12 which does a diagonal approximation of the posterior covariance.

Here, however, we consider the diagonal approximation following Chang et al. (2022), 13 which also linearizes the neural network and allows for closed-form updates and an explicit posterior predictive. The plot below is the evaluation of the posterior predictive model $({\rm 2})$ after two epochs with $\vSigma_t = \Upsilon_t$.

diagonal predictive sampling

This methods is faster than FC, but is not as data efficient. By this, we mean that we need more datapoints or more epochs to better fit the data. We quantify this in a variance-weighted error below.

Diagonal + low-rank approximation (D+LR)

A Bayesian middle ground between a full-covariance and a diagonal-covariance is to use a diagonal plus low-rank (D+LR) structure. In particular, the LoFi method 14 approximates the precision matrix (the inverse covariance) in a way that enables efficient updates to the model parameters while still capturing more structure than a purely diagonal approximation. In this case, the (inverse) covariance takes the form $$ \vSigma_t^{-1} = \vW_t\,\vW_t^T + \Upsilon_t. $$ Hence, $$ p(\vtheta \mid \data_{1:t}) \approx \normal(\vtheta \mid \vtheta_{t|t},\,(\vW_t\,\vW_t^T + \Upsilon_t)^{-1}). $$

where $\Upsilon_t$ is diagonal, $\vW_t \in \reals^{D\times d}$, and $d \ll D$. This formulation allows the use of linear algebra identities to keep updates computationally, and memory, efficient.

Similar to the FC approach, this D+LR approach also linearizes the neural network, so the posterior predictive can also be modelled explicitly. The figure below shows the predictive mean and variance under the linearized posterior predictive model, after two epochs and taking $d=20$:

LoFi predictive sampling

Compared to the full-covariance (FC) and diagonal (D) approaches, LoFi (D+LR) provides a compromise: it is more time-efficient than maintaining a dense covariance while capturing richer structure than a purely diagonal approximation. This balance allows for a more efficient update of the model parameters, without incurring the prohibitive cost of full-covariance updates. We quantify this in the benchmarks below, where we measure a variance-weighted error term and time to run for the various methods.

Low-rank approximation (LR) — non Bayesian updates

So far, we have seen multiple approximations to the posterior covariance that range from: (FC) computing the full covariance entirely, (D) approximating only the diagonal term of the covariance, or (D+LR) approximating the diagonal term and a low-rank structure (of precision) for more efficient updates.

Building on the approximations above, one may wonder whether we can drop the diagonal term entirely and keep only a low-rank covariance. Doing so would reduce the memory and increase the speed of the algorithm.

The problem is that a strictly low-rank covariance is singular (not invertible), so it cannot define a valid Gaussian posterior. To stay bayesian in the PLP sense, one could approximate a posterior that supports low-rank structures with variational Bayes (define a family of low-rank densities and approximate over it), but this may add additional mathematical and programatical overhead.

However, if the take the frequentist view, we are able to track a low-rank covariance matrix and still obtain a valid posterior predictive density. This is because $\vSigma_t$ is only a statistic and not a Gaussian posterior covariance.

In particular, consider the low-rank approximation $$ \vW_t = \argmin_{\vW \in \reals^{D\times d}}\|\vSigma_t - \vW^\intercal\,\vW\|_{\rm Fro}, $$ with $\vW_t \in \reals^{D\times d}$, $d \ll D$, and $\|\cdot\|_{\rm Fro}$ the Frobenius norm. 15 Thus, the approximate EVC matrix at time $t$ takes the form $$ \hat{\vSigma}_t = \vW_t^\intercal\,\vW_t. $$

We call the resulting approximation and subsequent efficient update equations the low-rank Kalman filter, or LRKF (See Appendix F.2 in the paper for details).

The figure below illustrates the predictive model $({\rm 2})$ using the low-rank updates derived from LRKF after two epochs and taking $d=100$.

LRKF predictive sampling

We observe that the result is comparable to that of LoFi (D+LR), except at the tails, where it does not fit the data appropriately. Nonetheless, the updates are faster than D+LR even after using a higher rank (see time comparisons below)

HiLoFi: Full-rank low-rank approximation (FL-LR)

As we saw before, the introduction of LRKF, results in a method that is fast, trains neural networks recursively, and produces a posterior predictive model suitable for uncertainty quantification and sequential decision making. Yet the results are not as expressive as what one might expect from a Gaussian Process (GP). Thus, the next step is to maintain the speed of the low-rank setting while improving its predictive performance and uncertainty quantification. This motivates our main method, which we call HiLoFi.

One of the main motivations for HiLoFi comes from last-layer BNNs literature. In a nutshell, hidden layers mainly learn representations, while the last layer directly maps to the target variable. Thus, we treat them separately.

Formally, we write our neural network as: $$ h(\vtheta, x) = \bm\ell^\intercal\,\phi(\vh, x), $$ where $\bm\ell$ are the last layer parameters and $\vh$ are the hidden layer parameters. Following the frequentist perspective, our quantities of interest are: $$ \begin{aligned} \vtheta_{t|t} &= (\bm\ell_{t|t}, \vh_{t|t}),\\ \hat{\vSigma}_{t|t} &= \begin{bmatrix} \vSigma_{\ell,t} & \vO\\ \vO & \vSigma_{\vh,t}\\ \end{bmatrix}, \end{aligned} $$ with

  • $ \vSigma_{\vh,t} = \vC_t^\intercal\,\vC_t $ a low-rank covariance matrix for the hidden-layer parameters ($\vC_t \in \reals^{d_h\times D_h}$) and
  • $\vSigma_{\bm\ell,t} \in \reals^{D_\ell\,\times D_\ell}$ a full-rank covariance matrix for the last-layer parameters.

Here, $D = D_\ell + D_h$ and $d_h \ll D_h$.

Our design for HiLoFi is for both computational and statistical efficiency. We model the last-layer parameters with a block full rank covariance matrix, and the hidden layer parameters with block low-rank matrix.

This setup also suggests how to set initial conditions. Since overparameterized networks tend to stay close to their initialization 16, we give the hidden layers a concentrated initial covariance, while the last layer (which directly controls the predictions) is initialized with a larger covariance to allow more flexibility.

The plot below shows the resulting predictive model after running HiLoFi for two epochs and taking $d_h = 20$.

HiLoFi predictive sampling

HiLoFi displays the desired properties for this dataset:

  • uncertainty grows in regions with no data,
  • interpolation is reasonably accurate at observed points, and
  • convergence is achieved in just two passes.

Taken together, HiLoFi shows that we can keep the Bayesian flavour in the predictive distribution, while using frequentist updates to gain speed and scalability.

HiLoFi keeps a full-rank block on the last layer (where predictions live) and a low-rank block on hidden layers (where representation lives). This yields GP-like uncertainty, with LRKF-like speed.

A low-rank low-rank (LR-LR) approximation

An alternative LR-LR approximation is possible whenever the output layer is high-dimensional. The result of this method is called LoLoFi.

Benchmarks

variance-weighted error

To evaluate the efficacy of the method in balancing uncertainty and predictive accuracy, we consider the variance-weighted error. Given a predictive model evaluated over a set of inputs ${\cal X}$ with predictive means $m_x$ and variances $s_x$. We define

$$ E_{\cal X} = \sum_{x\in {\cal X}} w_x (y - m_x)^2, $$ where the weights are inversely proportional to the predictive variance $$ \begin{aligned} \hat{w}_x &= \frac{1}{s_x}, & w_x &= \frac{\hat{w}_x}{\sum_{x' \in {\cal X}} \hat{w}_x}. \end{aligned} $$ For example, after $t$ observations, the FC-case takes the form $$ \begin{aligned} m_x &= h(\vtheta_{t|t}, x)\\ s_x &= \vJ_{x}\,\vSigma_{t}\,\vJ_{x}^\intercal + r^2, \end{aligned} $$ with $\vJ_x = \nabla_\vtheta h(\vtheta_{t|t}, x)$,

Below, we show the variance-weighted error for all methods discussed above.

weighted error comparison

We observe that FC, HiLoFi, and GP are the three methods with lowest variance-weighted error, with HiLoFi being the method with lowest weighted error overall. As expected, in terms of error FC < D+LR < D, suggesting that more structured covariance matrices lead to better uncertainty estimation.

Time comparison

Next, we compare training runtimes for the different covariance approximations shown above. From left to right, speed decreases: full covariance is by far the slowest, followed by diagonal+low-rank. Then, as expected, the three fastest methods are (D), HiLoFi (FL-LR), and LRKF (LR).

time comparison (in seconds)

This small experiment is consistent with the broader results in our paper: HiLoFi strikes a favourable balance, remaining fast while still producing well-calibrated predictive uncertainty. In downstream tasks, this translates into lower regret in bandits and better speed-performance tradeoff in Bayesian optimization.

Posterior predictive correlation matrix

One way to understand the peculiar form of HiLoFi over the other BNN approaches is through the predictive correlation matrix. To obtain this, we consider a vector of test datapoints $\vx = (x_1, \ldots, x_N)$ for which we seek to find a joint posterior predictive. In this case, the predictive model is a multivariate Gaussian of the form

$$ p(\vy \mid \vx,\,\data_{1:t}) = \mathcal{N}\,\left( \vy \mid \vh ,\;\; \vJ\,\vSigma_{t}\,\vJ^\intercal + r^2\,\vI_N \right), $$ with $$ \begin{aligned} \vh &= \begin{bmatrix} h(\vtheta_{t|t},\,x_1) & \ldots & h(\vtheta_{t|t},\,\vx_N) \end{bmatrix}^\intercal \in \reals^N,\\ \vJ &= \begin{bmatrix} \vJ_1 & \ldots & \vJ_N \end{bmatrix} \in \reals^{N\times D}. \end{aligned} $$

Let $$ \vS = \vJ\,\vSigma_{t}\,\vJ^\intercal + r^2\,\vI_N \in \reals^{N\times N}. $$ Then, the posterior predictive correlation matrix is $$ {\cal C}_{i,j} = \frac{S_{i,j}}{\sqrt{S_{i,i}}\sqrt{S_{jj}}}. $$

We show the posterior predictive correlation matrix for all methods below

Posterior predictive correlation matrix

We see that, around the diagonal, HiLoFi approximates GP. However, it extrapolates the correlation structure outside local points. We leave understanding and controlling this behaviour for future work.

Bayesian optimization experiments: the effect of the low-rank

Finally, we compare the performance of HiLoFi and LRKF as a function of the rank in a Bayesian optimization setting (for details, see Section G.4.1 in the Appendix).

The goal is to find the maximum value of a black box function whose input is 200-dimensional and the output is unidimensional. In this experiment, we have to balance two things: exploring regions where uncertainty is high (hoping that there might be higher values to find) or exploiting local regions where high values have already been found (hoping that higher values are found around those regions regions).

The plots below compare the performance of HiLoFi and LRKF as a function of their rank. They highlight two things:

  1. the importance of the rank: higher rank tends to yield better results and
  2. the importance of modelling the last-layer explicitly: for fixed rank, the performance of HiLoFi is superior than LRKF. Similarly, for fixed running time, the performance in best value is comparable, however, LRKF requires a much higher low-rank than the low-rank for HiLoFi (See e.g., rank 100 in HiLoFi and rank 200 in LRKF). In practice this means higher memory requirements.
LRKF-drawnn-performanceHiLoFi-drawnn-performance

Conclusion

The predictive-first view offers a way to design efficient methods for modeling uncertainty in prediction space, rather than focusing on approximating parameter uncertainty, which is of less interest in most ML applications.

Citation

1
2
3
4
5
6
7
@article{duran2025scalable,
  title={Martingale Posterior Neural Networks for Fast Sequential Decision Making},
  author={Duran-Martin, Gerardo and S{\'a}nchez-Betancourt, Leandro and Cartea, {\'A}lvaro and Murphy, Kevin},
  journal={Advances in Neural Information Processing Systems},
  volume={38},
  year={2025}
}

  1. Tomani, Christian, et al. “Uncertainty-based abstention in llms improves safety and reduces hallucinations.” arXiv preprint arXiv:2404.10960 (2024). ↩︎

  2. Lakshminarayanan, Balaji, Alexander Pritzel, and Charles Blundell. “Simple and scalable predictive uncertainty estimation using deep ensembles.” Advances in neural information processing systems 30 (2017). ↩︎

  3. Ferianc, Martin, et al. “On the effects of quantisation on model uncertainty in bayesian neural networks.” Uncertainty in Artificial Intelligence. PMLR, 2021. ↩︎

  4. Wenzel, Florian, et al. “How good is the bayes posterior in deep neural networks really?.” arXiv preprint arXiv:2002.02405 (2020). ↩︎

  5. Knoblauch, Jeremias, Jack Jewson, and Theodoros Damoulas. “An optimization-centric view on Bayes’ rule: Reviewing and generalizing variational inference.” Journal of Machine Learning Research 23.132 (2022): 1-109. ↩︎

  6. Fong, Edwin, Chris Holmes, and Stephen G. Walker. “Martingale posterior distributions.” Journal of the Royal Statistical Society Series B: Statistical Methodology 85.5 (2023): 1357-1391. ↩︎

  7. See e.g., slide 4 in https://mlg.eng.cam.ac.uk/teaching/4f13/2122/gp%20and%20data.pdf ↩︎

  8. https://peterroelants.github.io/posts/gaussian-process-kernels/#Exponentiated-quadratic-kernel ↩︎

  9. Under the assumption of Gaussian noise, i.e., $p(e_t) = \normal(e_t \mid 0, r^2)$, the innovation $\varepsilon_t = y_t - h(\vtheta_{t-1|t-1}, x_t)$ is a zero-mean Gaussian with variance $S_{t} = \vJ_{t},\vSigma_t,\vJ_{t}^\intercal + r^2$. Thus, $p(\varepsilon_t \mid x_{t},\,\data_{1:t-1}) = \normal(\varepsilon_t \mid 0, S_t)$. Because $h(\vtheta_{t-1|t-1}, x_t)$ is fixed at time $t$, this is equivalent to writing $p(y_t \mid x_{t},\,\data_{1:t-1}) = \normal(y_t \mid h(\vtheta_{t-1|t-1}, x_t), S_t)$. ↩︎

  10. See e.g., filtering-notes-iii ↩︎

  11. Wenzel, Florian, et al. “How good is the bayes posterior in deep neural networks really?.” arXiv preprint arXiv:2002.02405 (2020). ↩︎

  12. Blundell, Charles, et al. “Weight uncertainty in neural network.” International conference on machine learning. PMLR, 2015. ↩︎

  13. Chang, Peter G., Kevin Patrick Murphy, and Matt Jones. “On diagonal approximations to the extended Kalman filter for online training of Bayesian neural networks.” Continual Lifelong Learning Workshop at ACML 2022. 2022. ↩︎

  14. Chang, Peter G., et al. “Low-rank extended Kalman filtering for online learning of neural networks from streaming data.” arXiv preprint arXiv:2305.19535 (2023). ↩︎

  15. See e.g., https://mathworld.wolfram.com/FrobeniusNorm.html ↩︎

  16. Du, Simon S., et al. “Gradient descent provably optimizes over-parameterized neural networks.” arXiv preprint arXiv:1810.02054 (2018). ↩︎