Skip to content

Laplax API

GGN

Create a GGN matrix-vector product function.

Parameters:

Name Type Description Default
model_fn ModelFn

Neural network forward pass.

required
params Params

Network parameters.

required
data Data | Iterable

Training data.

required
loss_fn LossFn

Loss function to use.

required
factor float

Scaling factor for GGN.

1.0
vmap_over_data bool

Whether model expects batch dimension.

True
verbose_logging bool

Whether to enable verbose logging.

True
transform Callable | None

Transform to apply to data.

None

Returns:

Type Description
Callable[[Params], Params]

GGN matrix-vector product function.

Raises:

Type Description
ValueError

If input/output shapes don't match.

Source code in laplax/api.py
def GGN(
    model_fn: ModelFn,
    params: Params,
    data: Data | Iterable,
    loss_fn: LossFn,
    *,
    factor: float = 1.0,
    vmap_over_data: bool = True,
    verbose_logging: bool = True,
    transform: Callable | None = None,
) -> Callable[[Params], Params]:
    """Create a GGN matrix-vector product function.

    Args:
        model_fn: Neural network forward pass.
        params: Network parameters.
        data: Training data.
        loss_fn: Loss function to use.
        factor: Scaling factor for GGN.
        vmap_over_data: Whether model expects batch dimension.
        verbose_logging: Whether to enable verbose logging.
        transform: Transform to apply to data.

    Returns:
        GGN matrix-vector product function.

    Raises:
        ValueError: If input/output shapes don't match.
    """
    ggn_mv = create_ggn_mv_without_data(  # type: ignore[call-arg]
        model_fn=model_fn,
        params=params,
        loss_fn=loss_fn,
        factor=factor,
        vmap_over_data=vmap_over_data,
    )

    mv_bound = _maybe_wrap_loader_or_batch(
        ggn_mv,
        data,
        transform=transform,
        loader_kwargs={
            "verbose_logging": verbose_logging,
        },
    )

    test = mv_bound(params)
    if not jax.tree.all(
        jax.tree.map(lambda x, y: x.shape == y.shape, test, params),
    ):
        msg = "Setup of GGN-MV failed: input and output shapes do not match."
        raise ValueError(msg)

    return mv_bound

laplace

Estimate curvature & obtain a Gaussian weight-space posterior.

This function computes a Laplace approximation to the posterior distribution over neural network weights. It estimates the curvature of the loss landscape and constructs a Gaussian approximation centered at the MAP estimate.

Parameters:

Name Type Description Default
model_fn ModelFn

The neural network forward pass function that takes input and parameters.

required
params Params

The MAP estimate of the network parameters.

required
data Data | Iterable

Either a single batch (tuple/dict) or a DataLoader-like iterable containing the training data.

required
loss_fn LossFn

The supervised loss function to use (e.g., "mse" for regression).

required
curv_type CurvApprox

Type of curvature approximation to use (e.g., "ggn", "diag-ggn").

required
num_curv_samples Int

Number of Monte Carlo samples used to estimate the GGN, by default 1.

1
num_total_samples Int

Total number of samples in the dataset, by default 1.

1
vmap_over_data bool

Whether the model expects a leading batch axis, by default True.

True
curv_mv_jit bool

Whether to jit the curvature matrix-vector product, by default False.

False
**curv_kwargs Kwargs

Additional arguments forwarded to the curvature estimation function.

{}

Returns:

Type Description
tuple[Callable[[PriorArguments, Float], Posterior], PyTree]

A tuple containing:

  • posterior_fn: Function that generates samples from the posterior given prior arguments.
  • curv_estimate: The estimated curvature in the chosen representation.
Notes

The function supports different curvature approximations:

  • Full GGN: Computes the full Generalized Gauss-Newton matrix
  • Diagonal GGN: Approximates the GGN with its diagonal
  • Low-rank GGN: Uses Lanczos or LOBPCG for efficient approximation
