laplax.curv.diagonal
Diagonal curvature approximation.
create_diagonal_curvature ¶
Generate a diagonal curvature.
The diagonal of the curvature matrix-vector product is computed as an approximation to the full matrix.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mv
|
CurvatureMV
|
Matrix-vector product function representing the curvature. |
required |
layout
|
Layout
|
Structure defining the parameter layout that is assumed by the matrix-vector product function. |
required |
**kwargs
|
Kwargs
|
Additional arguments (unused). |
{}
|
Returns:
Type | Description |
---|---|
FlatParams
|
A 1D array representing the diagonal curvature. |
Source code in laplax/curv/diagonal.py
diagonal_curvature_to_precision ¶
diagonal_curvature_to_precision(curv_estimate: FlatParams, prior_arguments: PriorArguments, loss_scaling_factor: Float = 1.0) -> FlatParams
Add prior precision to the diagonal curvature estimate.
The prior precision (of an isotropic Gaussian prior) is read of the prior_arguments dictionary and added to the diagonal curvature estimate. The curvature (here: diagonal) is scaled by the \(\sigma^2\) parameter.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
curv_estimate
|
FlatParams
|
Diagonal curvature estimate. |
required |
prior_arguments
|
PriorArguments
|
Dictionary containing prior precision as 'prior_prec'. |
required |
loss_scaling_factor
|
Float
|
Factor by which the user-provided loss function is scaled. Defaults to 1.0. |
1.0
|
Returns:
Type | Description |
---|---|
FlatParams
|
Updated diagonal curvature with added prior precision. |
Source code in laplax/curv/diagonal.py
diagonal_prec_to_posterior_state ¶
Convert precision matrix to scale matrix.
The provided diagonal precision matrix is converted to the corresponding scale
diagonal, which is returned as a PosteriorState
dictionary. The scale matrix is
the diagonal matrix with the inverse of the diagonal elements.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prec
|
FlatParams
|
Precision matrix to convert. |
required |
Returns:
Type | Description |
---|---|
dict[str, FlatParams]
|
Scale matrix L where L @ L.T is the covariance matrix. |
Source code in laplax/curv/diagonal.py
diagonal_posterior_state_to_scale ¶
diagonal_posterior_state_to_scale(state: dict[str, FlatParams]) -> Callable[[FlatParams], FlatParams]
Create a scale matrix-vector product function.
The diagonal scale matrix is read from the state dictionary and is used to create a corresponding matrix-vector product function representing the action of the diagonal scale matrix on a vector.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
dict[str, FlatParams]
|
Dictionary containing the diagonal scale matrix. |
required |
Returns:
Type | Description |
---|---|
Callable[[FlatParams], FlatParams]
|
A function that computes the diagonal scale matrix-vector product. |
Source code in laplax/curv/diagonal.py
diagonal_posterior_state_to_cov ¶
Create a covariance matrix-vector product function.
The diagonal covariance matrix is computed as the product of the diagonal scale matrix with itself.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
dict[str, FlatParams]
|
Dictionary containing the diagonal scale matrix. |
required |
Returns:
Type | Description |
---|---|
Callable[[FlatParams], FlatParams]
|
A function that computes the diagonal covariance matrix-vector product. |