laplax.curv.full
Full curvature approximation.
create_full_curvature ¶
Generate a full curvature approximation.
The curvature is densed and flattened into a 2D array, that corresponds to the flattened parameter layout.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mv
|
CurvatureMV
|
Matrix-vector product function representing the curvature. |
required |
layout
|
Layout
|
Structure defining the parameter layout that is assumed by the
matrix-vector product function. If |
required |
**kwargs
|
Kwargs
|
Additional arguments (unused). |
{}
|
Returns:
Type | Description |
---|---|
Num[Array, 'P P']
|
A dense matrix representing the full curvature approximation. |
Source code in laplax/curv/full.py
full_curvature_to_precision ¶
full_curvature_to_precision(curv_estimate: Num[Array, 'P P'], prior_arguments: PriorArguments, loss_scaling_factor: Float = 1.0) -> Num[Array, 'P P']
Add prior precision to the curvature estimate.
The prior precision (of an isotropic Gaussian prior) is read of the prior_arguments dictionary and added to the curvature estimate. The curvature is scaled by the \(\sigma^2\) parameter.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
curv_estimate
|
Num[Array, 'P P']
|
Full curvature estimate matrix. |
required |
prior_arguments
|
PriorArguments
|
Dictionary containing prior precision as 'prior_prec'. |
required |
loss_scaling_factor
|
Float
|
Factor by which the user-provided loss function is scaled. Defaults to 1.0. |
1.0
|
Returns:
Type | Description |
---|---|
Num[Array, 'P P']
|
Updated curvature matrix with added prior precision. |
Source code in laplax/curv/full.py
full_prec_to_scale ¶
Convert precision matrix to scale matrix using Cholesky decomposition.
This converts a precision matrix to a scale matrix using a Cholesky decomposition. The scale matrix is the lower triangular matrix L such that L @ L.T is the covariance matrix.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prec
|
Num[Array, 'P P']
|
Precision matrix to convert. |
required |
Returns:
Type | Description |
---|---|
Num[Array, 'P P']
|
Scale matrix L where L @ L.T is the covariance matrix. |
Source code in laplax/curv/full.py
full_prec_to_posterior_state ¶
Convert precision matrix to scale matrix.
The provided precision matrix is converted to a scale matrix, which is the lower
triangular matrix L such that L @ L.T is the covariance matrix using
:func: full_prec_to_scale
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prec
|
Num[Array, 'P P']
|
Precision matrix to convert. |
required |
Returns:
Type | Description |
---|---|
dict[str, Num[Array, 'P P']]
|
Scale matrix L where L @ L.T is the covariance matrix. |
Source code in laplax/curv/full.py
full_posterior_state_to_scale ¶
full_posterior_state_to_scale(state: dict[str, Num[Array, 'P P']]) -> Callable[[FlatParams], FlatParams]
Create a scale matrix-vector product function.
The scale matrix is read from the state dictionary and is used to create a corresponding matrix-vector product function representing the action of the scale matrix on a vector.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
dict[str, Num[Array, 'P P']]
|
Dictionary containing the scale matrix. |
required |
Returns:
Type | Description |
---|---|
Callable[[FlatParams], FlatParams]
|
A function that computes the scale matrix-vector product. |
Source code in laplax/curv/full.py
full_posterior_state_to_cov ¶
full_posterior_state_to_cov(state: dict[str, Num[Array, 'P P']]) -> Callable[[FlatParams], FlatParams]
Create a covariance matrix-vector product function.
The scale matrix is read from the state dictionary and is used to create a corresponding matrix-vector product function representing the action of the cov matrix on a vector. The covariance matrix is computed as the product of the scale matrix and its transpose.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
dict[str, Num[Array, 'P P']]
|
Dictionary containing the scale matrix. |
required |
Returns:
Type | Description |
---|---|
Callable[[FlatParams], FlatParams]
|
A function that computes the covariance matrix-vector product. |