laplax.eval.calibrate
Calibration utilities for optimizing prior precision in probabilistic models.
This script provides utilities for optimizing prior precision in probabilistic models. It includes functions to:
- Evaluate metrics for given prior arguments and datasets.
- Perform grid search to optimize prior precision using objective functions.
- Optimize prior precision over a logarithmic grid interval.
The script leverages JAX for numerical operations, Loguru for logging, and custom
utilities from the laplax
package.
evaluate_for_given_prior_arguments ¶
evaluate_for_given_prior_arguments(*, data: Data, set_prob_predictive: Callable, metric: Callable = chi_squared_zero, **kwargs: Kwargs) -> Float
Evaluate the metric for a given set of prior arguments and data.
This function computes predictions for the input data using a probabilistic
predictive function generated by set_prob_predictive
. It then evaluates a
specified metric using these predictions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data
|
Data
|
Dataset containing inputs and targets. |
required |
set_prob_predictive
|
Callable
|
A callable that generates a probabilistic predictive function. |
required |
metric
|
Callable
|
A callable metric function to evaluate the predictions
(default: |
chi_squared_zero
|
**kwargs
|
Kwargs
|
Additional arguments passed to |
{}
|
Returns:
Type | Description |
---|---|
Float
|
The evaluated metric value. |
Source code in laplax/eval/calibrate.py
grid_search ¶
grid_search(prior_prec_interval: Array, objective: Callable[[PriorArguments], float], patience: int | None = None, max_iterations: int | None = None) -> Float
Perform grid search to optimize prior precision.
This function iteratively evaluates an objective function over a range of
prior precisions. It tracks the performance and stops early if results
increase consecutively for a specified number of iterations (patience
).
The prior precision which scores the lowest is returned.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prior_prec_interval
|
Array
|
An array of prior precision values to search. |
required |
objective
|
Callable[[PriorArguments], float]
|
A callable objective function that takes |
required |
patience
|
int | None
|
The number of consecutive iterations with increasing results to tolerate before stopping (default: 5). |
None
|
max_iterations
|
int | None
|
The maximum number of iterations to perform (default: None). |
None
|
Returns:
Type | Description |
---|---|
Float
|
The prior precision value that minimizes the objective function. |
Source code in laplax/eval/calibrate.py
optimize_prior_prec ¶
optimize_prior_prec(objective: Callable[[PriorArguments], float], log_prior_prec_min: float = -5.0, log_prior_prec_max: float = 6.0, grid_size: int = 300, **kwargs: Kwargs) -> Float
Optimize prior precision using logarithmic grid search.
This function creates a logarithmically spaced interval of prior precision values and performs a grid search to find the optimal value that minimizes the specified objective function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
objective
|
Callable[[PriorArguments], float]
|
A callable objective function that takes |
required |
log_prior_prec_min
|
float
|
The base-10 logarithm of the minimum prior precision value (default: -5.0). |
-5.0
|
log_prior_prec_max
|
float
|
The base-10 logarithm of the maximum prior precision value (default: 6.0). |
6.0
|
grid_size
|
int
|
The number of points in the grid interval (default: 300). |
300
|
**kwargs
|
Kwargs
|
Additional arguments passed to |
{}
|
Returns:
Type | Description |
---|---|
Float
|
The optimized prior precision value. |