Skip to content

Laplace API

This page documents the public Laplace approximation wrapper used throughout Deep-UQ. The notes below explain the public control variables and workflow; the generated section contains the exact signatures and source docstrings.

Public objects

  • LaplaceWrapper

Parameter and variable conventions

Name Meaning
likelihood either "regression" or "classification"
hessian_structure curvature backend: diag, fisher_diag, lowrank_diag, block_diag, kron, full
subset_of_weights "last_layer" or "all"
lowrank_rank target rank for lowrank_diag
damping numerical stabilization added to the precision approximation
full_max_params guardrail for dense full-Hessian fitting
prior_precision Gaussian prior precision used during fit(...)
predict_kwargs backend-specific predictive options forwarded unchanged

For hessian_structure="kron" and "full", Deep-UQ now prefers the legacy laplace-torch backend when that optional dependency is installed. Install it with:

pip install "uqdeepnn[laplace]"

Without that extra, those two structures fall back to the native Deep-UQ implementations.

Workflow expectations

  1. train the base model to a MAP solution with ordinary optimization
  2. construct LaplaceWrapper(model, ...)
  3. call fit(train_loader, prior_precision=...)
  4. call predict(...) or predict_uq(...)

predict_uq(...) cannot be called before fit(...).

Input and output shapes

  • fit(...) expects an iterable of (inputs, targets) minibatches compatible with the wrapped model.
  • regression predict(...) returns (mean, var) with the same trailing shape as one model prediction.
  • classification predict(...) returns (probs, probs_var_or_none) with shape [batch, n_classes].

UQResult mapping

predict_uq(...) returns:

  • regression: mean, epistemic_var, optional aleatoric_var, total_var
  • classification: probs, optional probs_var, and metadata describing the chosen backend

Common preconditions and failure modes

  • unsupported hessian_structure raises ValueError
  • full curvature over subset_of_weights="all" may be rejected if full_max_params is exceeded
  • calling predict(...) or predict_uq(...) before fit(...) raises RuntimeError
  • regression backends must return predictive variance; otherwise predict_uq(...) raises RuntimeError

Minimal example

la = LaplaceWrapper(
    model,
    likelihood="regression",
    hessian_structure="block_diag",
    subset_of_weights="last_layer",
)
la.fit(train_loader, prior_precision=10.0)
uq = la.predict_uq(x_test, n_samples=32)

deepuq.methods.laplace

LaplaceWrapper

Fit a Laplace approximation around a MAP-trained model.

Parameters:

Name Type Description Default
model Module

MAP-trained neural network to approximate locally with a Gaussian posterior.

required
likelihood str

Either "regression" or "classification". Controls predictive output interpretation.

'classification'
hessian_structure str

Curvature backend. Supported values are diag, fisher_diag, lowrank_diag, block_diag, kron, and full.

'diag'
subset_of_weights str

"last_layer" for a lightweight approximation around the last linear module, or "all" for all trainable parameters when the selected backend supports it.

'last_layer'
lowrank_rank int

Target rank for the lowrank_diag backend.

20
damping float

Numerical stabilization term added to precision approximations.

1e-06
full_max_params int

Safety guard for hessian_structure="full" with subset_of_weights="all".

20000

Examples:

>>> la = LaplaceWrapper(model, likelihood="classification", hessian_structure="diag")
>>> la.fit(train_loader, prior_precision=1.0)
>>> probs, probs_var = la.predict(x_test)

fit

fit(
    train_loader: Iterable,
    prior_precision: float | None = 1.0,
    **_
) -> object

Fit the selected Laplace backend.

Parameters:

Name Type Description Default
train_loader Iterable

Iterable of (inputs, targets) mini-batches used to accumulate curvature statistics around the MAP solution.

required
prior_precision float | None

Isotropic Gaussian prior precision. Higher values keep the approximation closer to the MAP point.

1.0

Returns:

Type Description
object

The concrete backend instance used internally.

Raises:

Type Description
ValueError

If full curvature is requested over too many parameters.

predict

predict(x: Tensor, **predict_kwargs)

Return the legacy predictive tuple from the fitted backend.

Parameters:

Name Type Description Default
x Tensor

Evaluation inputs.

required
**predict_kwargs

Forwarded to the backend predictive routine. Common options include sample counts for structured backends.

{}

predict_uq

predict_uq(x: Tensor, **predict_kwargs) -> UQResult

Return predictive moments in standardized UQResult form.

Parameters:

Name Type Description Default
x Tensor

Evaluation inputs.

required
**predict_kwargs

Forwarded to the backend predictive routine.

{}

Returns:

Type Description
UQResult

For regression, mean plus variance fields. For classification, probs and optional probs_var with regression variance fields left unset.

Raises:

Type Description
RuntimeError

If fit() has not been called or if a regression backend fails to produce predictive variance.

supported_hessian_structures staticmethod

supported_hessian_structures() -> tuple[str, ...]

Return the supported Hessian structure names.