laplax.register
register_calibration_method ¶
Register a new calibration method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
method_name
|
str
|
Name of the calibration method. |
required |
method_fn
|
Callable
|
Function implementing the calibration method. |
required |
Notes
The method function should have signature
method_fn(objective: Callable, **kwargs) -> float
Source code in laplax/register.py
register_curvature_method ¶
register_curvature_method(name: str, *, create_curvature_fn: Callable[[CurvatureMV, Layout, Any], Any] | None = None, curvature_to_precision_fn: Callable | None = None, prec_to_posterior_fn: Callable | None = None, posterior_state_to_scale_fn: Callable[[PosteriorState], Callable[[FlatParams], FlatParams]] | None = None, posterior_state_to_cov_fn: Callable[[PosteriorState], Callable[[FlatParams], FlatParams]] | None = None, marginal_log_likelihood_fn: Callable | None = None, default: CurvApprox | None = None) -> None
Register a new curvature method with optional custom functions.
This function allows adding new curvature methods with their corresponding functions for creating curvature estimates, adding prior information, computing posterior states, and deriving matrix-vector product functions for scale and covariance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name
|
str
|
Name of the new curvature method. |
required |
create_curvature_fn
|
Callable[[CurvatureMV, Layout, Any], Any] | None
|
Custom function to create the curvature estimate. Defaults to None. |
None
|
curvature_to_precision_fn
|
Callable | None
|
Custom function to convert the curvature estimate to a posterior precision matrix. Defaults to None. |
None
|
prec_to_posterior_fn
|
Callable | None
|
Custom function to convert the posterior precision matrix to a posterior state. Defaults to None. |
None
|
posterior_state_to_scale_fn
|
Callable[[PosteriorState], Callable[[FlatParams], FlatParams]] | None
|
Custom function to compute scale matrix-vector products. Defaults to None. |
None
|
posterior_state_to_cov_fn
|
Callable[[PosteriorState], Callable[[FlatParams], FlatParams]] | None
|
Custom function to compute covariance matrix-vector products. Defaults to None. |
None
|
marginal_log_likelihood_fn
|
Callable | None
|
Custom function to compute the marginal log-likelihood. Defaults to None. |
None
|
default
|
CurvApprox | None
|
Default method to inherit missing functionality from. Defaults to None. |
None
|
Raises:
Type | Description |
---|---|
ValueError
|
If no default is provided and required functions are missing. |
Source code in laplax/register.py
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
|