Source code in laplax/api.py
def laplace(
    model_fn: ModelFn,
    params: Params,
    data: Data | Iterable,
    *,
    loss_fn: LossFn,
    curv_type: CurvApprox,
    num_curv_samples: Int = 1,
    num_total_samples: Int = 1,
    vmap_over_data: bool = True,
    curv_mv_jit: bool = False,
    **curv_kwargs: Kwargs,
) -> tuple[Callable[[PriorArguments, Float], Posterior], PyTree]:
    """Estimate curvature & obtain a Gaussian weight-space posterior.

    This function computes a Laplace approximation to the posterior distribution over
    neural network weights. It estimates the curvature of the loss landscape and
    constructs a Gaussian approximation centered at the MAP estimate.

    Args:
        model_fn: The neural network forward pass function that takes input and
            parameters.
        params: The MAP estimate of the network parameters.
        data: Either a single batch (tuple/dict) or a DataLoader-like iterable
            containing the training data.
        loss_fn: The supervised loss function to use (e.g., "mse" for regression).
        curv_type: Type of curvature approximation to use (e.g., "ggn", "diag-ggn").
        num_curv_samples: Number of Monte Carlo samples used to estimate the GGN, by
            default 1.
        num_total_samples: Total number of samples in the dataset, by default 1.
        vmap_over_data: Whether the model expects a leading batch axis, by default True.
        curv_mv_jit: Whether to jit the curvature matrix-vector product, by default
            False.
        **curv_kwargs: Additional arguments forwarded to the curvature estimation
            function.

    Returns:
        A tuple containing:

            - posterior_fn: Function that generates samples from the posterior given
                prior arguments.
            - curv_estimate: The estimated curvature in the chosen representation.

    Notes:
        The function supports different curvature approximations:

        - Full GGN: Computes the full Generalized Gauss-Newton matrix
        - Diagonal GGN: Approximates the GGN with its diagonal
        - Low-rank GGN: Uses Lanczos or LOBPCG for efficient approximation
    """
    # Convert curv_type to enum
    curv_type_enum = _convert_to_enum(CurvApprox, curv_type)

    # Calculate factor
    factor = float(num_curv_samples) / float(num_total_samples)
    logger.debug(
        "Creating curvature MV - factor = {}/{} = {}",
        num_curv_samples,
        num_total_samples,
        factor,
    )

    # Set GGN MV
    ggn_mv = GGN(
        model_fn,
        params,
        data,
        loss_fn=loss_fn,
        factor=factor,
        vmap_over_data=vmap_over_data,
    )
    if curv_mv_jit:
        ggn_mv = jax.jit(ggn_mv)

    # Curvature estimation
    curv_estimate = estimate_curvature(
        curv_type=curv_type_enum,
        mv=ggn_mv,
        layout=params,
        **curv_kwargs,
    )
    logger.debug("Curvature estimated: {}", curv_type_enum)

    # Posterior (Gaussian)
    posterior_fn = set_posterior_fn(
        curv_type=curv_type_enum,
        curv_estimate=curv_estimate,
        layout=params,
        **curv_kwargs,
    )
    logger.debug("Posterior callable constructed.")

    return posterior_fn, curv_estimate

calibration

Calibrate hyperparameters of the Laplace approximation.

This function tunes the prior precision (or similar hyperparameters) of the Laplace approximation by optimizing a specified objective function. It supports different calibration objectives and methods.

Parameters:

Name Type Description Default
posterior_fn Callable[[PriorArguments, Float], Posterior]

Function that generates samples from the posterior.

required
model_fn ModelFn

The neural network forward pass function.

required
params Params

The MAP estimate of the network parameters.

required
data Data

The validation data used for calibration.

required
loss_fn LossFn

The supervised loss function used for training.

required
curv_estimate PyTree

The estimated curvature from the Laplace approximation.

required
curv_type CurvApprox

Type of curvature approximation used.

required
predictive_type Predictive | str

Type of predictive distribution to use, by default Predictive.NONE.

NONE
pushforward_type Pushforward | str

Type of pushforward approximation to use, by default Pushforward.LINEAR.

LINEAR
pushforward_fns list[Callable] | None

Custom pushforward functions to use, by default None.

None
sample_key KeyType

PRNG key.

DEFAULT_KEY
num_samples int

