Tiny illustration of Laplace approximations¶
This script is a super tiny illustration of a Laplace approximation - one where curvature approximation is tractable and can be easy visualised
import jax.numpy as jnp
from jax.nn import relu
from plotting import plot_figure_1
from laplax import laplace
# You need optimized parameters,
best_params = {"theta1": jnp.array(1.6546547), "theta2": jnp.array(1.0420421)}
def model_fn(input, params):
return relu(params["theta1"] * input - 1) * params["theta2"]
data = { # and training data.
"input": jnp.array([1.0, -1.0]).reshape(2, 1),
"target": jnp.array([1.0, -1.0]).reshape(2, 1),
}
# Then apply laplax
posterior_fn, _ = laplace(
model_fn,
best_params,
data,
loss_fn="mse",
curv_type="full",
)
curv = posterior_fn({"prior_prec": 0.2}).state["scale"]
# to get figure 1.
plot_figure_1(best_params, curv, save_fig=False)
[32m2025-11-18 16:22:50.894[0m | [34m[1mDEBUG [0m | [36mlaplax.api[0m:[36mlaplace[0m:[36m669[0m - [34m[1mCreating curvature MV - factor = 1/1 = 1.0[0m
[32m2025-11-18 16:22:50.895[0m | [34m[1mDEBUG [0m | [36mlaplax.api[0m:[36m_maybe_wrap_loader_or_batch[0m:[36m179[0m - [34m[1mUsing *single batch* curvature evaluation.[0m
[32m2025-11-18 16:22:51.272[0m | [34m[1mDEBUG [0m | [36mlaplax.api[0m:[36mlaplace[0m:[36m695[0m - [34m[1mCurvature estimated: full[0m
[32m2025-11-18 16:22:51.272[0m | [34m[1mDEBUG [0m | [36mlaplax.api[0m:[36mlaplace[0m:[36m704[0m - [34m[1mPosterior callable constructed.[0m
/home/runner/work/laplax/laplax/examples/plotting.py:545: UserWarning: No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
ax.legend()
(<Figure size 325x200.861 with 1 Axes>,
<Axes: xlabel='$\\theta_1$', ylabel='$\\theta_2$'>)
