Laplace API¶
This page documents the public Laplace approximation wrapper used throughout Deep-UQ. The notes below explain the public control variables and workflow; the generated section contains the exact signatures and source docstrings.
Public objects¶
LaplaceWrapper
Parameter and variable conventions¶
| Name | Meaning |
|---|---|
likelihood | either "regression" or "classification" |
hessian_structure | curvature backend: diag, fisher_diag, lowrank_diag, block_diag, kron, full |
subset_of_weights | "last_layer" or "all" |
lowrank_rank | target rank for lowrank_diag |
damping | numerical stabilization added to the precision approximation |
full_max_params | guardrail for dense full-Hessian fitting |
prior_precision | Gaussian prior precision used during fit(...) |
predict_kwargs | backend-specific predictive options forwarded unchanged |
For hessian_structure="kron" and "full", Deep-UQ now prefers the legacy laplace-torch backend when that optional dependency is installed. Install it with:
Without that extra, those two structures fall back to the native Deep-UQ implementations.
Workflow expectations¶
- train the base model to a MAP solution with ordinary optimization
- construct
LaplaceWrapper(model, ...) - call
fit(train_loader, prior_precision=...) - call
predict(...)orpredict_uq(...)
predict_uq(...) cannot be called before fit(...).
Input and output shapes¶
fit(...)expects an iterable of(inputs, targets)minibatches compatible with the wrapped model.- regression
predict(...)returns(mean, var)with the same trailing shape as one model prediction. - classification
predict(...)returns(probs, probs_var_or_none)with shape[batch, n_classes].
UQResult mapping¶
predict_uq(...) returns:
- regression:
mean,epistemic_var, optionalaleatoric_var,total_var - classification:
probs, optionalprobs_var, and metadata describing the chosen backend
Common preconditions and failure modes¶
- unsupported
hessian_structureraisesValueError fullcurvature oversubset_of_weights="all"may be rejected iffull_max_paramsis exceeded- calling
predict(...)orpredict_uq(...)beforefit(...)raisesRuntimeError - regression backends must return predictive variance; otherwise
predict_uq(...)raisesRuntimeError
Minimal example¶
la = LaplaceWrapper(
model,
likelihood="regression",
hessian_structure="block_diag",
subset_of_weights="last_layer",
)
la.fit(train_loader, prior_precision=10.0)
uq = la.predict_uq(x_test, n_samples=32)
Related docs¶
deepuq.methods.laplace ¶
LaplaceWrapper ¶
Fit a Laplace approximation around a MAP-trained model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model | Module | MAP-trained neural network to approximate locally with a Gaussian posterior. | required |
likelihood | str | Either | 'classification' |
hessian_structure | str | Curvature backend. Supported values are | 'diag' |
subset_of_weights | str |
| 'last_layer' |
lowrank_rank | int | Target rank for the | 20 |
damping | float | Numerical stabilization term added to precision approximations. | 1e-06 |
full_max_params | int | Safety guard for | 20000 |
Examples:
>>> la = LaplaceWrapper(model, likelihood="classification", hessian_structure="diag")
>>> la.fit(train_loader, prior_precision=1.0)
>>> probs, probs_var = la.predict(x_test)
fit ¶
Fit the selected Laplace backend.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
train_loader | Iterable | Iterable of | required |
prior_precision | float | None | Isotropic Gaussian prior precision. Higher values keep the approximation closer to the MAP point. | 1.0 |
Returns:
| Type | Description |
|---|---|
object | The concrete backend instance used internally. |
Raises:
| Type | Description |
|---|---|
ValueError | If |
predict ¶
Return the legacy predictive tuple from the fitted backend.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Evaluation inputs. | required |
**predict_kwargs | Forwarded to the backend predictive routine. Common options include sample counts for structured backends. | {} |
predict_uq ¶
Return predictive moments in standardized UQResult form.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Evaluation inputs. | required |
**predict_kwargs | Forwarded to the backend predictive routine. | {} |
Returns:
| Type | Description |
|---|---|
UQResult | For regression, |
Raises:
| Type | Description |
|---|---|
RuntimeError | If |
supported_hessian_structures staticmethod ¶
Return the supported Hessian structure names.