laplax.eval.pushforward
Pushforward Functions for Weight Space Uncertainty.
This module provides functions to propagate uncertainty in weight space to
output uncertainty. It includes methods for ensemble-based Monte Carlo
predictions and linearized approximations for uncertainty estimation, as well as to
create the posterior_gp_kernel
.
set_get_weight_sample ¶
set_get_weight_sample(key: KeyType | None, mean_params: Params, scale_mv: Callable[[Array], Array], num_samples: int, **kwargs: Kwargs) -> Callable[[int], Params]
Creates a function to sample weights from a Gaussian distribution.
This function generates weight samples from a Gaussian distribution characterized by the mean and the scale matrix-vector product function. It supports precomputation of samples for efficiency and assumes a fixed number of required samples.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key
|
KeyType | None
|
PRNG key for generating random samples. |
required |
mean_params
|
Params
|
Mean of the weight-space Gaussian distribution. |
required |
scale_mv
|
Callable[[Array], Array]
|
Function for the scale matrix-vector product. |
required |
num_samples
|
int
|
Number of weight samples to generate. |
required |
**kwargs
|
Kwargs
|
Additional arguments, including:
|
{}
|
Returns:
Type | Description |
---|---|
Callable[[int], Params]
|
A function that generates a specific weight sample by index. |
Source code in laplax/eval/pushforward.py
get_dist_state ¶
get_dist_state(mean_params: Params, model_fn: ModelFn, posterior_state: PosteriorState, *, linearized: bool = False, num_samples: int = 0, key: KeyType | None = None, **kwargs: Kwargs) -> DistState
Construct the distribution state for uncertainty propagation.
The distribution state contains information needed to propagate uncertainty from the posterior over weights to predictions. It forms the state for both linearized and ensemble-based Monte Carlo approaches.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mean_params
|
Params
|
Mean of the posterior (model parameters). |
required |
model_fn
|
ModelFn
|
The model function to evaluate. |
required |
posterior_state
|
PosteriorState
|
The posterior distribution state. |
required |
linearized
|
bool
|
Whether to consider a linearized approximation. |
False
|
num_samples
|
int
|
Number of weight samples for Monte Carlo methods. |
0
|
key
|
KeyType | None
|
PRNG key for generating random samples. |
None
|
**kwargs
|
Kwargs
|
Additional arguments, including:
|
{}
|
Returns:
Type | Description |
---|---|
DistState
|
A dictionary containing functions and parameters for uncertainty propagation. |
Source code in laplax/eval/pushforward.py
nonlin_setup ¶
nonlin_setup(results: dict[str, Array], aux: dict[str, Any], input: InputArray, dist_state: DistState, **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]
Prepare ensemble-based Monte Carlo predictions.
This function generates predictions for multiple weight samples and stores them in the auxiliary dictionary.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
results
|
dict[str, Array]
|
Dictionary to store computed results. |
required |
aux
|
dict[str, Any]
|
Auxiliary data, including the model function. |
required |
input
|
InputArray
|
Input data for prediction. |
required |
dist_state
|
DistState
|
Distribution state containing weight sampling functions. |
required |
**kwargs
|
Kwargs
|
Additional arguments, including:
|
{}
|
Returns:
Type | Description |
---|---|
tuple[dict[str, Array], dict[str, Any]]
|
Updated |
Source code in laplax/eval/pushforward.py
nonlin_pred_mean ¶
nonlin_pred_mean(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]
Compute the mean of ensemble predictions.
This function calculates the mean of prediction ensemble generated from multiple weight samples in an ensemble-based Monte Carlo approach.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
results
|
dict[str, Array]
|
Dictionary to store computed results. |
required |
aux
|
dict[str, Any]
|
Auxiliary data containing the prediction ensemble. |
required |
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
tuple[dict[str, Array], dict[str, Any]]
|
Updated |
Source code in laplax/eval/pushforward.py
nonlin_pred_cov ¶
nonlin_pred_cov(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]
Compute the covariance of ensemble predictions.
This function calculates the empirical covariance of the ensemble of predictions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
results
|
dict[str, Array]
|
Dictionary to store computed results. |
required |
aux
|
dict[str, Any]
|
Auxiliary data containing the prediction ensemble. |
required |
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
tuple[dict[str, Array], dict[str, Any]]
|
Updated |
Source code in laplax/eval/pushforward.py
nonlin_pred_var ¶
nonlin_pred_var(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]
Compute the variance of ensemble predictions.
This function calculates the empirical variance of the ensemble of predictions. If the covariance is already available, it extracts the diagonal.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
results
|
dict[str, Array]
|
Dictionary to store computed results. |
required |
aux
|
dict[str, Any]
|
Auxiliary data containing the prediction ensemble. |
required |
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
tuple[dict[str, Array], dict[str, Any]]
|
Updated |
Source code in laplax/eval/pushforward.py
nonlin_pred_std ¶
nonlin_pred_std(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]
Compute the standard deviation of ensemble predictions.
This function calculates the empirical standard deviation of the ensemble of predictions. If the variance is already available, then it takes the square root.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
results
|
dict[str, Array]
|
Dictionary to store computed results. |
required |
aux
|
dict[str, Any]
|
Auxiliary data containing the prediction ensemble. |
required |
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
tuple[dict[str, Array], dict[str, Any]]
|
Updated |
Source code in laplax/eval/pushforward.py
nonlin_samples ¶
nonlin_samples(results: dict[str, Array], aux: dict[str, Any], num_samples: int = 5, **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]
Select samples from ensemble.
This function selects a subset of samples from the ensemble of predictions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
results
|
dict[str, Array]
|
Dictionary to store computed results. |
required |
aux
|
dict[str, Any]
|
Auxiliary data containing the prediction ensemble. |
required |
num_samples
|
int
|
Number of samples to select. |
5
|
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
tuple[dict[str, Array], dict[str, Any]]
|
Updated |
Source code in laplax/eval/pushforward.py
nonlin_special_pred_act ¶
nonlin_special_pred_act(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]
Apply special predictive methods to nonlinear Laplace for classification.
This function applies special predictive methods (Laplace Bridge, Mean Field-0, Mean Field-1, or Mean Field-2) to nonlinear Laplace for classification. These methods transform the predictions into probability space using specific formulations rather than Monte Carlo sampling.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
results
|
dict[str, Array]
|
Dictionary to store computed results. |
required |
aux
|
dict[str, Any]
|
Auxiliary data containing prediction information. |
required |
**kwargs
|
Kwargs
|
Additional arguments, including:
|
{}
|
Returns:
Type | Description |
---|---|
tuple[dict[str, Array], dict[str, Any]]
|
Updated |
Source code in laplax/eval/pushforward.py
nonlin_mc_pred_act ¶
nonlin_mc_pred_act(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]
Compute Monte Carlo predictions for nonlinear Laplace classification.
This function generates Monte Carlo predictions for classification by averaging softmax probabilities across different weight samples. If samples are not already available, it generates them first.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
results
|
dict[str, Array]
|
Dictionary to store computed results. |
required |
aux
|
dict[str, Any]
|
Auxiliary data containing prediction information. |
required |
**kwargs
|
Kwargs
|
Additional arguments passed to sample generation. |
{}
|
Returns:
Type | Description |
---|---|
tuple[dict[str, Array], dict[str, Any]]
|
Updated |
Source code in laplax/eval/pushforward.py
set_output_mv ¶
set_output_mv(posterior_state: Posterior, input: InputArray, jvp: Callable[[InputArray, Params], PredArray], vjp: Callable[[InputArray, PredArray], Params]) -> dict
Create matrix-vector product functions for output covariance and scale.
This function propagates uncertainty from weight space to output space by constructing matrix-vector product functions for the output covariance and scale matrices. These functions utilize the posterior's covariance and scale operators in conjunction with Jacobian-vector products (JVP) and vector-Jacobian products (VJP).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
posterior_state
|
Posterior
|
The posterior state containing covariance and scale operators. |
required |
input
|
InputArray
|
Input data for the model. |
required |
jvp
|
Callable[[InputArray, Params], PredArray]
|
Function for computing Jacobian-vector products. |
required |
vjp
|
Callable[[InputArray, PredArray], Params]
|
Function for computing vector-Jacobian products. |
required |
Returns:
Type | Description |
---|---|
dict
|
A dictionary with:
|
Source code in laplax/eval/pushforward.py
lin_setup ¶
lin_setup(results: dict[str, Array], aux: dict[str, Any], input: InputArray, dist_state: DistState, **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]
Prepare linearized pushforward functions for uncertainty propagation.
This function sets up matrix-vector product functions for the output covariance and scale matrices in a linearized pushforward framework. It verifies the validity of input components (posterior state, JVP, and VJP) and stores the resulting functions in the auxiliary dictionary.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
results
|
dict[str, Array]
|
Dictionary to store computed results. |
required |
aux
|
dict[str, Any]
|
Auxiliary data to store matrix-vector product functions. |
required |
input
|
InputArray
|
Input data for the model. |
required |
dist_state
|
DistState
|
Distribution state containing posterior state, JVP, and VJP functions. |
required |
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
tuple[dict[str, Array], dict[str, Any]]
|
Updated |
Raises:
Type | Description |
---|---|
TypeError
|
When the posterior_state, vjp, or jvp has an incorrect type. |
Source code in laplax/eval/pushforward.py
lin_pred_mean ¶
lin_pred_mean(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]
Restore the linearized predictions.
This function extracts the prediction from the results dictionary and stores it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
results
|
dict[str, Array]
|
Dictionary to store computed results. |
required |
aux
|
dict[str, Any]
|
Auxiliary data (ignored). |
required |
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
tuple[dict[str, Array], dict[str, Any]]
|
Updated |
Note
This function is used for the linearized mean prediction.
Source code in laplax/eval/pushforward.py
lin_pred_var ¶
lin_pred_var(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]
Compute and store the variance of the linearized predictions.
This function calculates the variance of predictions by extracting the diagonal of the output covariance matrix.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
results
|
dict[str, Array]
|
Dictionary containing computed results. |
required |
aux
|
dict[str, Any]
|
Auxiliary data, including covariance matrix functions. |
required |
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
tuple[dict[str, Array], dict[str, Any]]
|
Updated |
Source code in laplax/eval/pushforward.py
lin_pred_std ¶
lin_pred_std(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]
Compute and store the standard deviation of the linearized predictions.
This function calculates the standard deviation by taking the square root of the predicted variance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
results
|
dict[str, Array]
|
Dictionary containing computed results. |
required |
aux
|
dict[str, Any]
|
Auxiliary data (ignored). |
required |
**kwargs
|
Kwargs
|
Additional arguments. |
{}
|
Returns:
Type | Description |
---|---|
tuple[dict[str, Array], dict[str, Any]]
|
Updated |
Source code in laplax/eval/pushforward.py
lin_pred_cov ¶
lin_pred_cov(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]
Compute and store the covariance of the linearized predictions.
This function computes the full output covariance matrix in dense form using the covariance matrix-vector product function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
results
|
dict[str, Array]
|
Dictionary containing computed results. |
required |
aux
|
dict[str, Any]
|
Auxiliary data containing covariance matrix-vector product functions. |
required |
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
tuple[dict[str, Array], dict[str, Any]]
|
Updated |
Source code in laplax/eval/pushforward.py
lin_samples ¶
lin_samples(results: dict[str, Array], aux: dict[str, Any], dist_state: DistState, **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]
Generate and store samples from the linearized distribution.
This function computes samples in the output space by applying the scale matrix to weight samples generated from the posterior distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
results
|
dict[str, Array]
|
Dictionary to store computed results. |
required |
aux
|
dict[str, Any]
|
Auxiliary data containing the scale matrix function. |
required |
dist_state
|
DistState
|
Distribution state containing sampling functions and sample count. |
required |
**kwargs
|
Kwargs
|
Additional arguments, including:
|
{}
|
Returns:
Type | Description |
---|---|
tuple[dict[str, Array], dict[str, Any]]
|
Updated |
Source code in laplax/eval/pushforward.py
lin_special_pred_act ¶
lin_special_pred_act(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]
Apply special predictive methods to linearized Laplace for classification.
This function applies special predictive methods (Laplace Bridge, Mean Field-0, Mean Field-1, or Mean Field-2) to linearized Laplace for classification. These methods transform the predictions into probability space using specific formulations rather than Monte Carlo sampling.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
results
|
dict[str, Array]
|
Dictionary to store computed results. |
required |
aux
|
dict[str, Any]
|
Auxiliary data containing prediction information. |
required |
**kwargs
|
Kwargs
|
Additional arguments, including:
|
{}
|
Returns:
Type | Description |
---|---|
tuple[dict[str, Array], dict[str, Any]]
|
Updated |
Source code in laplax/eval/pushforward.py
lin_mc_pred_act ¶
lin_mc_pred_act(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]
Compute Monte Carlo predictions for linear Laplace classification.
This function generates Monte Carlo predictions for classification by averaging softmax probabilities across different weight samples. If samples are not already available, it generates them first.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
results
|
dict[str, Array]
|
Dictionary to store computed results. |
required |
aux
|
dict[str, Any]
|
Auxiliary data containing prediction information. |
required |
**kwargs
|
Kwargs
|
Additional arguments passed to sample generation. |
{}
|
Returns:
Type | Description |
---|---|
tuple[dict[str, Array], dict[str, Any]]
|
Updated |
Source code in laplax/eval/pushforward.py
set_prob_predictive ¶
set_prob_predictive(model_fn: ModelFn, mean_params: Params, dist_state: DistState, pushforward_fns: list[Callable], **kwargs: Kwargs) -> Callable[[InputArray], dict[str, Array]]
Create a probabilistic predictive function.
This function generates a predictive callable that computes uncertainty-aware predictions using a set of pushforward functions. The generated function can evaluate mean predictions and propagate uncertainty from the posterior over weights to output space.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_fn
|
ModelFn
|
The model function to evaluate, which takes input and parameters. |
required |
mean_params
|
Params
|
The mean of the posterior distribution over model parameters. |
required |
dist_state
|
DistState
|
The distribution state for uncertainty propagation, containing functions and parameters related to the posterior. |
required |
pushforward_fns
|
list[Callable]
|
A list of pushforward functions, such as mean, variance, and covariance. |
required |
**kwargs
|
Kwargs
|
Additional arguments passed to the pushforward functions. |
{}
|
Returns:
Type | Description |
---|---|
Callable[[InputArray], dict[str, Array]]
|
A function that takes an input array and returns a dictionary of predictions and uncertainty metrics. |
Source code in laplax/eval/pushforward.py
set_nonlin_pushforward ¶
set_nonlin_pushforward(model_fn: ModelFn, mean_params: Params, posterior_fn: Callable[[PriorArguments, Int], Posterior], prior_arguments: PriorArguments, *, key: KeyType, loss_scaling_factor: Float = 1.0, pushforward_fns: list = DEFAULT_NONLIN_FINALIZE_FNS, num_samples: int = 100, **kwargs: Kwargs) -> Callable
Construct a Monte Carlo pushforward predictive function.
This function creates a probabilistic predictive callable that computes ensemble-based Monte Carlo (MC) predictions and propagates uncertainty from weight space to output space using sampling.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_fn
|
ModelFn
|
The model function to evaluate, which takes input and parameters. |
required |
mean_params
|
Params
|
The mean of the posterior distribution over model parameters. |
required |
posterior_fn
|
Callable[[PriorArguments, Int], Posterior]
|
A callable that generates the posterior state from prior arguments. |
required |
prior_arguments
|
PriorArguments
|
Arguments for defining the prior distribution. |
required |
key
|
KeyType
|
PRNG key for generating random samples. |
required |
loss_scaling_factor
|
Float
|
Factor by which the user-provided loss function is scaled. Defaults to 1.0. |
1.0
|
pushforward_fns
|
list
|
A list of Monte Carlo pushforward functions
(default: |
DEFAULT_NONLIN_FINALIZE_FNS
|
num_samples
|
int
|
Number of weight samples for Monte Carlo predictions. |
100
|
**kwargs
|
Kwargs
|
Additional arguments passed to the pushforward functions. |
{}
|
Returns:
Type | Description |
---|---|
Callable
|
A probabilistic predictive function that computes predictions and uncertainty metrics using Monte Carlo sampling. |
Source code in laplax/eval/pushforward.py
set_lin_pushforward ¶
set_lin_pushforward(model_fn: ModelFn, mean_params: Params, posterior_fn: Callable[[PriorArguments, Int], Posterior], prior_arguments: PriorArguments, loss_scaling_factor: Float = 1.0, pushforward_fns: list = DEFAULT_LIN_FINALIZE_FNS, **kwargs: Kwargs) -> Callable
Construct a linearized pushforward predictive function.
This function generates a probabilistic predictive callable that computes predictions and propagates uncertainty using a linearized approximation of the model function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_fn
|
ModelFn
|
The model function to evaluate, which takes input and parameters. |
required |
mean_params
|
Params
|
The mean of the posterior distribution over model parameters. |
required |
posterior_fn
|
Callable[[PriorArguments, Int], Posterior]
|
A callable that generates the posterior state from prior arguments. |
required |
prior_arguments
|
PriorArguments
|
Arguments for defining the prior distribution. |
required |
loss_scaling_factor
|
Float
|
Factor by which the user-provided loss function is scaled. Defaults to 1.0. |
1.0
|
pushforward_fns
|
list
|
A list of linearized pushforward functions
(default: |
DEFAULT_LIN_FINALIZE_FNS
|
**kwargs
|
Kwargs
|
Additional arguments passed to the pushforward functions, including:
|
{}
|
Returns:
Type | Description |
---|---|
Callable
|
A probabilistic predictive function that computes predictions and uncertainty metrics using a linearized approximation. |
Source code in laplax/eval/pushforward.py
set_posterior_gp_kernel ¶
set_posterior_gp_kernel(model_fn: ModelFn, mean: Params, posterior_fn: Callable[[PriorArguments, Int], Posterior], prior_arguments: PriorArguments, loss_scaling_factor: Float = 1.0, **kwargs: Kwargs) -> tuple[Callable, DistState]
Construct a kernel matrix-vector product function for a posterior GP.
This function generates a callable for the kernel matrix-vector product (MVP) in a posterior GP framework. The kernel MVP is constructed using the posterior state and propagates uncertainty in weight space to output space via linearization. The resulting kernel MVP can optionally return a dense matrix representation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_fn
|
ModelFn
|
The model function to evaluate, which takes input and parameters. |
required |
mean
|
Params
|
The mean of the posterior distribution over model parameters. |
required |
posterior_fn
|
Callable[[PriorArguments, Int], Posterior]
|
A callable that generates the posterior state from prior arguments. |
required |
prior_arguments
|
PriorArguments
|
Arguments for defining the prior distribution. |
required |
loss_scaling_factor
|
Float
|
Factor by which the user-provided loss function is scaled. Defaults to 1.0. |
1.0
|
**kwargs
|
Kwargs
|
Additional arguments, including:
|
{}
|
Returns:
Type | Description |
---|---|
tuple[Callable, DistState]
|
A kernel MVP callable or a dense kernel matrix function, and the distribution state containing posterior information. |
Raises:
Type | Description |
---|---|
ValueError
|
If |