Home
Laplace Approximations in JAX.
The laplax package provides a performant, minimal, and practical implementation of Laplace approximation techniques in jax. The package is designed to support a wide range of scientific libraries, initially focusing on compatibility with popular neural network libraries such as equinox, flax.linen, and flax.nnx. Our goal is to create a flexible tool for both practical applications and research, enabling rapid iteration and comparison of new approaches.
Installation¶
Use pip install laplax.
Minimal example¶
The following tiny laplax example shows how to use the laplax package to perform a linearized Laplace approximation on a two-parameter ReLU network \(\mathcal{D}(x,\theta)=\theta_2\,\text{ReLU}(\theta_1 x+1)\) for \(\mathcal{D}=\{(1,-1),(-1,-1)\}\) and visualize the weight space uncertainty in the loss landscape.
Gray contours: energy with square loss; black dot: optimum \(\theta^*\); green ellipses: \(1\sigma\) and \(2\sigma\) levels of the Laplace approximation.
from jax.nn import relu
from jax.numpy import array
from laplax import laplace
from plotting import plot_figure_1
# You need a model...
def model_fn(input, params):
return relu(params["theta1"] * input - 1) * params["theta2"]
params = { # optimized weights,
"theta1": array(1.6556547), "theta2": array(1.0420421)
}
data = { # and training data.
"input": array([1., -1.]), "target": array([1., -1.])
}
# ... then apply laplax ...
posterior_fn, _ = laplace(
model_fn, params, data, loss_fn="mse", curv_type="full",
)
curv = posterior_fn({"prior_prec": 0.2}).state['scale']
# ... to get Figure 1.
plot_figure_1(model_fn, params, curv)
Overview¶
We provide a high-level interface for performing Laplace approximation on a model and expose additional is low-level building blocks. As working examples, we include both approaches as tutorials:
- Tiny example (cf. plot above)
- Laplax for regression
- Laplax on MNIST
Both APIs and all available options are documented in the Manual. For each submodule, we provide a short overview as well as a comprehensive list of all available functions.
For a general starting point for details on Laplace approximation and our notations, we refer to the Background section. Most of the documentation follows our recent workshop paper.
Citation¶
If you use laplax in your research, please cite for now:
@software{laplax,
author = {Tobias Weber, Bálint Mucsányi, Lenard Rommel, Thomas Christie, Lars Kasüschke, Marvin Pförtner, Philipp Hennig},
title = {Laplax: Laplace Approximations in JAX},
year = {2025},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/laplax-org/laplax}},
}