laplax.curv.cov
Posterior covariance functions for various curvature estimates.
estimate_curvature ¶
estimate_curvature(curv_type: CurvApprox | str, mv: CurvatureMV, layout: Layout | None = None, **kwargs: Kwargs) -> PyTree
Estimate the curvature based on the provided type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
curv_type
|
CurvApprox | str
|
Type of curvature approximation ( |
required |
mv
|
CurvatureMV
|
Function representing the curvature-vector product. |
required |
layout
|
Layout | None
|
Defines the input layer format of the matrix-vector products. If None or an integer, no flattening/unflattening is used. |
None
|
**kwargs
|
Kwargs
|
Additional key-word arguments passed to the curvature estimation function. |
{}
|
Returns:
Type | Description |
---|---|
PyTree
|
The estimated curvature. |
Source code in laplax/curv/cov.py
set_posterior_fn ¶
set_posterior_fn(curv_type: CurvatureKeyType, curv_estimate: PyTree, *, layout: Layout, **kwargs: Kwargs) -> Callable
Set the posterior function based on the curvature estimate.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
curv_type
|
CurvatureKeyType
|
Type of curvature approximation ( |
required |
curv_estimate
|
PyTree
|
Estimated curvature. |
required |
layout
|
Layout
|
Defines the input/output layout of the corresponding curvature-vector
products. If |
required |
**kwargs
|
Kwargs
|
Additional key-word arguments (unused). |
{}
|
Returns:
Type | Description |
---|---|
Callable
|
A function that computes the posterior state. |
Raises:
Type | Description |
---|---|
ValueError
|
When layout is neither an integer, a PyTree, nor None. |
Source code in laplax/curv/cov.py
create_posterior_fn ¶
create_posterior_fn(curv_type: CurvApprox | str, mv: CurvatureMV, layout: Layout | None = None, **kwargs: Kwargs) -> Callable
Factory function to create the posterior function given a curvature type.
This sets up the posterior function, which can then be initiated using
prior_arguments
by computing a specified curvature approximation and encoding the
sequential computational order of:
1. `CURVATURE_PRIOR_METHODS`
2. `CURVATURE_TO_POSTERIOR_STATE`
3. `CURVATURE_STATE_TO_SCALE`
4. `CURVATURE_STATE_TO_COV`
All methods are selected from the corresponding dictionary by the curv_type
argument. New methods can be registered using the
:func:laplax.register.register_curvature_method
method.
See the :mod:laplax.register
module for more details.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
curv_type
|
CurvApprox | str
|
Type of curvature approximation ( |
required |
mv
|
CurvatureMV
|
Function representing the curvature. |
required |
layout
|
Layout | None
|
Defines the format of the layout for matrix-vector products. If |
None
|
**kwargs
|
Kwargs
|
Additional keyword arguments passed to the curvature estimation function. |
{}
|
Returns:
Type | Description |
---|---|
Callable
|
A posterior function that takes the |