laplax.eval.utils
Pushforward utilities for evaluating probabilistic predictions on datasets.
This module provides utilities for evaluating probabilistic models on datasets and managing metric computations.
Key features include:
- Wrapping functions to store outputs in a structured format.
- Finalizing multiple functions and collecting results in a dictionary.
- Applying prediction functions across datasets to generate predictions and evaluating them against their targets.
- Computing and transforming evaluation metrics for datasets using custom or default metrics.
These utilities streamline dataset evaluation workflows and ensure flexibility in metric computation and result aggregation.
finalize_fns ¶
finalize_fns(fns: list[Callable], results: dict, aux: dict[str, Any] | None = None, **kwargs: Kwargs) -> dict
Execute a set of functions and store their results in a dictionary.
This function iterates over a list of functions, executes each
function with the provided keyword arguments, and updates the results
dictionary with their outputs. The functions know what key they should update the
results dict with.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
fns
|
list[Callable]
|
A list of callables to execute. |
required |
results
|
dict
|
A dictionary to store the outputs of the functions. |
required |
aux
|
dict[str, Any] | None
|
Auxiliary data passed to the functions. |
None
|
**kwargs
|
Kwargs
|
Additional arguments passed to each function. |
{}
|
Returns:
Type | Description |
---|---|
dict
|
The updated |
Source code in laplax/eval/utils.py
evaluate_on_dataset ¶
evaluate_on_dataset(pred_fn: Callable[[InputArray], dict[str, Array]], data: Data, **kwargs: Kwargs) -> dict
Evaluate a prediction function on a dataset.
This function applies a probabilistic predictive function (pred_fn
) to
each data point in the dataset, combining the predictions with the target
labels.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pred_fn
|
Callable[[InputArray], dict[str, Array]]
|
A callable that takes an input array and returns predictions as a dictionary. |
required |
data
|
Data
|
A dataset, where each data point is a dictionary containing "input" and "target". |
required |
**kwargs
|
Kwargs
|
Additional arguments, including:
|
{}
|
Returns:
Type | Description |
---|---|
dict
|
A dictionary containing predictions and target labels for the entire dataset. |
Source code in laplax/eval/utils.py
apply_fns ¶
apply_fns(*funcs: Callable, names: list[str] | None = None, field: str = 'results', **kwargs: Kwargs) -> Callable
Apply multiple functions and store their results in a dictionary.
This function takes a sequence of functions, applies them to the provided inputs, and stores their results in either the 'results' or 'aux' dictionary under specified names. This function is useful for applying multiple metrics to the results of a pushforward function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*funcs
|
Callable
|
Variable number of callable functions to be applied. |
()
|
names
|
list[str] | None
|
Optional list of names for the functions' results. If None, function names will be used. |
None
|
field
|
str
|
String indicating where to store results, either 'results' or 'aux' (default: 'results'). |
'results'
|
**kwargs
|
Kwargs
|
Mapping of argument names to keys in results/aux dictionaries that will be passed to the functions. |
{}
|
Returns:
Type | Description |
---|---|
Callable
|
A function that takes 'results' and 'aux' dictionaries along with additional kwargs, applies the functions, and returns the updated dictionaries. |
Raises:
Type | Description |
---|---|
TypeError
|
If any of the provided functions is not callable. |
Source code in laplax/eval/utils.py
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
|
transfer_entry ¶
transfer_entry(mapping: dict[str, str] | list[str], field: str = 'results', access_from: str = 'aux') -> Callable
Transfer entries between results and auxiliary dictionaries.
This function creates a callable that copies values between the results and auxiliary dictionaries based on the provided mapping.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mapping
|
dict[str, str] | list[str]
|
Either a dictionary mapping destination keys to source keys, or a list of keys to copy with the same names. |
required |
field
|
str
|
String indicating where to store entries, either 'results' or 'aux' (default: 'results'). |
'results'
|
access_from
|
str
|
String indicating which dictionary to read from, either 'results' or 'aux' (default: 'aux'). |
'aux'
|
Returns:
Type | Description |
---|---|
Callable
|
A function that takes 'results' and 'aux' dictionaries, transfers the specified entries, and returns the updated dictionaries. |
Raises:
Type | Description |
---|---|
ValueError
|
If field is not 'results' or 'aux'. |
Source code in laplax/eval/utils.py
evaluate_metrics_on_dataset ¶
evaluate_metrics_on_dataset(pred_fn: Callable[[InputArray], dict[str, Array]], data: Data, *, metrics: list | None = None, metrics_dict: dict[str, Callable] | None = None, reduce: Callable = identity, **kwargs: Kwargs) -> dict
Evaluate a set of metrics on a dataset.
This function computes specified metrics for predictions generated by a
probabilistic predictive function (pred_fn
) over a dataset. The results
can optionally be transformed using an apply
function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pred_fn
|
Callable[[InputArray], dict[str, Array]]
|
A callable that takes an input array and returns predictions as a dictionary. |
required |
data
|
Data
|
A dataset, where each data point is a dictionary containing "input" and "target". |
required |
metrics
|
list | None
|
A list of metrics to compute, this should use the |
None
|
metrics_dict
|
dict[str, Callable] | None
|
A dictionary of metrics to compute, where keys are metric names and values are callables. |
None
|
reduce
|
Callable
|
A callable to transform the evaluated metrics (default: identity). |
identity
|
**kwargs
|
Kwargs
|
Additional arguments, including:
|
{}
|
Returns:
Type | Description |
---|---|
dict
|
A dictionary containing the evaluated metrics for the entire dataset. |
Raises:
Type | Description |
---|---|
ValueError
|
When metrics and metrics_dict are both None. |
Source code in laplax/eval/utils.py
evaluate_metrics_on_generator ¶
evaluate_metrics_on_generator(pred_fn: Callable[[InputArray], dict[str, Array]], data_generator: Iterator[Data], *, metrics: list | None = None, metrics_dict: dict[str, Callable] | None = None, transform: Callable = identity, reduce: Callable = identity, vmap_over_data: bool = True, **kwargs: Kwargs) -> dict
Evaluate a set of metrics on a data generator.
Similar to evaluate_metrics_on_dataset, but works with a generator of data points instead of a dataset array. This is useful for cases where the data doesn't fit in memory or is being streamed.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pred_fn
|
Callable[[InputArray], dict[str, Array]]
|
A callable that takes an input array and returns predictions as a dictionary. |
required |
data_generator
|
Iterator[Data]
|
An iterator yielding data points, where each data point is a dictionary containing "input" and "target". |
required |
metrics
|
list | None
|
A list of metrics to compute, this should use the |
None
|
metrics_dict
|
dict[str, Callable] | None
|
A dictionary of metrics to compute, where keys are metric names and values are callables. |
None
|
transform
|
Callable
|
The transform over individual data points. |
identity
|
reduce
|
Callable
|
A callable to transform the evaluated metrics (default: identity). |
identity
|
vmap_over_data
|
bool
|
Data batches from generator have unaccounted batch dimension (default: True). |
True
|
**kwargs
|
Kwargs
|
Additional keyword arguments passed to the metrics functions. |
{}
|
Returns:
Type | Description |
---|---|
dict
|
A dictionary containing the evaluated metrics for all data points. |
Raises:
Type | Description |
---|---|
ValueError
|
If neither metrics nor metric_dict is provided. |
Source code in laplax/eval/utils.py
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 |
|