laplax.eval.calibrate
Calibration utilities for optimizing prior precision in probabilistic models.
This script provides utilities for optimizing prior precision in probabilistic models. It includes functions to:
- Evaluate metrics for given prior arguments and datasets.
- Perform grid search to optimize prior precision using objective functions.
- Optimize prior precision over a logarithmic grid interval.
The script leverages JAX for numerical operations, Loguru for logging, and custom
utilities from the laplax package.
evaluate_for_given_prior_arguments ¶
evaluate_for_given_prior_arguments(*, data: Data, set_prob_predictive: Callable, metric: Callable = chi_squared_zero, **kwargs: Kwargs) -> Float
Evaluate the metric for a given set of prior arguments and data.
This function computes predictions for the input data using a probabilistic
predictive function generated by set_prob_predictive. It then evaluates a
specified metric using these predictions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Data
|
Dataset containing inputs and targets. |
required |
set_prob_predictive
|
Callable
|
A callable that generates a probabilistic predictive function. |
required |
metric
|
Callable
|
A callable metric function to evaluate the predictions
(default: |
chi_squared_zero
|
**kwargs
|
Kwargs
|
Additional arguments passed to |
{}
|
Returns:
| Type | Description |
|---|---|
Float
|
The evaluated metric value. |
Source code in laplax/eval/calibrate.py
grid_search ¶
grid_search(prior_prec_interval: Array, objective: Callable[[PriorArguments], float], patience: int | None = None, max_iterations: int | None = None) -> Float
Perform grid search to optimize prior precision.
This function iteratively evaluates an objective function over a range of
prior precisions. It tracks the performance and stops early if results
increase consecutively for a specified number of iterations (patience).
The prior precision which scores the lowest is returned.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prior_prec_interval
|
Array
|
An array of prior precision values to search. |
required |
objective
|
Callable[[PriorArguments], float]
|
A callable objective function that takes |
required |
patience
|
int | None
|
The number of consecutive iterations with increasing results to tolerate before stopping (default: 5). |
None
|
max_iterations
|
int | None
|
The maximum number of iterations to perform (default: None). |
None
|
Returns:
| Type | Description |
|---|---|
Float
|
The prior precision value that minimizes the objective function. |
Source code in laplax/eval/calibrate.py
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | |
optimize_prior_prec ¶
optimize_prior_prec(objective: Callable[[PriorArguments], float], log_prior_prec_min: float = -5.0, log_prior_prec_max: float = 6.0, grid_size: int = 300, **kwargs: Kwargs) -> Float
Optimize prior precision using logarithmic grid search.
This function creates a logarithmically spaced interval of prior precision values and performs a grid search to find the optimal value that minimizes the specified objective function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
objective
|
Callable[[PriorArguments], float]
|
A callable objective function that takes |
required |
log_prior_prec_min
|
float
|
The base-10 logarithm of the minimum prior precision value (default: -5.0). |
-5.0
|
log_prior_prec_max
|
float
|
The base-10 logarithm of the maximum prior precision value (default: 6.0). |
6.0
|
grid_size
|
int
|
The number of points in the grid interval (default: 300). |
300
|
**kwargs
|
Kwargs
|
Additional arguments passed to |
{}
|
Returns:
| Type | Description |
|---|---|
Float
|
The optimized prior precision value. |