Curvature Module¶
Curvatures¶
Currently supported curvature-vector products are:
-
GGN-mv (Generalized Gauss-Newton):
\[ v \mapsto \sum_{n=1}^{N} \mathcal{J}_\theta^\top(f_{\theta^*}(x_n)) \nabla^2_{f_{\theta^*}(x_n),f_{\theta^*}(x_n)} \ell(f_\theta(x_n), y_n) \mathcal{J}_\theta(f_{\theta^*})\, v \] -
Hessian-mv (Hessian): $$ v \mapsto \sum_{n=1}^{N} \nabla_{\theta \theta}^2 \ell(f_\theta(x_n), y_n)\,v $$
Curvature estimators/approximations¶
For both curvature-vector products, the following methods are supported for approximating and transforming them into a weight space covariance matrix-vector product:
-
CurvApprox.FULL
denses the curvature-vector product into a full matrix. The posterior function is then given by\[ (\tau, \mathcal{C}) \mapsto \left[ v \mapsto \left(\textbf{Curv}(\mathcal{C}) + \tau I \right)^{-1} v \right]. \] -
CurvApprox.DIAGONAL
approximates the curvature using only its diagonal, obtained by evaluating the curvature-vector product with standard basis vectors from both sides. This leads to:\[ (\tau, \mathcal{C}) \mapsto \left[ v \mapsto \left(\text{diag}(\textbf{Curv}(\mathcal{C}) + \tau I \right)^{-1}v \right]. \] -
Low-Rank employs either a custom Lanczos routine (
CurvApprox.LANCZOS
) or a variant of the LOBPCG algorithm (CurvApprox.LOBPCG
). These methods approximate the top eigenvectors \(U\) and eigenvalues \(S\) of the curvature via matrix-vector products. The posterior is then given by a low-rank plus scaled diagonal:\[ (\tau, \mathcal{C}) \mapsto \left[ v \mapsto \left(\big[U S U^\top\big](\mathcal{C}) + \tau I \right)^{-1} v \right]. \]
Main computational scaffold¶
This pipeline is controlled via the following three functions:
-
laplax.curv.estimate_curvature
: Estimates the curvature based on the provided type and curvature-vector-product. -
laplax.curv.set_posterior_fn
: Takes an estimated curvature and returns a function that mapsprior_arguments
to the posterior. -
laplax.curv.create_posterior_fn
: Combines theestimate_curvature
andset_posterior_fn
.
laplax.curv.estimate_curvature¶
Estimate the curvature based on the provided type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
curv_type
|
CurvApprox | str
|
Type of curvature approximation ( |
required |
mv
|
CurvatureMV
|
Function representing the curvature-vector product. |
required |
layout
|
Layout | None
|
Defines the input layer format of the matrix-vector products. If None or an integer, no flattening/unflattening is used. |
None
|
**kwargs
|
Kwargs
|
Additional key-word arguments passed to the curvature estimation function. |
{}
|
Returns:
Type | Description |
---|---|
PyTree
|
The estimated curvature. |
Source code in laplax/curv/cov.py
laplax.curv.set_posterior_fn¶
Set the posterior function based on the curvature estimate.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
curv_type
|
CurvatureKeyType
|
Type of curvature approximation ( |
required |
curv_estimate
|
PyTree
|
Estimated curvature. |
required |
layout
|
Layout
|
Defines the input/output layout of the corresponding curvature-vector
products. If |
required |
**kwargs
|
Kwargs
|
Additional key-word arguments (unused). |
{}
|
Returns:
Type | Description |
---|---|
Callable
|
A function that computes the posterior state. |
Raises:
Type | Description |
---|---|
ValueError
|
When layout is neither an integer, a PyTree, nor None. |
Source code in laplax/curv/cov.py
laplax.curv.create_posterior_fn¶
Factory function to create the posterior function given a curvature type.
This sets up the posterior function, which can then be initiated using
prior_arguments
by computing a specified curvature approximation and encoding the
sequential computational order of:
1. `CURVATURE_PRIOR_METHODS`
2. `CURVATURE_TO_POSTERIOR_STATE`
3. `CURVATURE_STATE_TO_SCALE`
4. `CURVATURE_STATE_TO_COV`
All methods are selected from the corresponding dictionary by the curv_type
argument. New methods can be registered using the
:func:laplax.register.register_curvature_method
method.
See the :mod:laplax.register
module for more details.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
curv_type
|
CurvApprox | str
|
Type of curvature approximation ( |
required |
mv
|
CurvatureMV
|
Function representing the curvature. |
required |
layout
|
Layout | None
|
Defines the format of the layout for matrix-vector products. If |
None
|
**kwargs
|
Kwargs
|
Additional keyword arguments passed to the curvature estimation function. |
{}
|
Returns:
Type | Description |
---|---|
Callable
|
A posterior function that takes the |