Number of MC samples for the predictive.

30
calibration_objective CalibrationObjective | str

Objective function to optimize during calibration, by default CalibrationObjective.NLL.

NLL
calibration_method CalibrationMethod | str

Method to use for calibration, by default CalibrationMethod.GRID_SEARCH.

GRID_SEARCH
vmap_over_data bool

Whether the model expects a leading batch axis, by default True.

True
objective_jit bool

Whether to jit the calibration objective, by default True.

True
**calibration_kwargs Kwargs

Additional arguments for the calibration method.

{}

Returns:

Type Description
tuple[PriorArguments, Callable[[InputArray], dict[str, Array]]]

A tuple containing:

  • prior_arguments : PriorArguments Dictionary of calibrated hyperparameters.
  • set_prob_predictive : Callable Function that creates a predictive distribution given prior arguments.

Raises:

Type Description
ValueError

When an unknown calibration method is provided.

Notes

Supported calibration objectives:

  • NLL: Negative log-likelihood
  • CHI_SQUARED: Chi-squared statistic
  • MARGINAL_LOG_LIKELIHOOD: Marginal log-likelihood
  • ECE: Expected Calibration Error

Supported calibration methods:

  • GRID_SEARCH: Grid search over prior precision
Source code in laplax/api.py
def calibration(
    posterior_fn: Callable[[PriorArguments, Float], Posterior],
    model_fn: ModelFn,
    params: Params,
    data: Data,
    *,
    loss_fn: LossFn,
    curv_estimate: PyTree,
    curv_type: CurvApprox,
    predictive_type: Predictive | str = Predictive.NONE,
    pushforward_type: Pushforward | str = Pushforward.LINEAR,
    pushforward_fns: list[Callable] | None = None,
    sample_key: KeyType = DEFAULT_KEY,
    num_samples: int = 30,
    calibration_objective: CalibrationObjective | str = CalibrationObjective.NLL,
    calibration_method: CalibrationMethod | str = CalibrationMethod.GRID_SEARCH,
    vmap_over_data: bool = True,
    objective_jit: bool = True,
    **calibration_kwargs: Kwargs,
) -> tuple[PriorArguments, Callable[[InputArray], dict[str, Array]]]:
    """Calibrate hyperparameters of the Laplace approximation.

    This function tunes the prior precision (or similar hyperparameters) of the Laplace
    approximation by optimizing a specified objective function. It supports different
    calibration objectives and methods.

    Args:
        posterior_fn: Function that generates samples from the posterior.
        model_fn: The neural network forward pass function.
        params: The MAP estimate of the network parameters.
        data: The validation data used for calibration.
        loss_fn: The supervised loss function used for training.
        curv_estimate: The estimated curvature from the Laplace approximation.
        curv_type: Type of curvature approximation used.
        predictive_type: Type of predictive distribution to use, by default
            Predictive.NONE.
        pushforward_type: Type of pushforward approximation to use, by default
            Pushforward.LINEAR.
        pushforward_fns: Custom pushforward functions to use, by default None.
        sample_key: PRNG key.
        num_samples: Number of MC samples for the predictive.
        calibration_objective: Objective function to optimize during calibration, by
            default CalibrationObjective.NLL.
        calibration_method: Method to use for calibration, by default
            CalibrationMethod.GRID_SEARCH.
        vmap_over_data: Whether the model expects a leading batch axis, by default True.
        objective_jit: Whether to jit the calibration objective, by default True.
        **calibration_kwargs: Additional arguments for the calibration method.

    Returns:
        A tuple containing:

            - prior_arguments : PriorArguments
                Dictionary of calibrated hyperparameters.
            - set_prob_predictive : Callable
                Function that creates a predictive distribution given prior arguments.

    Raises:
        ValueError: When an unknown calibration method is provided.

    Notes:
        Supported calibration objectives:

        - NLL: Negative log-likelihood
        - CHI_SQUARED: Chi-squared statistic
        - MARGINAL_LOG_LIKELIHOOD: Marginal log-likelihood
        - ECE: Expected Calibration Error

        Supported calibration methods:

        - GRID_SEARCH: Grid search over prior precision
    """
    # If task is classification, then no NLL objective
    is_classification = predictive_type != Predictive.NONE

    # Pushforward construction
    set_pushforward, pushforward_fns = _setup_pushforward(
        pushforward_type=pushforward_type,
        predictive_type=predictive_type,
        pushforward_fns=pushforward_fns,
    )

    set_prob_predictive = partial(
        set_pushforward,
        model_fn=model_fn,
        mean_params=params,
        posterior_fn=posterior_fn,
        pushforward_fns=pushforward_fns,
        key=sample_key,
        num_samples=num_samples,
    )

    # Calibration objective & optimisation
    objective_fn = _build_calibration_objective(
        objective_type=calibration_objective,
        set_prob_predictive=set_prob_predictive,
        curv_estimate=curv_estimate,
        model_fn=model_fn,
        params=params,
        loss_fn=loss_fn,
        curv_type=curv_type,
        vmap_over_data=vmap_over_data,
        is_classification=is_classification,
    )

    calibration_method = _convert_to_enum(
        CalibrationMethod, calibration_method, str_default=True
    )

    if calibration_method == CalibrationMethod.GRID_SEARCH:
        # Get default values if not provided
        log_prior_prec_min = calibration_kwargs.get("log_prior_prec_min", -3.0)
        log_prior_prec_max = calibration_kwargs.get("log_prior_prec_max", 3.0)
        grid_size = calibration_kwargs.get("grid_size", 50)
        patience = calibration_kwargs.get("patience")

        # Transform calibration batch to {"input": ..., "target": ...}
        data = _validate_and_get_transform(data)(data)

        logger.debug(
            "Starting calibration with objective {} on grid [{}, {}] ({} pts, pat={})",
            calibration_objective,
            log_prior_prec_min,
            log_prior_prec_max,
            grid_size,
            patience,
        )

        def objective(x):
            return objective_fn(x, data)

        if objective_jit:
            objective = jax.jit(objective)

        prior_prec = calibration_options[calibration_method](
            objective=objective,
            log_prior_prec_min=log_prior_prec_min,
            log_prior_prec_max=log_prior_prec_max,
            grid_size=grid_size,
            patience=patience,
        )
        prior_args = {"prior_prec": prior_prec}

    elif calibration_method in calibration_options:
        data = _validate_and_get_transform(data)(data)

        if objective_jit:
            objective_fn = jax.jit(objective_fn)

        prior_args = calibration_options[calibration_method](
            objective=objective_fn,
            data=data,
            **calibration_kwargs,
        )
    else:
        msg = f"Unknown calibration method: {calibration_method}"
        raise ValueError(msg)
    logger.debug("Calibrated prior args = {}", prior_args)

    return prior_args, set_prob_predictive

evaluation

Evaluate the calibrated Laplace approximation.

This function assesses the performance of the calibrated Laplace approximation by computing various metrics on the test data. It supports both regression and classification tasks with different predictive distributions.

Parameters:

Name Type Description Default
posterior_fn Callable[[PriorArguments, Float], Posterior]

Function that generates samples from the posterior.

required
model_fn ModelFn

The neural network forward pass function.

required
params Params

The MAP estimate of the network parameters.

required
arguments PriorArguments

The calibrated prior arguments.

required
data Data | Iterator[Data]

The test data for evaluation.

required
metrics DefaultMetrics | list[Callable] | Callable | str

Metrics to compute during evaluation, by default DefaultMetrics.REGRESSION.

REGRESSION
predictive_type Predictive | str

Type of predictive distribution to use, by default Predictive.NONE.

NONE
pushforward_type Pushforward | str

Type of pushforward approximation to use, by default Pushforward.LINEAR.

LINEAR
pushforward_fns list[Callable] | None

Custom pushforward functions to use, by default None.

None
reduce Callable

Function to reduce metrics across batches, by default identity.

identity
sample_key KeyType

Random key for sampling, by default jax.random.key(0).

DEFAULT_KEY
num_samples int

Number of samples for Monte Carlo predictions, by default 10.

10
predictive_jit bool

Whether to jit the predictive distribution, by default True.

True

Returns:

Type Description
tuple[dict[str, Array], Callable[[InputArray], dict[str, Array]]]

A tuple containing:

  • results : dict Dictionary of computed metrics.
  • prob_predictive : Callable The predictive distribution function.
Notes

Supported metrics:

  • REGRESSION: Default metrics for regression tasks
  • CLASSIFICATION: Default metrics for classification tasks
  • Custom metrics can be provided as a list of callables

The function supports both linearized and Monte Carlo predictions through different pushforward types.

Source code in laplax/api.py
def evaluation(
    posterior_fn: Callable[[PriorArguments, Float], Posterior],
    model_fn: ModelFn,
    params: Params,
    arguments: PriorArguments,
    data: Data | Iterator[Data],
    *,
    metrics: DefaultMetrics
    | list[Callable]
    | Callable
    | str = DefaultMetrics.REGRESSION,
    predictive_type: Predictive | str = Predictive.NONE,
    pushforward_type: Pushforward | str = Pushforward.LINEAR,
    pushforward_fns: list[Callable] | None = None,
    reduce: Callable = identity,
    sample_key: KeyType = DEFAULT_KEY,
    num_samples: int = 10,
    predictive_jit: bool = True,
) -> tuple[dict[str, Array], Callable[[InputArray], dict[str, Array]]]:
    """Evaluate the calibrated Laplace approximation.

    This function assesses the performance of the calibrated Laplace approximation
    by computing various metrics on the test data. It supports both regression and
    classification tasks with different predictive distributions.

    Args:
        posterior_fn: Function that generates samples from the posterior.
        model_fn: The neural network forward pass function.
        params: The MAP estimate of the network parameters.
        arguments: The calibrated prior arguments.
        data: The test data for evaluation.
        metrics: Metrics to compute during evaluation, by default
            DefaultMetrics.REGRESSION.
        predictive_type: Type of predictive distribution to use, by default
            Predictive.NONE.
        pushforward_type: Type of pushforward approximation to use, by default
            Pushforward.LINEAR.
        pushforward_fns: Custom pushforward functions to use, by default None.
        reduce: Function to reduce metrics across batches, by default identity.
        sample_key: Random key for sampling, by default jax.random.key(0).
        num_samples: Number of samples for Monte Carlo predictions, by default 10.
        predictive_jit: Whether to jit the predictive distribution, by default True.

    Returns:
        A tuple containing:

            - results : dict
                Dictionary of computed metrics.
            - prob_predictive : Callable
                The predictive distribution function.

    Notes:
        Supported metrics:

        - REGRESSION: Default metrics for regression tasks
        - CLASSIFICATION: Default metrics for classification tasks
        - Custom metrics can be provided as a list of callables

        The function supports both linearized and Monte Carlo predictions through
        different pushforward types.
    """
    metrics_list = _resolve_metrics(metrics)

    set_pushforward, pushforward_fns = _setup_pushforward(
        pushforward_type=pushforward_type,
        predictive_type=predictive_type,
        pushforward_fns=pushforward_fns,
    )

    # Build predictive distribution
    prob_predictive = set_pushforward(
        prior_arguments=arguments,
        model_fn=model_fn,
        mean_params=params,
        posterior_fn=posterior_fn,
        pushforward_fns=pushforward_fns,
        key=sample_key,
        num_samples=num_samples,
    )

    if predictive_jit:
        prob_predictive = jax.jit(prob_predictive)

    # Evaluate
    is_data_loader = _is_data_loader(data)
    transform = _validate_and_get_transform(
        next(iter(data)) if is_data_loader else data
    )

    if is_data_loader:
        results = evaluate_metrics_on_generator(
            pred_fn=prob_predictive,
            data_generator=cast("Iterator[Data]", data),
            metrics=metrics_list,
            transform=transform,
            reduce=jnp.concatenate,
            has_batch=True,
        )
    else:
        results = evaluate_metrics_on_dataset(
            pred_fn=prob_predictive,
            data=transform(data),
            metrics=metrics_list,
            reduce=reduce,
        )

    return results, prob_predictive