laplax.curv.low_rank
Low-rank curvature approximation.
create_low_rank_curvature ¶
create_low_rank_curvature(mv: CurvatureMV, layout: Layout, low_rank_method: LowRankMethod = LANCZOS, **kwargs: Kwargs) -> LowRankTerms
Generate a low-rank curvature approximation.
The low-rank curvature is computed as an approximation to the full curvature matrix
using the provided matrix-vector product function and either the Lanczos or LOBPCG
algorithm. The low-rank approximation is returned as a LowRankTerms
object.
The low-rank approximation is computed as:
where \(U\) are the eigenvectors and \(S\) are the eigenvalues. The LowRankTerms
holds
the eigenvectors, eigenvalues, and a scalar factor. The latter can be used to
express an isotropic Gaussian prior.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mv
|
CurvatureMV
|
Matrix-vector product function representing the curvature. |
required |
layout
|
Layout
|
Structure defining the parameter layout that is assumed by the matrix-vector product function. |
required |
low_rank_method
|
LowRankMethod
|
Method to use for computing the low-rank approximation.
Can be either |
LANCZOS
|
**kwargs
|
Kwargs
|
Additional arguments passed to the low-rank method. |
{}
|
Returns:
Type | Description |
---|---|
LowRankTerms
|
A LowRankTerms object representing the low-rank curvature approximation. |
Source code in laplax/curv/low_rank.py
create_low_rank_mv ¶
create_low_rank_mv(low_rank_terms: LowRankTerms) -> Callable[[FlatParams], FlatParams]
Create a low-rank matrix-vector product function.
The low-rank matrix-vector product is computed as the sum of the scalar multiple of the vector by the scalar and the product of the matrix-vector product of the eigenvectors and the eigenvalues times the eigenvector-vector product:
Parameters:
Name | Type | Description | Default |
---|---|---|---|
low_rank_terms
|
LowRankTerms
|
Low-rank curvature approximation. |
required |
Returns:
Type | Description |
---|---|
Callable[[FlatParams], FlatParams]
|
A function that computes the low-rank matrix-vector product. |
Source code in laplax/curv/low_rank.py
low_rank_square ¶
low_rank_square(state: LowRankTerms) -> LowRankTerms
Square the low-rank curvature approximation.
This returns the LowRankTerms
which correspond to the squared low-rank
approximation. The squared low-rank approximation is computed as:
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
LowRankTerms
|
Low-rank curvature approximation. |
required |
Returns:
Type | Description |
---|---|
LowRankTerms
|
A |
Source code in laplax/curv/low_rank.py
low_rank_curvature_to_precision ¶
low_rank_curvature_to_precision(curv_estimate: LowRankTerms, prior_arguments: PriorArguments, loss_scaling_factor: Float = 1.0) -> LowRankTerms
Add prior precision to the low-rank curvature estimate.
The prior precision (of an isotropic Gaussian prior) is read from the
prior_arguments
dictionary and added to the scalar component of the
LowRankTerms.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
curv_estimate
|
LowRankTerms
|
Low-rank curvature approximation. |
required |
prior_arguments
|
PriorArguments
|
Dictionary containing prior precision as 'prior_prec'. |
required |
loss_scaling_factor
|
Float
|
Factor by which the user-provided loss function is scaled. Defaults to 1.0. |
1.0
|
Returns:
Name | Type | Description |
---|---|---|
LowRankTerms |
LowRankTerms
|
Updated low-rank curvature approximation with added prior precision. |
Source code in laplax/curv/low_rank.py
low_rank_prec_to_posterior_state ¶
low_rank_prec_to_posterior_state(curv_estimate: LowRankTerms) -> dict[str, LowRankTerms]
Convert the low-rank precision representation to a posterior state.
The scalar component and eigenvalues of the low-rank curvature estimate are
transformed to represent the posterior scale, creating again a LowRankTerms
representation. The scale matrix is the diagonal matrix with the inverse of the
square root of the low-rank approximation using the Woodbury identity for analytic
inversion.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
curv_estimate
|
LowRankTerms
|
Low-rank curvature estimate. |
required |
Returns:
Type | Description |
---|---|
dict[str, LowRankTerms]
|
A dictionary with the posterior state represented as |
Source code in laplax/curv/low_rank.py
low_rank_posterior_state_to_scale ¶
low_rank_posterior_state_to_scale(state: dict[str, LowRankTerms]) -> Callable[[FlatParams], FlatParams]
Create a matrix-vector product function for the scale matrix.
The state dictionary containing the low-rank representation of the covariance state is used to create a function that computes the matrix-vector product for the scale matrix. The scale matrix is the diagonal matrix with the inverse of the square root of the eigenvalues.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
dict[str, LowRankTerms]
|
Dictionary containing the low-rank scale. |
required |
Returns:
Type | Description |
---|---|
Callable[[FlatParams], FlatParams]
|
A function that computes the scale matrix-vector product. |
Source code in laplax/curv/low_rank.py
low_rank_posterior_state_to_cov ¶
low_rank_posterior_state_to_cov(state: dict[str, LowRankTerms]) -> Callable[[FlatParams], FlatParams]
Create a matrix-vector product function for the covariance matrix.
The state dictionary containing the low-rank representation of the covariance state is used to create a function that computes the matrix-vector product for the covariance matrix. The covariance matrix is the low-rank approximation squared.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
dict[str, LowRankTerms]
|
Dictionary containing the low-rank scale. |
required |
Returns:
Type | Description |
---|---|
Callable[[FlatParams], FlatParams]
|
A function that computes the covariance matrix-vector product. |