laplax.eval.metrics
Regression and Classification Metrics for Uncertainty Quantification.
This module provides a comprehensive suite of classification and regression metrics for evaluating probabilistic models.
Key Features¶
Classification Metrics¶
- Accuracy
- Top-k Accuracy
- Cross-Entropy
- Multiclass Brier Score
- Expected Calibration Error (ECE)
- Maximum Calibration Error (MCE)
Regression Metrics¶
- Root Mean Squared Error (RMSE)
- Chi-squared
- Negative Log-Likelihood (NLL) for Gaussian distributions
Bin Metrics¶
- Confidence and Correctness Metrics binned by confidence intervals
The module leverages JAX for efficient numerical computation and supports flexible evaluation for diverse model outputs.
correctness ¶
Determine if each target label matches the top-1 prediction.
Computes a binary indicator for whether the predicted class matches the
target class. If the target is a 2D array, it is first reduced to its
class index using argmax
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pred
|
Array
|
Array of predictions with shape |
required |
target
|
Array
|
Array of ground truth labels, either 1D (class indices) or 2D (one-hot encoded). |
required |
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
Array
|
Boolean array of shape |
Source code in laplax/eval/metrics.py
accuracy ¶
Compute top-k accuracy for specified values of k.
For each k in top_k
, this function calculates the fraction of samples
where the ground truth label is among the top-k predictions. If the target
is a 2D array, it is reduced to its class index using argmax
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pred
|
Array
|
Array of predictions with shape |
required |
target
|
Array
|
Array of ground truth labels, either 1D (class indices) or 2D (one-hot encoded). |
required |
top_k
|
tuple[int]
|
Tuple of integers specifying the values of k for top-k accuracy. |
(1,)
|
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
list[Array]
|
A list of accuracies corresponding to each k in |
Source code in laplax/eval/metrics.py
cross_entropy ¶
Compute cross-entropy between two probability distributions.
This function calculates the cross-entropy of prob_p
relative to prob_q
,
summing over the specified axis.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prob_p
|
Array
|
Array of true probability distributions. |
required |
prob_q
|
Array
|
Array of predicted probability distributions. |
required |
axis
|
int
|
Axis along which to compute the cross-entropy (default: -1). |
-1
|
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
Array
|
Cross-entropy values for each sample. |
Source code in laplax/eval/metrics.py
multiclass_brier ¶
Compute the multiclass Brier score.
The Brier score is a measure of the accuracy of probabilistic predictions. For multiclass classification, it calculates the mean squared difference between the predicted probabilities and the true target.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prob
|
Array
|
Array of predicted probabilities with shape |
required |
target
|
Array
|
Array of ground truth labels, either 1D (class indices) or 2D (one-hot encoded). |
required |
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
Array
|
Mean Brier score across all samples. |
Source code in laplax/eval/metrics.py
calculate_bin_metrics ¶
calculate_bin_metrics(confidence: Array, correctness: Array, num_bins: int = 15, **kwargs: Kwargs) -> tuple[Array, Array, Array]
Calculate bin-wise metrics for confidence and correctness.
Computes the proportion of samples, average confidence, and average accuracy within each bin, where the bins are defined by evenly spaced confidence intervals.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
confidence
|
Array
|
Array of predicted confidence values with shape |
required |
correctness
|
Array
|
Array of correctness labels (0 or 1) with shape |
required |
num_bins
|
int
|
Number of bins for dividing the confidence range (default: 15). |
15
|
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
tuple[Array, Array, Array]
|
Tuple of arrays containing:
|
Source code in laplax/eval/metrics.py
calibration_error ¶
calibration_error(confidence: Array, correctness: Array, num_bins: int, norm: CalibrationErrorNorm, **kwargs: Kwargs) -> Array
Compute the expected/maximum calibration error.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
confidence
|
Array
|
Float tensor of shape (n,) containing predicted confidences. |
required |
correctness
|
Array
|
Float tensor of shape (n,) containing the true correctness labels. |
required |
num_bins
|
int
|
Number of equally sized bins. |
required |
norm
|
CalibrationErrorNorm
|
Whether to return ECE (L1 norm) or MCE (inf norm). |
required |
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
Array
|
The ECE/MCE. |
Source code in laplax/eval/metrics.py
expected_calibration_error ¶
expected_calibration_error(confidence: Array, correctness: Array, num_bins: int, **kwargs: Kwargs) -> Array
Compute the expected calibration error.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
confidence
|
Array
|
Float tensor of shape (n,) containing predicted confidences. |
required |
correctness
|
Array
|
Float tensor of shape (n,) containing the true correctness labels. |
required |
num_bins
|
int
|
Number of equally sized bins. |
required |
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
Array
|
The ECE/MCE. |
Source code in laplax/eval/metrics.py
maximum_calibration_error ¶
maximum_calibration_error(confidence: Array, correctness: Array, num_bins: int, **kwargs: Kwargs) -> Array
Compute the maximum calibration error.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
confidence
|
Array
|
Float tensor of shape (n,) containing predicted confidences. |
required |
correctness
|
Array
|
Float tensor of shape (n,) containing the true correctness labels. |
required |
num_bins
|
int
|
Number of equally sized bins. |
required |
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
Array
|
The ECE/MCE. |
Source code in laplax/eval/metrics.py
chi_squared ¶
chi_squared(pred_mean: Array, pred_std: Array, target: Array, *, averaged: bool = True, **kwargs: Kwargs) -> Float
Estimate the q-value for predictions.
The \(\chi^2\)-value is a measure of the squared error normalized by the predicted variance.
Mathematically:
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pred_mean
|
Array
|
Array of predicted means. |
required |
pred_std
|
Array
|
Array of predicted standard deviations. |
required |
target
|
Array
|
Array of ground truth labels. |
required |
averaged
|
bool
|
Whether to return the mean or sum of the q-values. |
True
|
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
Float
|
The estimated q-value. |
Source code in laplax/eval/metrics.py
chi_squared_zero ¶
Computes a calibration metric for a given set of predictions.
The calculated metric is the ratio between the error of the prediction and the variance of the output uncertainty.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**predictions
|
Kwargs
|
Keyword arguments representing the model predictions, typically including mean, variance, and target. |
{}
|
Returns:
Type | Description |
---|---|
Float
|
The calibration metric value. |
Source code in laplax/eval/metrics.py
estimate_rmse ¶
Estimate the root mean squared error (RMSE) for predictions.
Mathematically:
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pred_mean
|
Array
|
Array of predicted means. |
required |
target
|
Array
|
Array of ground truth labels. |
required |
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
Float
|
The RMSE value. |
Source code in laplax/eval/metrics.py
crps_gaussian ¶
crps_gaussian(pred_mean: Array, pred_std: Array, target: Array, *, scaled: bool = True, **kwargs: Kwargs) -> Float
The negatively oriented continuous ranked probability score for Gaussians.
Negatively oriented means a smaller value is more desirable.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pred_mean
|
Array
|
1D array of the predicted means for the held out dataset. |
required |
pred_std
|
Array
|
1D array of he predicted standard deviations for the held out dataset. |
required |
target
|
Array
|
1D array of the true labels in the held out dataset. |
required |
scaled
|
bool
|
Whether to scale the score by size of held out set. |
True
|
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
Float
|
The crps for the heldout set. |
Raises:
Type | Description |
---|---|
ValueError
|
pred_mean, pred_std, and target have incompatible shapes. |
Source code in laplax/eval/metrics.py
nll_gaussian ¶
nll_gaussian(pred_mean: Array, pred_std: Array, target: Array, *, scaled: bool = True, **kwargs: Kwargs) -> Float
Compute the negative log-likelihood (NLL) for a Gaussian distribution.
The NLL quantifies how well the predictive distribution fits the data,
assuming a Gaussian distribution characterized by pred
(mean) and pred_std
(standard deviation).
Mathematically:
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pred_mean
|
Array
|
Array of predicted means for the dataset. |
required |
pred_std
|
Array
|
Array of predicted standard deviations for the dataset. |
required |
target
|
Array
|
Array of ground truth labels for the dataset. |
required |
scaled
|
bool
|
Whether to normalize the NLL by the number of samples (default: True). |
True
|
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
Float
|
The computed NLL value. |
Raises:
Type | Description |
---|---|
ValueError
|
pred_mean, pred_std, and target have incompatible shapes. |