API Reference¶
Welcome to the Laplax API reference.
Package Overview¶
The laplax
package contains a high-level API, that is designed to be used out-of-the box, and a modular low-level API, which provides exposes all essential building blocks for the high-level API and can be used for fast experimentations. The low-level API is organized into the following modules:
laplax.curv
: Tools for computing and approximating curvature informationlaplax.eval
: Evaluation metrics and utilities for assessing predictive uncertaintylaplax.util
: Various utilities for working with PyTrees, DataLoaders, and other common utilities.
Main design decisions¶
Model function signature¶
laplax
operates by taking an arbitrary model_fn
with (key-word) signature model_fn(input, params)
. This allows for a wide range of JAX-based neural network libraries to be used. For flax.nnx
and equinox
, this would look like: