Skip to content

Curvature Module

Curvatures

Currently supported curvature-vector products are:

  • GGN-mv (Generalized Gauss-Newton):

    \[ v \mapsto \sum_{n=1}^{N} \mathcal{J}_\theta^\top(f_{\theta^*}(x_n)) \nabla^2_{f_{\theta^*}(x_n),f_{\theta^*}(x_n)} \ell(f_\theta(x_n), y_n) \mathcal{J}_\theta(f_{\theta^*})\, v \]
  • Hessian-mv (Hessian): $$ v \mapsto \sum_{n=1}^{N} \nabla_{\theta \theta}^2 \ell(f_\theta(x_n), y_n)\,v $$

Curvature estimators/approximations

For both curvature-vector products, the following methods are supported for approximating and transforming them into a weight space covariance matrix-vector product:

  • CurvApprox.FULL denses the curvature-vector product into a full matrix. The posterior function is then given by

    \[ (\tau, \mathcal{C}) \mapsto \left[ v \mapsto \left(\textbf{Curv}(\mathcal{C}) + \tau I \right)^{-1} v \right]. \]
  • CurvApprox.DIAGONAL approximates the curvature using only its diagonal, obtained by evaluating the curvature-vector product with standard basis vectors from both sides. This leads to:

    \[ (\tau, \mathcal{C}) \mapsto \left[ v \mapsto \left(\text{diag}(\textbf{Curv}(\mathcal{C}) + \tau I \right)^{-1}v \right]. \]
  • Low-Rank employs either a custom Lanczos routine (CurvApprox.LANCZOS) or a variant of the LOBPCG algorithm (CurvApprox.LOBPCG). These methods approximate the top eigenvectors \(U\) and eigenvalues \(S\) of the curvature via matrix-vector products. The posterior is then given by a low-rank plus scaled diagonal:

    \[ (\tau, \mathcal{C}) \mapsto \left[ v \mapsto \left(\big[U S U^\top\big](\mathcal{C}) + \tau I \right)^{-1} v \right]. \]

Main computational scaffold

This pipeline is controlled via the following three functions:

  • laplax.curv.estimate_curvature: Estimates the curvature based on the provided type and curvature-vector-product.

  • laplax.curv.set_posterior_fn: Takes an estimated curvature and returns a function that maps prior_arguments to the posterior.

  • laplax.curv.create_posterior_fn: Combines the estimate_curvature and set_posterior_fn.

laplax.curv.estimate_curvature

Estimate the curvature based on the provided type.

Parameters:

Name Type Description Default
curv_type CurvApprox | str

Type of curvature approximation (CurvApprox.FULL, CurvApprox.DIAGONAL, CurvApprox.LANCZOS, CurvApprox.LOBPCG) or corresponding string ('full', 'diagonal', 'lanczos', 'lobpcg').

required
mv CurvatureMV

Function representing the curvature-vector product.

required
layout Layout | None

Defines the input layer format of the matrix-vector products. If None or an integer, no flattening/unflattening is used.

None
**kwargs Kwargs

Additional key-word arguments passed to the curvature estimation function.

{}

Returns:

Type Description
PyTree

The estimated curvature.

Source code in laplax/curv/cov.py
def estimate_curvature(
    curv_type: CurvApprox | str,
    mv: CurvatureMV,
    layout: Layout | None = None,
    **kwargs: Kwargs,
) -> PyTree:
    """Estimate the curvature based on the provided type.

    Args:
        curv_type: Type of curvature approximation (`CurvApprox.FULL`,
            `CurvApprox.DIAGONAL`, `CurvApprox.LANCZOS`, `CurvApprox.LOBPCG`) or
            corresponding string (`'full'`, `'diagonal'`, `'lanczos'`, `'lobpcg'`).
        mv: Function representing the curvature-vector product.
        layout: Defines the input layer format of the matrix-vector products. If None or
            an integer, no flattening/unflattening is used.
        **kwargs: Additional key-word arguments passed to the curvature estimation
            function.

    Returns:
        The estimated curvature.
    """
    curv_estimate = CURVATURE_METHODS[curv_type](mv, layout=layout, **kwargs)

    # Ignore lazy evaluation
    curv_estimate = jax.tree.map(
        lambda x: x.block_until_ready() if isinstance(x, jax.Array) else x,
        curv_estimate,
    )

    return curv_estimate

laplax.curv.set_posterior_fn

Set the posterior function based on the curvature estimate.

Parameters:

Name Type Description Default
curv_type CurvatureKeyType

Type of curvature approximation (CurvApprox.FULL, CurvApprox.DIAGONAL, CurvApprox.LANCZOS, CurvApprox.LOBPCG) or corresponding string ('full', 'diagonal', 'lanczos', 'lobpcg').

required
curv_estimate PyTree

Estimated curvature.

required
layout Layout

Defines the input/output layout of the corresponding curvature-vector products. If None or an integer, no flattening/unflattening is used.

required
**kwargs Kwargs

Additional key-word arguments (unused).

{}

Returns:

Type Description
Callable

A function that computes the posterior state.

Raises:

Type Description
ValueError

When layout is neither an integer, a PyTree, nor None.

Source code in laplax/curv/cov.py
def set_posterior_fn(
    curv_type: CurvatureKeyType,
    curv_estimate: PyTree,
    *,
    layout: Layout,
    **kwargs: Kwargs,
) -> Callable:
    """Set the posterior function based on the curvature estimate.

    Args:
        curv_type: Type of curvature approximation (`CurvApprox.FULL`,
            `CurvApprox.DIAGONAL`, `CurvApprox.LANCZOS`, `CurvApprox.LOBPCG`) or
            corresponding string (`'full'`, `'diagonal'`, `'lanczos'`, `'lobpcg'`).
        curv_estimate: Estimated curvature.
        layout: Defines the input/output layout of the corresponding curvature-vector
            products. If `None` or an integer, no flattening/unflattening is used.
        **kwargs: Additional key-word arguments (unused).

    Returns:
        A function that computes the posterior state.

    Raises:
        ValueError: When layout is neither an integer, a PyTree, nor None.
    """
    del kwargs
    if layout is not None and not isinstance(layout, int | PyTree):
        msg = "Layout must be an integer, PyTree or None."
        raise ValueError(msg)

    # Create functions for flattening and unflattening if required
    if layout is None or isinstance(layout, int):
        flatten = unflatten = None
    else:
        # Use custom flatten/unflatten functions for complex pytrees
        flatten, unflatten = create_pytree_flattener(layout)

    def posterior_fn(
        prior_arguments: PriorArguments,
        loss_scaling_factor: Float = 1.0,
    ) -> PosteriorState:
        """Compute the posterior state.

        Args:
            prior_arguments: Prior arguments for the posterior.
            loss_scaling_factor: Factor by which the user-provided loss function is
                scaled. Defaults to 1.0.

        Returns:
            PosteriorState: Dictionary containing:

                - 'state': Updated state of the posterior.
                - 'cov_mv': Function to compute covariance matrix-vector product.
                - 'scale_mv': Function to compute scale matrix-vector product.
        """
        # Calculate posterior precision.
        precision = CURVATURE_PRECISION_METHODS[curv_type](
            curv_estimate=curv_estimate,
            prior_arguments=prior_arguments,
            loss_scaling_factor=loss_scaling_factor,
        )

        # Calculate posterior state
        state = CURVATURE_TO_POSTERIOR_STATE[curv_type](precision)

        # Extract matrix-vector product
        scale_mv_from_state = CURVATURE_STATE_TO_SCALE[curv_type]
        cov_mv_from_state = CURVATURE_STATE_TO_COV[curv_type]

        return Posterior(
            state=state,
            cov_mv=wrap_factory(cov_mv_from_state, flatten, unflatten),
            scale_mv=wrap_factory(scale_mv_from_state, flatten, unflatten),
        )

    return posterior_fn

laplax.curv.create_posterior_fn

Factory function to create the posterior function given a curvature type.

This sets up the posterior function, which can then be initiated using prior_arguments by computing a specified curvature approximation and encoding the sequential computational order of:

1. `CURVATURE_PRIOR_METHODS`
2. `CURVATURE_TO_POSTERIOR_STATE`
3. `CURVATURE_STATE_TO_SCALE`
4. `CURVATURE_STATE_TO_COV`

All methods are selected from the corresponding dictionary by the curv_type argument. New methods can be registered using the :func:laplax.register.register_curvature_method method. See the :mod:laplax.register module for more details.

Parameters:

Name Type Description Default
curv_type CurvApprox | str

Type of curvature approximation (CurvApprox.FULL, CurvApprox.DIAGONAL, CurvApprox.LANCZOS, CurvApprox.LOBPCG) or corresponding string ('full', 'diagonal', 'lanczos', 'lobpcg').

required
mv CurvatureMV

Function representing the curvature.

required
layout Layout | None

Defines the format of the layout for matrix-vector products. If None or an integer, no flattening/unflattening is used.

None
**kwargs Kwargs

Additional keyword arguments passed to the curvature estimation function.

{}

Returns:

Type Description
Callable

A posterior function that takes the prior_arguments and returns the posterior_state.

Source code in laplax/curv/cov.py
def create_posterior_fn(
    curv_type: CurvApprox | str,
    mv: CurvatureMV,
    layout: Layout | None = None,
    **kwargs: Kwargs,
) -> Callable:
    """Factory function to create the posterior function given a curvature type.

    This sets up the posterior function, which can then be initiated using
    `prior_arguments` by computing a specified curvature approximation and encoding the
    sequential computational order of:

        1. `CURVATURE_PRIOR_METHODS`
        2. `CURVATURE_TO_POSTERIOR_STATE`
        3. `CURVATURE_STATE_TO_SCALE`
        4. `CURVATURE_STATE_TO_COV`

    All methods are selected from the corresponding dictionary by the `curv_type`
    argument. New methods can be registered using the
    :func:`laplax.register.register_curvature_method` method.
    See the :mod:`laplax.register` module for more details.

    Args:
        curv_type: Type of curvature approximation (`CurvApprox.FULL`,
            `CurvApprox.DIAGONAL`, `CurvApprox.LANCZOS`, `CurvApprox.LOBPCG`) or
            corresponding string (`'full'`, `'diagonal'`, `'lanczos'`, `'lobpcg'`).
        mv: Function representing the curvature.
        layout: Defines the format of the layout for matrix-vector products. If `None`
            or an integer, no flattening/unflattening is used.
        **kwargs: Additional keyword arguments passed to the curvature estimation
            function.

    Returns:
        A posterior function that takes the `prior_arguments` and returns the
            `posterior_state`.
    """
    # Retrieve the curvature estimator based on the provided type
    curv_estimate = estimate_curvature(curv_type, mv=mv, layout=layout, **kwargs)

    # Set posterior fn based on curv_estimate
    posterior_fn = set_posterior_fn(curv_type, curv_estimate, layout=layout)

    return posterior_fn