Bayes-by-Backprop Tutorial¶
This version uses deterministic warmup + fixed ELBO fine-tuning to improve fit quality while keeping ELBO tracking stable.
Training Strategy¶
- Warmup stage: optimize MSE with posterior means (
sample=False) for good initialization. - VI stage: fixed objective with
kl_weight = 0.002and canonical KL scalingnum_batches = len(train_loader). - ELBO evaluation uses MC averaging (
mc_samples = 8) for lower variance.
Notebook Roadmap¶
This tutorial is organized as:
- In-domain fit: train and evaluate on data in
[-3, 3].
The focus is to show both predictive accuracy and uncertainty behavior using Bayes-by-Backprop.
# ----------------------------------------------------------------------------
# Notebook bootstrap cell
# - Adds the local `src/` directory to Python path.
# - This makes sure we import the in-repo `deepuq` package, not an older
# globally installed version.
# - Keep this cell first so all later imports are resolved correctly.
# ----------------------------------------------------------------------------
import os
import sys
from pathlib import Path
PROJECT_ROOT = Path(os.getcwd())
if not (PROJECT_ROOT / 'src').exists():
PROJECT_ROOT = PROJECT_ROOT.parent
SRC_PATH = str(PROJECT_ROOT / 'src')
if SRC_PATH not in sys.path:
sys.path.insert(0, SRC_PATH)
1) Environment Setup¶
This cell ensures the notebook imports the local project package from src/.
Why it matters:
- You run the latest local code changes immediately.
- You avoid version mismatches with any globally installed
deepuqpackage.
# ----------------------------------------------------------------------------
# Core imports + compatibility utilities
# - We import Bayes-by-Backprop primitives from the deepuq library:
# * `deepuq.methods.vi_elbo_step`
# * `deepuq.methods.vi.BayesianLinear`
# - `vi_elbo_step_compat` lets this notebook run against both old and new
# deepuq APIs (`n_batches` vs `num_batches`, optional `mc_samples`).
# ----------------------------------------------------------------------------
import random
import inspect
import copy
import math
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
# DeepUQ VI API (primary objective function for Bayes-by-Backprop training)
from deepuq.methods import vi_elbo_step
# DeepUQ Bayesian layer used to build the custom model in this notebook
from deepuq.methods.vi import BayesianLinear
VI_ELBO_SIG = inspect.signature(vi_elbo_step)
print(f"vi_elbo_step signature: {VI_ELBO_SIG}")
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def set_seed(seed: int) -> None:
"""Set RNG seeds for reproducibility across numpy/torch/cuda."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def vi_elbo_step_compat(
model,
x,
y,
*,
num_batches,
criterion,
kl_weight,
mc_samples=1,
):
"""
Compatibility wrapper around deepuq `vi_elbo_step`.
Why this exists:
- Some versions expect `num_batches`.
- Older versions expect `n_batches` and may not expose `mc_samples`.
"""
params = VI_ELBO_SIG.parameters
# New API supports `num_batches` directly.
if 'num_batches' in params:
kwargs = {
'num_batches': num_batches,
'criterion': criterion,
'kl_weight': kl_weight,
}
if 'mc_samples' in params:
kwargs['mc_samples'] = mc_samples
return vi_elbo_step(model, x, y, **kwargs)
# Legacy API path: if mc_samples == 1, pass through once.
if mc_samples == 1:
return vi_elbo_step(
model,
x,
y,
n_batches=num_batches,
criterion=criterion,
kl_weight=kl_weight,
)
# Legacy API path with manual Monte-Carlo averaging.
loss_acc = 0.0
nll_acc = 0.0
kl_acc = 0.0
for _ in range(mc_samples):
loss, nll, kl = vi_elbo_step(
model,
x,
y,
n_batches=num_batches,
criterion=criterion,
kl_weight=kl_weight,
)
loss_acc = loss_acc + loss
nll_acc = nll_acc + nll
kl_acc = kl_acc + kl
loss = loss_acc / float(mc_samples)
nll = nll_acc / float(mc_samples)
kl = kl_acc / float(mc_samples)
return loss, nll, kl
NOTEBOOK_QUICK_MODE = os.environ.get('DEEPUQ_NOTEBOOK_QUICK', '0') == '1'
print(f"Notebook quick mode: {NOTEBOOK_QUICK_MODE}")
vi_elbo_step signature: (model, x, y, num_batches: Optional[int] = None, n_batches: Optional[int] = None, criterion=None, kl_weight: float = 1.0, mc_samples: int = 1) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
2) DeepUQ VI APIs Used Here¶
Key library methods used in this notebook:
deepuq.methods.vi_elbo_step: computes one ELBO step (NLL + KL term).deepuq.methods.vi.BayesianLinear: Bayesian linear layer with learned posterior.
vi_elbo_step_compat is a local wrapper so the notebook remains compatible with both old and newer deepuq argument names (n_batches vs num_batches).
# ----------------------------------------------------------------------------
# Part 1 data generation (in-domain training and testing on [-3, 3])
# - We build a nonlinear 1D regression problem with controlled Gaussian noise.
# - Inputs are transformed with Fourier features to improve representation.
# - Targets are standardized for stable optimization.
# ----------------------------------------------------------------------------
set_seed(13)
gen = torch.Generator().manual_seed(13)
def target_fn(x: torch.Tensor) -> torch.Tensor:
"""Ground-truth function to be learned by the Bayesian model."""
return 0.45 * x + torch.sin(1.7 * x) + 0.2 * torch.cos(3.2 * x)
def fourier_features(x_raw: torch.Tensor) -> torch.Tensor:
"""
Shared feature map for Part 1 and Part 2.
We enrich a scalar x into a multi-frequency basis.
"""
return torch.cat(
[
x_raw,
torch.sin(0.8 * x_raw),
torch.cos(0.8 * x_raw),
torch.sin(1.7 * x_raw),
torch.cos(1.7 * x_raw),
torch.sin(2.6 * x_raw),
torch.cos(2.6 * x_raw),
torch.sin(3.5 * x_raw),
torch.cos(3.5 * x_raw),
],
dim=1,
)
# Synthetic dataset sizes and observation noise level
num_total = 1200
noise_std = 0.08
# Uniform sampling in the in-domain region
x_all_raw = torch.empty(num_total, 1).uniform_(-3.0, 3.0, generator=gen)
y_true_all = target_fn(x_all_raw)
y_all_raw = y_true_all + noise_std * torch.randn(num_total, 1, generator=gen)
# Random split: train / validation / test
perm = torch.randperm(num_total, generator=gen)
idx_train = perm[:800]
idx_val = perm[800:1000]
idx_test = perm[1000:]
x_train_raw, y_train_raw = x_all_raw[idx_train], y_all_raw[idx_train]
x_val_raw, y_val_raw = x_all_raw[idx_val], y_all_raw[idx_val]
x_test_raw, y_test_raw = x_all_raw[idx_test], y_all_raw[idx_test]
y_test_true = y_true_all[idx_test]
# Compute feature/target normalization statistics from training set only
feat_train = fourier_features(x_train_raw)
feat_mean = feat_train.mean(dim=0, keepdim=True)
feat_std = feat_train.std(dim=0, keepdim=True, unbiased=False).clamp_min(1e-6)
y_mean = y_train_raw.mean(dim=0, keepdim=True)
y_std = y_train_raw.std(dim=0, keepdim=True, unbiased=False).clamp_min(1e-6)
def scale_x(x_raw: torch.Tensor) -> torch.Tensor:
"""Apply train-derived normalization to Fourier features."""
return (fourier_features(x_raw) - feat_mean) / feat_std
def scale_y(y_raw: torch.Tensor) -> torch.Tensor:
"""Standardize targets using train-derived mean/std."""
return (y_raw - y_mean) / y_std
# Prepare normalized tensors used by the model and loss
x_train = scale_x(x_train_raw)
y_train = scale_y(y_train_raw)
x_val = scale_x(x_val_raw)
y_val = scale_y(y_val_raw)
x_test = scale_x(x_test_raw)
y_test = scale_y(y_test_raw)
# DataLoaders feed mini-batches into warmup and VI phases
train_loader = DataLoader(TensorDataset(x_train, y_train), batch_size=128, shuffle=True)
val_loader = DataLoader(TensorDataset(x_val, y_val), batch_size=128, shuffle=False)
test_loader = DataLoader(TensorDataset(x_test, y_test), batch_size=128, shuffle=False)
# Cached tensors on DEVICE for inverse-scaling during RMSE/prediction utilities
y_mean_device = y_mean.to(DEVICE)
y_std_device = y_std.to(DEVICE)
3) Synthetic Regression Data (Part 1)¶
This section builds a noisy nonlinear regression dataset and then standardizes it.
Important design choices:
- Fourier feature expansion improves expressivity for periodic components.
- Scaling statistics are computed from training split only.
- Validation/test use the same training normalization for fair evaluation.
# ----------------------------------------------------------------------------
# Quick data sanity plot
# - Shows noisy train/val samples plus the true underlying function.
# - Use this to verify the dataset shape before training.
# ----------------------------------------------------------------------------
plt.figure(figsize=(7, 4))
plt.scatter(x_train_raw.squeeze().numpy(), y_train_raw.squeeze().numpy(), s=12, alpha=0.45, label='Train noisy')
plt.scatter(x_val_raw.squeeze().numpy(), y_val_raw.squeeze().numpy(), s=12, alpha=0.35, label='Val noisy')
x_vis = torch.linspace(-3.0, 3.0, 300).unsqueeze(-1)
y_vis = target_fn(x_vis)
plt.plot(x_vis.squeeze().numpy(), y_vis.squeeze().numpy(), color='black', linewidth=2, label='True function')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Synthetic regression data')
plt.legend(frameon=False)
plt.tight_layout()
plt.show()
4) Data Sanity Plot¶
Before training, verify:
- The sampled points cover the intended input range.
- Noise level looks reasonable.
- The true target trend is visible underneath noisy observations.
# ----------------------------------------------------------------------------
# Bayesian model definition
# - Built from `deepuq.methods.vi.BayesianLinear` blocks.
# - Residual architecture improves optimization stability in deeper nets.
# - Each BayesianLinear contributes a KL term for ELBO training.
# ----------------------------------------------------------------------------
class BayesianResidualBlock(nn.Module):
"""Residual block with two BayesianLinear layers."""
def __init__(self, width: int, prior_sigma: float):
super().__init__()
self.fc1 = BayesianLinear(width, width, prior_sigma=prior_sigma)
self.fc2 = BayesianLinear(width, width, prior_sigma=prior_sigma)
self.act = nn.SiLU()
def forward(self, x: torch.Tensor, sample: bool = True) -> torch.Tensor:
# Residual branch: x + f(x); scale branch by 0.5 to keep updates smooth.
h = self.act(self.fc1(x, sample=sample))
h = self.fc2(h, sample=sample)
return self.act(x + 0.5 * h)
def kl(self) -> torch.Tensor:
# KL is additive across Bayesian layers.
return self.fc1.kl() + self.fc2.kl()
class ComplexBayesianRegressor(nn.Module):
"""Main Bayesian regression model used in both parts of the notebook."""
def __init__(self, input_dim: int, width: int = 72, depth: int = 3, prior_sigma: float = 0.9):
super().__init__()
self.input_layer = BayesianLinear(input_dim, width, prior_sigma=prior_sigma)
self.blocks = nn.ModuleList([BayesianResidualBlock(width, prior_sigma) for _ in range(depth)])
self.output_layer = BayesianLinear(width, 1, prior_sigma=prior_sigma)
self.act = nn.SiLU()
def forward(self, x: torch.Tensor, sample: bool = True) -> torch.Tensor:
h = self.act(self.input_layer(x, sample=sample))
for block in self.blocks:
h = block(h, sample=sample)
return self.output_layer(h, sample=sample)
def kl(self) -> torch.Tensor:
total = self.input_layer.kl() + self.output_layer.kl()
for block in self.blocks:
total = total + block.kl()
return total
def initialize_bayesian_posteriors(model: nn.Module, rho_init: float = -4.5) -> None:
"""
Custom initialization for BayesianLinear parameters.
Why:
- Zero posterior means in deeper residual Bayesian networks can slow start.
- Xavier-like mean init gives a stronger deterministic warmup phase.
- `rho_init` controls initial posterior std through softplus(rho).
"""
for module in model.modules():
if isinstance(module, BayesianLinear):
bound = 1.0 / math.sqrt(module.in_features)
nn.init.uniform_(module.weight_posterior.mu, -bound, bound)
nn.init.uniform_(module.bias_posterior.mu, -bound, bound)
nn.init.constant_(module.weight_posterior.rho, rho_init)
nn.init.constant_(module.bias_posterior.rho, rho_init)
# Shared architecture hyperparameters (reused in Part 2).
MODEL_WIDTH = 72
MODEL_DEPTH = 3
MODEL_PRIOR_SIGMA = 0.9
5) Bayesian Model Architecture¶
ComplexBayesianRegressor is built from residual blocks using BayesianLinear.
Why this design is used:
- Residual connections improve optimization stability.
- Bayesian layers provide uncertainty over weights.
initialize_bayesian_posteriors(...)gives a stronger warm start than all-zero posterior means.
# ----------------------------------------------------------------------------
# Evaluation helpers
# - `evaluate_elbo`: reports ELBO/NLL/KL over a full loader with MC averaging.
# - `evaluate_rmse`: computes RMSE in original y-units (after inverse scaling).
# - `posterior_predictive`: returns predictive mean/std on arbitrary x points.
# ----------------------------------------------------------------------------
def ema(values, alpha: float = 0.2):
"""Exponential moving average for smoother trend visualization."""
smoothed = []
for value in values:
if not smoothed:
smoothed.append(float(value))
else:
smoothed.append(alpha * float(value) + (1.0 - alpha) * smoothed[-1])
return smoothed
def evaluate_elbo(
model: nn.Module,
loader: DataLoader,
criterion: nn.Module,
kl_weight: float,
num_batches: int,
mc_samples: int = 8,
):
"""Compute dataset-average ELBO components using `vi_elbo_step_compat`."""
model.eval()
total_loss = 0.0
total_nll = 0.0
total_kl = 0.0
total_items = 0
with torch.no_grad():
for x_batch, y_batch in loader:
x_batch = x_batch.to(DEVICE)
y_batch = y_batch.to(DEVICE)
# This is where we call the DeepUQ VI objective helper.
loss, nll, kl = vi_elbo_step_compat(
model,
x_batch,
y_batch,
num_batches=num_batches,
criterion=criterion,
kl_weight=kl_weight,
mc_samples=mc_samples,
)
items = y_batch.numel()
total_loss += loss.item() * items
total_nll += nll.item() * items
total_kl += kl.item() * items
total_items += items
return (
total_loss / total_items,
total_nll / total_items,
total_kl / total_items,
)
@torch.no_grad()
def evaluate_rmse(model: nn.Module, loader: DataLoader, mc_samples: int = 120):
"""RMSE in original target scale using MC predictive mean."""
model.eval()
total_sq = 0.0
total_items = 0
for x_batch, y_batch in loader:
x_batch = x_batch.to(DEVICE)
y_batch = y_batch.to(DEVICE)
# Multiple stochastic forward passes capture epistemic uncertainty.
preds = []
for _ in range(mc_samples):
preds.append(model(x_batch, sample=True))
pred_mean_scaled = torch.stack(preds, dim=0).mean(dim=0)
# Convert from normalized target space back to physical units.
pred_mean = pred_mean_scaled * y_std_device + y_mean_device
y_true = y_batch * y_std_device + y_mean_device
total_sq += ((pred_mean - y_true) ** 2).sum().item()
total_items += y_true.numel()
return (total_sq / total_items) ** 0.5
@torch.no_grad()
def posterior_predictive(model: nn.Module, x_scaled: torch.Tensor, n_samples: int = 300):
"""Return posterior predictive mean/std at query points."""
model.eval()
preds = []
for _ in range(n_samples):
preds.append(model(x_scaled, sample=True))
preds = torch.stack(preds, dim=0)
mean_scaled = preds.mean(dim=0)
std_scaled = preds.var(dim=0, unbiased=False).sqrt()
mean = mean_scaled * y_std_device + y_mean_device
std = std_scaled * y_std_device
return mean.cpu(), std.cpu()
6) Evaluation Utilities¶
These helpers keep training readable and consistent:
evaluate_elbo(...): reports averaged ELBO/NLL/KL.evaluate_rmse(...): computes RMSE in original target units.posterior_predictive(...): returns predictive mean and epistemic std.
All utilities rely on Monte Carlo forward passes for stable uncertainty estimates.
# ----------------------------------------------------------------------------
# Part 1 training loop
# Stage 1: deterministic warmup (`sample=False`) for stable fit initialization.
# Stage 2: stochastic VI optimization with ELBO from deepuq `vi_elbo_step`.
# We keep 1000 VI epochs as requested and restore the best val-RMSE checkpoint.
# ----------------------------------------------------------------------------
set_seed(13)
model = ComplexBayesianRegressor(
input_dim=x_train.shape[1],
width=MODEL_WIDTH,
depth=MODEL_DEPTH,
prior_sigma=MODEL_PRIOR_SIGMA,
).to(DEVICE)
initialize_bayesian_posteriors(model, rho_init=-4.5)
# MSE handles regression data-fit (NLL proxy in normalized target space).
criterion = nn.MSELoss(reduction='mean')
# AdamW with mild weight decay and gradient clipping improves stability.
optimizer = optim.AdamW(model.parameters(), lr=4e-4, weight_decay=1e-5)
# Training policy
if NOTEBOOK_QUICK_MODE:
warmup_epochs = 10
vi_epochs = 80
eval_mc_elbo = 2
eval_mc_rmse = 8
else:
warmup_epochs = 60
vi_epochs = 1000
eval_mc_elbo = 6
eval_mc_rmse = 32
kl_weight = 1e-4
num_batches = len(train_loader)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=vi_epochs, eta_min=1e-4)
# History dictionary used by downstream plots/diagnostics.
history = {
'epoch': [],
'stage': [],
'train_elbo': [],
'val_elbo': [],
'train_nll': [],
'val_nll': [],
'train_kl': [],
'val_kl': [],
'train_rmse': [],
'val_rmse': [],
}
# Track best validation checkpoint; final model is restored to this state.
best_val_rmse = float('inf')
best_epoch = 0
best_state = copy.deepcopy(model.state_dict())
# ------------------------------------------------------------------------
# Stage 1: deterministic warmup (no weight sampling in forward pass)
# ------------------------------------------------------------------------
for epoch_idx in range(warmup_epochs):
model.train()
for x_batch, y_batch in train_loader:
x_batch = x_batch.to(DEVICE)
y_batch = y_batch.to(DEVICE)
optimizer.zero_grad(set_to_none=True)
pred = model(x_batch, sample=False)
loss = criterion(pred, y_batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
# Evaluate after each epoch using MC-averaged metrics.
train_elbo, train_nll, train_kl = evaluate_elbo(model, train_loader, criterion, kl_weight, num_batches, mc_samples=eval_mc_elbo)
val_elbo, val_nll, val_kl = evaluate_elbo(model, val_loader, criterion, kl_weight, num_batches, mc_samples=eval_mc_elbo)
train_rmse = evaluate_rmse(model, train_loader, mc_samples=eval_mc_rmse)
val_rmse = evaluate_rmse(model, val_loader, mc_samples=eval_mc_rmse)
if val_rmse < best_val_rmse:
best_val_rmse = val_rmse
best_epoch = epoch_idx + 1
best_state = copy.deepcopy(model.state_dict())
history['epoch'].append(epoch_idx + 1)
history['stage'].append('warmup')
history['train_elbo'].append(train_elbo)
history['val_elbo'].append(val_elbo)
history['train_nll'].append(train_nll)
history['val_nll'].append(val_nll)
history['train_kl'].append(train_kl)
history['val_kl'].append(val_kl)
history['train_rmse'].append(train_rmse)
history['val_rmse'].append(val_rmse)
if (epoch_idx + 1) % 20 == 0:
print(
f"Warmup {epoch_idx + 1:03d}/{warmup_epochs} | "
f"val_rmse={val_rmse:.4f} | best_val_rmse={best_val_rmse:.4f} | "
f"train_elbo={train_elbo:.4f}"
)
# ------------------------------------------------------------------------
# Stage 2: variational training with stochastic forward passes + ELBO
# ------------------------------------------------------------------------
for vi_idx in range(vi_epochs):
model.train()
for x_batch, y_batch in train_loader:
x_batch = x_batch.to(DEVICE)
y_batch = y_batch.to(DEVICE)
optimizer.zero_grad(set_to_none=True)
# Main Bayes-by-Backprop objective call via DeepUQ helper.
loss, _, _ = vi_elbo_step_compat(
model,
x_batch,
y_batch,
num_batches=num_batches,
criterion=criterion,
kl_weight=kl_weight,
mc_samples=1,
)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
train_elbo, train_nll, train_kl = evaluate_elbo(model, train_loader, criterion, kl_weight, num_batches, mc_samples=eval_mc_elbo)
val_elbo, val_nll, val_kl = evaluate_elbo(model, val_loader, criterion, kl_weight, num_batches, mc_samples=eval_mc_elbo)
train_rmse = evaluate_rmse(model, train_loader, mc_samples=eval_mc_rmse)
val_rmse = evaluate_rmse(model, val_loader, mc_samples=eval_mc_rmse)
epoch_number = warmup_epochs + vi_idx + 1
if val_rmse < best_val_rmse:
best_val_rmse = val_rmse
best_epoch = epoch_number
best_state = copy.deepcopy(model.state_dict())
history['epoch'].append(epoch_number)
history['stage'].append('vi')
history['train_elbo'].append(train_elbo)
history['val_elbo'].append(val_elbo)
history['train_nll'].append(train_nll)
history['val_nll'].append(val_nll)
history['train_kl'].append(train_kl)
history['val_kl'].append(val_kl)
history['train_rmse'].append(train_rmse)
history['val_rmse'].append(val_rmse)
if (vi_idx + 1) % 50 == 0:
print(
f"VI {vi_idx + 1:04d}/{vi_epochs} | beta={kl_weight:.5f} | "
f"lr={optimizer.param_groups[0]['lr']:.6f} | "
f"val_rmse={val_rmse:.4f} | best_val_rmse={best_val_rmse:.4f} | "
f"train_elbo={train_elbo:.4f} val_elbo={val_elbo:.4f}"
)
# Restore the best checkpoint so downstream plots/results use best generalization.
model.load_state_dict(best_state)
print(f"Restored best checkpoint from epoch {best_epoch} with val RMSE={best_val_rmse:.4f}")
test_elbo, test_nll, test_kl = evaluate_elbo(model, test_loader, criterion, kl_weight, num_batches, mc_samples=8)
test_rmse = evaluate_rmse(model, test_loader, mc_samples=64)
print(f"Test ELBO={test_elbo:.4f} | Test NLL={test_nll:.4f} | Test KL={test_kl:.4f} | Test RMSE={test_rmse:.4f}")
Warmup 020/60 | val_rmse=0.0934 | best_val_rmse=0.0934 | train_elbo=1.8133 Warmup 040/60 | val_rmse=0.0894 | best_val_rmse=0.0878 | train_elbo=1.8122 Warmup 060/60 | val_rmse=0.0894 | best_val_rmse=0.0878 | train_elbo=1.8118 VI 0050/1000 | beta=0.00010 | lr=0.000398 | val_rmse=0.0970 | best_val_rmse=0.0878 | train_elbo=1.7521 val_elbo=1.7537 VI 0100/1000 | beta=0.00010 | lr=0.000393 | val_rmse=0.0880 | best_val_rmse=0.0878 | train_elbo=1.6895 val_elbo=1.6918 VI 0150/1000 | beta=0.00010 | lr=0.000384 | val_rmse=0.0899 | best_val_rmse=0.0878 | train_elbo=1.6305 val_elbo=1.6322 VI 0200/1000 | beta=0.00010 | lr=0.000371 | val_rmse=0.0891 | best_val_rmse=0.0878 | train_elbo=1.5735 val_elbo=1.5749 VI 0250/1000 | beta=0.00010 | lr=0.000356 | val_rmse=0.0917 | best_val_rmse=0.0878 | train_elbo=1.5194 val_elbo=1.5211 VI 0300/1000 | beta=0.00010 | lr=0.000338 | val_rmse=0.0912 | best_val_rmse=0.0878 | train_elbo=1.4679 val_elbo=1.4696 VI 0350/1000 | beta=0.00010 | lr=0.000318 | val_rmse=0.0907 | best_val_rmse=0.0878 | train_elbo=1.4201 val_elbo=1.4217 VI 0400/1000 | beta=0.00010 | lr=0.000296 | val_rmse=0.0950 | best_val_rmse=0.0878 | train_elbo=1.3762 val_elbo=1.3788 VI 0450/1000 | beta=0.00010 | lr=0.000273 | val_rmse=0.0907 | best_val_rmse=0.0874 | train_elbo=1.3345 val_elbo=1.3360 VI 0500/1000 | beta=0.00010 | lr=0.000250 | val_rmse=0.0984 | best_val_rmse=0.0874 | train_elbo=1.2981 val_elbo=1.2991 VI 0550/1000 | beta=0.00010 | lr=0.000227 | val_rmse=0.0956 | best_val_rmse=0.0874 | train_elbo=1.2654 val_elbo=1.2669 VI 0600/1000 | beta=0.00010 | lr=0.000204 | val_rmse=0.0895 | best_val_rmse=0.0874 | train_elbo=1.2349 val_elbo=1.2357 VI 0650/1000 | beta=0.00010 | lr=0.000182 | val_rmse=0.0888 | best_val_rmse=0.0874 | train_elbo=1.2086 val_elbo=1.2101 VI 0700/1000 | beta=0.00010 | lr=0.000162 | val_rmse=0.0894 | best_val_rmse=0.0874 | train_elbo=1.1859 val_elbo=1.1879 VI 0750/1000 | beta=0.00010 | lr=0.000144 | val_rmse=0.0893 | best_val_rmse=0.0874 | train_elbo=1.1662 val_elbo=1.1672 VI 0800/1000 | beta=0.00010 | lr=0.000129 | val_rmse=0.0905 | best_val_rmse=0.0874 | train_elbo=1.1492 val_elbo=1.1507 VI 0850/1000 | beta=0.00010 | lr=0.000116 | val_rmse=0.0893 | best_val_rmse=0.0874 | train_elbo=1.1339 val_elbo=1.1342 VI 0900/1000 | beta=0.00010 | lr=0.000107 | val_rmse=0.0886 | best_val_rmse=0.0874 | train_elbo=1.1193 val_elbo=1.1209 VI 0950/1000 | beta=0.00010 | lr=0.000102 | val_rmse=0.0892 | best_val_rmse=0.0874 | train_elbo=1.1058 val_elbo=1.1098 VI 1000/1000 | beta=0.00010 | lr=0.000100 | val_rmse=0.0890 | best_val_rmse=0.0874 | train_elbo=1.0938 val_elbo=1.0953 Restored best checkpoint from epoch 463 with val RMSE=0.0874 Test ELBO=1.3737 | Test NLL=0.0106 | Test KL=13631.0137 | Test RMSE=0.0840
7) Two-Stage Training Logic¶
Training uses a staged strategy:
- Warmup stage: deterministic (
sample=False) to quickly learn coarse fit. - VI stage: stochastic sampling with ELBO objective from
deepuq.
The loop also tracks and restores the best validation RMSE checkpoint so final results reflect best generalization, not just the last epoch.
# ----------------------------------------------------------------------------
# Part 1 training curves
# Subplot 1: ELBO (raw + EMA)
# Subplot 2: RMSE (raw + EMA)
# Subplot 3: ELBO components (NLL and KL)
# Vertical dashed line = switch from warmup to VI stage.
# ----------------------------------------------------------------------------
epochs = history['epoch']
train_ema = ema(history['train_elbo'], alpha=0.2)
val_ema = ema(history['val_elbo'], alpha=0.2)
train_rmse_ema = ema(history['train_rmse'], alpha=0.2)
val_rmse_ema = ema(history['val_rmse'], alpha=0.2)
fig, axes = plt.subplots(1, 3, figsize=(17, 4))
axes[0].plot(epochs, history['train_elbo'], color='tab:blue', alpha=0.25, label='Train ELBO (raw)')
axes[0].plot(epochs, history['val_elbo'], color='tab:orange', alpha=0.25, label='Val ELBO (raw)')
axes[0].plot(epochs, train_ema, color='tab:blue', label='Train ELBO (EMA)')
axes[0].plot(epochs, val_ema, color='tab:orange', label='Val ELBO (EMA)')
axes[0].axvline(warmup_epochs, color='tab:gray', linestyle='--', linewidth=1.0)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('ELBO')
axes[0].set_title('ELBO trend')
axes[0].grid(True, linestyle='--', linewidth=0.5, alpha=0.4)
axes[0].legend(frameon=False)
axes[1].plot(epochs, history['train_rmse'], color='tab:green', alpha=0.25, label='Train RMSE (raw)')
axes[1].plot(epochs, history['val_rmse'], color='tab:red', alpha=0.25, label='Val RMSE (raw)')
axes[1].plot(epochs, train_rmse_ema, color='tab:green', label='Train RMSE (EMA)')
axes[1].plot(epochs, val_rmse_ema, color='tab:red', label='Val RMSE (EMA)')
axes[1].axvline(warmup_epochs, color='tab:gray', linestyle='--', linewidth=1.0)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('RMSE')
axes[1].set_title('Fit quality')
axes[1].grid(True, linestyle='--', linewidth=0.5, alpha=0.4)
axes[1].legend(frameon=False)
axes[2].plot(epochs, history['train_nll'], color='tab:purple', label='Train NLL')
axes[2].plot(epochs, history['val_nll'], color='tab:brown', label='Val NLL')
axes[2].plot(epochs, history['train_kl'], color='tab:olive', label='Train KL / num_batches')
axes[2].axvline(warmup_epochs, color='tab:gray', linestyle='--', linewidth=1.0)
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Value')
axes[2].set_title('ELBO components')
axes[2].grid(True, linestyle='--', linewidth=0.5, alpha=0.4)
axes[2].legend(frameon=False)
plt.tight_layout()
plt.show()
8) How to Read the Curves¶
- ELBO typically trends downward but can have local noise.
- RMSE is easier to interpret for fit quality in original units.
- KL curve shows regularization pressure from Bayesian posterior complexity.
The dashed line marks the transition from warmup to VI training.
# ----------------------------------------------------------------------------
# Part 1 diagnostics summary
# - Checks whether VI-stage EMA ELBO trends downward.
# - Reports RMSE EMA delta and best validation checkpoint stats.
# ----------------------------------------------------------------------------
vi_start = warmup_epochs
ema_train_elbo = ema(history['train_elbo'][vi_start:], alpha=0.2)
delta = ema_train_elbo[-1] - ema_train_elbo[0]
if len(ema_train_elbo) > 1:
decreasing_ratio = float((np.diff(np.array(ema_train_elbo)) <= 0.0).mean())
else:
decreasing_ratio = 1.0
delta_ok = delta < 0.0
ratio_ok = decreasing_ratio >= 0.65
ema_val_rmse = ema(history['val_rmse'], alpha=0.2)
rmse_delta = ema_val_rmse[-1] - ema_val_rmse[0]
print(f"VI-stage EMA train ELBO delta (last-first): {delta:.6f}")
print(f"VI-stage EMA train ELBO non-increase ratio: {decreasing_ratio:.3f}")
print(f"Criterion delta < 0.0: {'PASS' if delta_ok else 'FAIL'}")
print(f"Criterion ratio >= 0.65: {'PASS' if ratio_ok else 'FAIL'}")
print(f"Val RMSE EMA delta (last-first): {rmse_delta:.6f}")
print(f"Best val RMSE during training: {best_val_rmse:.4f} at epoch {best_epoch}")
print(f"Last logged val RMSE: {history['val_rmse'][-1]:.4f}")
VI-stage EMA train ELBO delta (last-first): -0.715862 VI-stage EMA train ELBO non-increase ratio: 1.000 Criterion delta < 0.0: PASS Criterion ratio >= 0.65: PASS Val RMSE EMA delta (last-first): -0.880727 Best val RMSE during training: 0.0874 at epoch 463 Last logged val RMSE: 0.0890
9) Diagnostic Checks¶
This diagnostic cell provides quick numeric checks on trend quality:
- ELBO EMA delta from start to end of VI stage.
- Fraction of VI epochs where EMA did not increase.
- RMSE trend summary plus best checkpoint info.
Use this as a sanity check when changing architecture or training hyperparameters.
Predictive Distribution¶
Posterior predictive mean and 95% epistemic interval on a dense input grid.
# ----------------------------------------------------------------------------
# Part 1 posterior predictive plot
# - Draw MC predictions at a dense x-grid.
# - Plot mean and 95% interval from epistemic uncertainty.
# - Compare against noisy samples and true function.
# ----------------------------------------------------------------------------
x_grid_raw = torch.linspace(-3.0, 3.0, 300).unsqueeze(-1)
y_grid_true = target_fn(x_grid_raw)
x_grid_scaled = scale_x(x_grid_raw).to(DEVICE)
pred_mean, pred_std = posterior_predictive(model, x_grid_scaled, n_samples=300)
lower = pred_mean - 1.96 * pred_std
upper = pred_mean + 1.96 * pred_std
plt.figure(figsize=(7, 4))
plt.scatter(x_train_raw.squeeze().numpy(), y_train_raw.squeeze().numpy(), s=10, alpha=0.3, label='Train noisy')
plt.scatter(x_test_raw.squeeze().numpy(), y_test_raw.squeeze().numpy(), s=10, alpha=0.25, label='Test noisy')
plt.plot(x_grid_raw.squeeze().numpy(), y_grid_true.squeeze().numpy(), color='black', linewidth=2, label='True function')
plt.plot(x_grid_raw.squeeze().numpy(), pred_mean.squeeze().numpy(), color='tab:blue', label='Predictive mean')
plt.fill_between(
x_grid_raw.squeeze().numpy(),
lower.squeeze().numpy(),
upper.squeeze().numpy(),
color='tab:blue',
alpha=0.2,
label='95% epistemic interval',
)
plt.xlabel('x')
plt.ylabel('y')
plt.title('Bayes-by-Backprop posterior predictive')
plt.legend(frameon=False)
plt.tight_layout()
plt.show()
10) Posterior Predictive Interpretation¶
In the predictive plot:
- Mean line represents expected prediction.
- Shaded interval (
±1.96σ) represents epistemic uncertainty band. - Compare how the interval behaves where data density is high vs low.