Variational Inference API¶
This page documents the public variational-inference classes and helpers behind Bayes by Backprop in deepuq. Use this page for the concrete API contract: constructor arguments, tensor shapes, return semantics, and workflow requirements. Use the method guide for the mathematical derivations and broader tradeoffs.
Public objects¶
GaussianPosteriorGaussianPriorBayesianLinearBayesByBackpropMLPBayesByBackpropRegressorHeteroscedasticBayesByBackpropRegressorMultiOutputBayesByBackpropRegressorHeteroscedasticMultiOutputBayesByBackpropRegressorLastLayerVariationalInferencevi_elbo_steppredict_vi_uq
Workflow summary¶
The VI workflow in deepuq is consistent across all public VI models:
- Instantiate a VI model.
- Train it with
vi_elbo_step(...)inside the optimizer loop. - Draw Monte Carlo predictive samples with
predict_vi_uq(...). - Read predictive moments through the returned
UQResult.
The important distinction between the model variants is the structure of the raw model output:
- plain regression returns a mean only
- heteroscedastic regression returns mean and raw log-variance channels
- classification returns logits
- multi-output regression returns one mean per output dimension
- last-layer VI wraps a deterministic feature extractor and Bayesian linear head
Parameter and variable conventions¶
| Name | Meaning | Where it appears |
|---|---|---|
prior_mu | Mean of the Gaussian prior placed on Bayesian weights. | Bayesian layers and VI model constructors |
prior_sigma | Standard deviation of the Gaussian prior. Lower values tighten the posterior around zero. | Bayesian layers and VI model constructors |
input_dim | Size of the final feature dimension expected by dense VI models. | BayesByBackpropMLP and regression variants |
hidden_dims | Width of each hidden Bayesian MLP layer. | Dense VI models |
output_dim | Number of regression outputs or classes. | Multi-output models and last-layer VI |
activation | Hidden-layer nonlinearity. Supported values are documented by the source docstrings. | Dense VI models |
feature_extractor | Deterministic backbone whose final features feed a Bayesian head. | LastLayerVariationalInference |
feature_dim | Size of the feature dimension produced by feature_extractor. | LastLayerVariationalInference |
task | Either "regression" or "classification". Controls likelihood semantics and prediction formatting. | LastLayerVariationalInference |
heteroscedastic | Whether the regression head also predicts observation noise. | LastLayerVariationalInference |
num_batches | Number of optimizer steps per epoch. Used to scale the KL term in vi_elbo_step(...). | vi_elbo_step(...) |
n_batches | Deprecated alias for num_batches. | vi_elbo_step(...) |
kl_weight | Multiplicative factor applied to the scaled KL term. Often written as $\beta$ in the ELBO. | vi_elbo_step(...) |
mc_samples | Number of stochastic forward passes used to estimate the training ELBO for one minibatch. | vi_elbo_step(...) |
n_samples | Number of stochastic predictive forward passes used for Monte Carlo uncertainty estimates. | predict_vi_uq(...) |
apply_softmax | Interpret raw outputs as class logits and average probabilities instead of raw outputs. | predict_vi_uq(...) |
aleatoric | Optional additive variance term passed into predict_vi_uq(...) for plain regression models. | predict_vi_uq(...) |
Input and output shapes¶
Dense Bayes-by-Backprop models¶
BayesianLinear.forward(x, sample=True)expects the trailing feature dimension ofxto equalin_features.BayesByBackpropMLPand the dense regression variants typically consumexwith shape[batch, input_dim]and return:- plain regression:
[batch, output_dim] - heteroscedastic regression:
[batch, 2 * output_dim] - classification:
[batch, num_classes]
Last-layer VI¶
LastLayerVariationalInference is more general. Its deterministic feature extractor may emit either:
[batch, feature_dim]for dense features, or[batch, ..., feature_dim]for spatial or sequence-style features.
The Bayesian head is applied over the last dimension and preserves the leading dimensions. Examples:
- dense regression:
[batch, feature_dim] -> [batch, output_dim] - dense classification:
[batch, feature_dim] -> [batch, num_classes] - spatial regression head:
[batch, height, width, feature_dim] -> [batch, height, width, output_dim]
Training helper¶
vi_elbo_step(model, x, y, ...) expects x and y to be compatible with one stochastic forward pass of model. If the model implements nll(prediction, target), that method is used. Otherwise the helper falls back to the provided criterion, or to a default likelihood based on model.task_type.
Prediction helper¶
predict_vi_uq(model, x, n_samples=...) stacks n_samples stochastic forward passes and returns moments with the same leading shape as one model output. Examples:
- scalar regression:
mean.shape == [batch, 1] - multi-output regression:
mean.shape == [batch, output_dim] - classification:
probs.shape == [batch, num_classes] - spatial last-layer VI:
mean.shape == [batch, height, width, output_dim]
UQResult mapping¶
predict_vi_uq(...) returns a UQResult with semantics determined by the model family:
| Variant | Populated fields |
|---|---|
| Plain regression | mean, epistemic_var, total_var, metadata |
| Heteroscedastic regression | mean, epistemic_var, aleatoric_var, total_var, metadata |
| Multi-output regression | same as plain or heteroscedastic regression, but with an extra output dimension |
| Classification | mean, probs, probs_var, epistemic_var, metadata |
| Last-layer VI | follows the same rules as its configured task and heteroscedastic setting |
Interpretation details:
meanis always the Monte Carlo predictive mean.epistemic_varis the variance across stochastic weight samples.aleatoric_varis populated only when the model predicts observation noise, or when an explicitaleatoric=tensor is supplied for plain regression.total_varisepistemic_var + aleatoric_varwhen aleatoric variance exists; otherwise it equalsepistemic_var.probsandprobs_varare classification-specific convenience views.
For field-by-field details, see Types API.
Common preconditions and failure modes¶
num_batchesmust be a positive integer. Omitting it or setting it to zero raisesValueError.n_batchesis accepted only as a deprecated alias for backward compatibility.mc_samplesandn_samplesmust both be positive.- VI models used with
vi_elbo_step(...)must exposekl(). predict_vi_uq(...)assumesforward(sample=True)is supported.- Classification tasks require logits and integer class labels compatible with cross-entropy.
- Heteroscedastic regression assumes the output can be split into mean and raw variance channels.
LastLayerVariationalInferencerequires the deterministic backbone to return a trailing feature dimension equal tofeature_dim.
Minimal examples¶
Plain Bayes by Backprop regression¶
import torch
from deepuq.methods import BayesByBackpropMLP, predict_vi_uq, vi_elbo_step
model = BayesByBackpropMLP(input_dim=8, hidden_dims=(32, 32), output_dim=1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for xb, yb in train_loader:
optimizer.zero_grad(set_to_none=True)
loss, nll, kl = vi_elbo_step(
model,
xb,
yb,
num_batches=len(train_loader),
criterion=torch.nn.MSELoss(),
kl_weight=0.01,
mc_samples=4,
)
loss.backward()
optimizer.step()
uq = predict_vi_uq(model, x_test, n_samples=32)
print(uq.mean.shape, uq.total_var.shape)
Heteroscedastic regression¶
import torch
from deepuq.methods import (
HeteroscedasticBayesByBackpropRegressor,
predict_vi_uq,
vi_elbo_step,
)
model = HeteroscedasticBayesByBackpropRegressor(
input_dim=6,
hidden_dims=(64, 64),
prior_sigma=0.2,
)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
for xb, yb in train_loader:
optimizer.zero_grad(set_to_none=True)
loss, _, _ = vi_elbo_step(model, xb, yb, num_batches=len(train_loader))
loss.backward()
optimizer.step()
uq = predict_vi_uq(model, x_test, n_samples=64)
print(uq.mean.shape, uq.epistemic_var.shape, uq.aleatoric_var.shape)
Last-layer VI classification¶
import torch
import torch.nn as nn
from deepuq.methods import LastLayerVariationalInference, predict_vi_uq, vi_elbo_step
backbone = nn.Sequential(nn.Linear(4, 32), nn.ReLU(), nn.Linear(32, 16), nn.ReLU())
model = LastLayerVariationalInference(
feature_extractor=backbone,
feature_dim=16,
output_dim=3,
task="classification",
prior_sigma=0.1,
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for xb, yb in train_loader:
optimizer.zero_grad(set_to_none=True)
loss, _, _ = vi_elbo_step(model, xb, yb, num_batches=len(train_loader))
loss.backward()
optimizer.step()
uq = predict_vi_uq(model, x_test, n_samples=32, apply_softmax=True)
print(uq.probs.shape, uq.probs_var.shape)
Related docs¶
- UQ API Conventions
- Types API
- Variational Inference method guide
- Bayes by Backprop tutorial
- Heteroscedastic Bayes by Backprop + ADR1D
- Multi-Output Bayes by Backprop + Elastic Bar
- Heteroscedastic Multi-Output Bayes by Backprop + Transport2D
- Last-Layer VI + Heat2D Classification
deepuq.methods.vi ¶
Variational-inference primitives and wrappers for Deep-UQ.
This module centers on mean-field Bayes-by-Backprop layers and extends them to three additional regression variants plus a scalable last-layer VI wrapper. The public surface is designed to stay small:
- Bayesian layers are built from :class:
BayesianLinear. vi_elbo_step(...)remains the shared training helper.predict_vi_uq(...)remains the shared Monte Carlo predictive helper.
The new model classes make task-specific behavior explicit through public attributes such as task_type, heteroscedastic, and output_dim. That lets the training and prediction helpers implement current behavior without inventing parallel APIs per VI variant.
BayesByBackpropMLP ¶
Bases: _BayesianMLPBase
Convenience MLP composed from :class:BayesianLinear layers.
This is the original mean-field Bayes-by-Backprop baseline in Deep-UQ. It remains intentionally generic and can be used for regression or classification depending on the chosen output dimension and criterion.
BayesianLinear ¶
Bases: Module
Fully-connected layer with Bayesian weights and biases.
During sample=True forward passes, weights are sampled from the posterior. During sample=False passes, posterior means are used.
forward ¶
Apply the Bayesian affine transform.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Input tensor whose trailing feature dimension equals | required |
sample | bool | If | True |
GaussianPosterior ¶
Bases: Module
Diagonal Gaussian variational posterior over one parameter tensor.
The posterior is parameterized by mu and rho. We transform rho with softplus to obtain sigma and guarantee strictly positive standard deviations.
GaussianPrior ¶
Isotropic Gaussian prior used by Bayesian layers.
HeteroscedasticBayesByBackpropRegressor ¶
Bases: _BayesianMLPBase
Mean-field Bayesian regressor with input-dependent noise.
The network predicts two values per target dimension: a predictive mean and an unconstrained variance parameter. The variance parameter is transformed with softplus so that the Gaussian likelihood remains valid.
HeteroscedasticMultiOutputBayesByBackpropRegressor ¶
Bases: _BayesianMLPBase
Multi-output Bayesian regressor with per-output noise prediction.
LastLayerVariationalInference ¶
Bases: Module
Deterministic feature extractor with a Bayesian linear output head.
This wrapper is the scalable VI path for larger backbones. The feature extractor stays deterministic and only the final affine map is treated with a variational posterior.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature_extractor | Module | Deterministic module that returns a tensor whose trailing dimension is | required |
feature_dim | int | Size of the final feature dimension produced by | required |
output_dim | int | Number of regression outputs or classes. | required |
task | str |
| 'regression' |
heteroscedastic | bool | If | False |
prior_mu | float | Gaussian prior parameters for the Bayesian head. | 0.0 |
prior_sigma | float | Gaussian prior parameters for the Bayesian head. | 0.0 |
forward ¶
Apply the deterministic backbone and Bayesian head.
The wrapper preserves all leading dimensions returned by the feature extractor and applies the Bayesian linear head over the last dimension.
nll ¶
Return the heteroscedastic Gaussian NLL for regression wrappers.
predictive_variance ¶
Convert unconstrained regression noise output into positive variance.
split_prediction ¶
Split raw regression output into mean and raw variance tensors.
MultiOutputBayesByBackpropRegressor ¶
Bases: _BayesianMLPBase
Mean-field Bayesian regressor for vector-valued targets.
predict_vi_uq ¶
predict_vi_uq(
model: Module,
x: Tensor,
n_samples: int = 50,
apply_softmax: bool = False,
aleatoric_var: Tensor | None = None,
) -> UQResult
Monte Carlo predictive summary for Bayes-by-Backprop models.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model | Module | Bayesian model supporting | required |
x | Tensor | Inputs. | required |
n_samples | int | Number of stochastic weight samples. | 50 |
apply_softmax | bool | If | False |
aleatoric_var | Tensor | None | Optional additive aleatoric variance term for plain regression models. | None |
Returns:
| Type | Description |
|---|---|
UQResult | Regression calls populate |
vi_elbo_step ¶
vi_elbo_step(
model: Module,
x: Tensor,
y: Tensor,
num_batches: int | None = None,
n_batches: int | None = None,
criterion: Module | None = None,
kl_weight: float = 1.0,
mc_samples: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Compute one Bayes-by-Backprop ELBO step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model | Module | Bayesian model exposing | required |
x | Tensor | Minibatch inputs and targets. | required |
y | Tensor | Minibatch inputs and targets. | required |
num_batches | int | None | Canonical number of optimizer steps per epoch, usually | None |
n_batches | int | None | Deprecated alias for | None |
criterion | Module | None | Data-fit loss used when the model does not implement | None |
kl_weight | float | Multiplicative weight for the scaled KL term. | 1.0 |
mc_samples | int | Number of stochastic forward passes used to Monte Carlo-average NLL and KL for a lower-variance ELBO estimate. | 1 |
Returns:
| Type | Description |
|---|---|
(loss, nll, kl): |
|