Skip to content

laplax.eval.pushforward

Pushforward Functions for Weight Space Uncertainty.

This module provides functions to propagate uncertainty in weight space to output uncertainty. It includes methods for ensemble-based Monte Carlo predictions and linearized approximations for uncertainty estimation, as well as to create the posterior_gp_kernel.

set_get_weight_sample

set_get_weight_sample(key: KeyType | None, mean_params: Params, scale_mv: Callable[[Array], Array], num_samples: int, **kwargs: Kwargs) -> Callable[[int], Params]

Creates a function to sample weights from a Gaussian distribution.

This function generates weight samples from a Gaussian distribution characterized by the mean and the scale matrix-vector product function. It supports precomputation of samples for efficiency and assumes a fixed number of required samples.

Parameters:

Name Type Description Default
key KeyType | None

PRNG key for generating random samples.

required
mean_params Params

Mean of the weight-space Gaussian distribution.

required
scale_mv Callable[[Array], Array]

Function for the scale matrix-vector product.

required
num_samples int

Number of weight samples to generate.

required
**kwargs Kwargs

Additional arguments, including:

  • set_get_weight_sample_precompute: Controls whether samples are precomputed.
{}

Returns:

Type Description
Callable[[int], Params]

A function that generates a specific weight sample by index.

Source code in laplax/eval/pushforward.py
def set_get_weight_sample(
    key: KeyType | None,
    mean_params: Params,
    scale_mv: Callable[[Array], Array],
    num_samples: int,
    **kwargs: Kwargs,
) -> Callable[[int], Params]:
    """Creates a function to sample weights from a Gaussian distribution.

    This function generates weight samples from a Gaussian distribution
    characterized by the mean and the scale matrix-vector product function.
    It supports precomputation of samples for efficiency and assumes a fixed
    number of required samples.

    Args:
        key: PRNG key for generating random samples.
        mean_params: Mean of the weight-space Gaussian distribution.
        scale_mv: Function for the scale matrix-vector product.
        num_samples: Number of weight samples to generate.
        **kwargs: Additional arguments, including:

            - `set_get_weight_sample_precompute`: Controls whether samples are
              precomputed.

    Returns:
        A function that generates a specific weight sample by index.
    """
    if key is None:
        key = jax.random.key(0)

    keys = jax.random.split(key, num_samples)

    def get_weight_sample(idx):
        return util.tree.normal_like(keys[idx], mean=mean_params, scale_mv=scale_mv)

    return precompute_list(
        get_weight_sample,
        jnp.arange(num_samples),
        precompute=kwargs.get(
            "set_get_weight_sample_precompute", kwargs.get("precompute")
        ),
        **kwargs,
    )

get_dist_state

get_dist_state(mean_params: Params, model_fn: ModelFn, posterior_state: PosteriorState, *, linearized: bool = False, num_samples: int = 0, key: KeyType | None = None, **kwargs: Kwargs) -> DistState

Construct the distribution state for uncertainty propagation.

The distribution state contains information needed to propagate uncertainty from the posterior over weights to predictions. It forms the state for both linearized and ensemble-based Monte Carlo approaches.

Parameters:

Name Type Description Default
mean_params Params

Mean of the posterior (model parameters).

required
model_fn ModelFn

The model function to evaluate.

required
posterior_state PosteriorState

The posterior distribution state.

required
linearized bool

Whether to consider a linearized approximation.

False
num_samples int

Number of weight samples for Monte Carlo methods.

0
key KeyType | None

PRNG key for generating random samples.

None
**kwargs Kwargs

Additional arguments, including:

  • set_get_weight_sample_precompute.
{}

Returns:

Type Description
DistState

A dictionary containing functions and parameters for uncertainty propagation.

Source code in laplax/eval/pushforward.py
def get_dist_state(
    mean_params: Params,
    model_fn: ModelFn,
    posterior_state: PosteriorState,
    *,
    linearized: bool = False,
    num_samples: int = 0,
    key: KeyType | None = None,
    **kwargs: Kwargs,
) -> DistState:
    """Construct the distribution state for uncertainty propagation.

    The distribution state contains information needed to propagate uncertainty
    from the posterior over weights to predictions. It forms the state for both
    linearized and ensemble-based Monte Carlo approaches.

    Args:
        mean_params: Mean of the posterior (model parameters).
        model_fn: The model function to evaluate.
        posterior_state: The posterior distribution state.
        linearized: Whether to consider a linearized approximation.
        num_samples: Number of weight samples for Monte Carlo methods.
        key: PRNG key for generating random samples.
        **kwargs: Additional arguments, including:

            - `set_get_weight_sample_precompute`.

    Returns:
        A dictionary containing functions and parameters for uncertainty propagation.
    """
    dist_state = {
        "posterior_state": posterior_state,
        "num_samples": num_samples,
    }

    if linearized:
        # Create pushforward functions
        def pf_jvp(input: InputArray, vector: Params) -> PredArray:
            return jax.jvp(
                lambda p: model_fn(input=input, params=p),
                (mean_params,),
                (vector,),
            )[1]

        def pf_vjp(input: InputArray, vector: PredArray) -> Params:
            out, vjp_fun = jax.vjp(
                lambda p: model_fn(input=input, params=p), mean_params
            )
            return vjp_fun(vector.reshape(out.shape))

        dist_state["vjp"] = pf_vjp
        dist_state["jvp"] = pf_jvp

    if num_samples > 0:
        weight_sample_mean = (
            util.tree.zeros_like(mean_params) if linearized else mean_params
        )

        # Create weight sample function
        get_weight_samples = set_get_weight_sample(
            key,
            mean_params=weight_sample_mean,
            scale_mv=posterior_state.scale_mv(posterior_state.state),
            num_samples=num_samples,
            **kwargs,
        )
        dist_state["get_weight_samples"] = get_weight_samples

    return dist_state

nonlin_setup

nonlin_setup(results: dict[str, Array], aux: dict[str, Any], input: InputArray, dist_state: DistState, **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]

Prepare ensemble-based Monte Carlo predictions.

This function generates predictions for multiple weight samples and stores them in the auxiliary dictionary.

Parameters:

Name Type Description Default
results dict[str, Array]

Dictionary to store computed results.

required
aux dict[str, Any]

Auxiliary data, including the model function.

required
input InputArray

Input data for prediction.

required
dist_state DistState

Distribution state containing weight sampling functions.

required
**kwargs Kwargs

Additional arguments, including:

  • nonlin_setup_batch_size: Controls batch size for computing predictions.
{}

Returns:

Type Description
tuple[dict[str, Array], dict[str, Any]]

Updated results and aux.

Source code in laplax/eval/pushforward.py
def nonlin_setup(
    results: dict[str, Array],
    aux: dict[str, Any],
    input: InputArray,
    dist_state: DistState,
    **kwargs: Kwargs,
) -> tuple[dict[str, Array], dict[str, Any]]:
    """Prepare ensemble-based Monte Carlo predictions.

    This function generates predictions for multiple weight samples and stores
    them in the auxiliary dictionary.

    Args:
        results: Dictionary to store computed results.
        aux: Auxiliary data, including the model function.
        input: Input data for prediction.
        dist_state: Distribution state containing weight sampling functions.
        **kwargs: Additional arguments, including:

            - `nonlin_setup_batch_size`: Controls batch size for computing predictions.

    Returns:
        Updated `results` and `aux`.
    """

    def compute_pred_ptw(idx: int) -> PredArray:
        weight_sample = dist_state["get_weight_samples"](idx)
        return aux["model_fn"](input=input, params=weight_sample)

    aux["pred_ensemble"] = jax.lax.map(
        compute_pred_ptw,
        jnp.arange(dist_state["num_samples"]),
        batch_size=kwargs.get(
            "nonlin_setup_batch_size", kwargs.get("weight_batch_size")
        ),
    )

    return results, aux

nonlin_pred_mean

nonlin_pred_mean(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]

Compute the mean of ensemble predictions.

This function calculates the mean of prediction ensemble generated from multiple weight samples in an ensemble-based Monte Carlo approach.

Parameters:

Name Type Description Default
results dict[str, Array]

Dictionary to store computed results.

required
aux dict[str, Any]

Auxiliary data containing the prediction ensemble.

required
**kwargs Kwargs

Additional arguments (ignored).

{}

Returns:

Type Description
tuple[dict[str, Array], dict[str, Any]]

Updated results and aux.

Source code in laplax/eval/pushforward.py
def nonlin_pred_mean(
    results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs
) -> tuple[dict[str, Array], dict[str, Any]]:
    """Compute the mean of ensemble predictions.

    This function calculates the mean of prediction ensemble generated from
    multiple weight samples in an ensemble-based Monte Carlo approach.

    Args:
        results: Dictionary to store computed results.
        aux: Auxiliary data containing the prediction ensemble.
        **kwargs: Additional arguments (ignored).

    Returns:
        Updated `results` and `aux`.
    """
    del kwargs

    pred_ensemble = aux["pred_ensemble"]
    results["pred_mean"] = util.tree.mean(pred_ensemble, axis=0)
    return results, aux

nonlin_pred_cov

nonlin_pred_cov(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]

Compute the covariance of ensemble predictions.

This function calculates the empirical covariance of the ensemble of predictions.

Parameters:

Name Type Description Default
results dict[str, Array]

Dictionary to store computed results.

required
aux dict[str, Any]

Auxiliary data containing the prediction ensemble.

required
**kwargs Kwargs

Additional arguments (ignored).

{}

Returns:

Type Description
tuple[dict[str, Array], dict[str, Any]]

Updated results and aux.

Source code in laplax/eval/pushforward.py
def nonlin_pred_cov(
    results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs
) -> tuple[dict[str, Array], dict[str, Any]]:
    """Compute the covariance of ensemble predictions.

    This function calculates the empirical covariance of the ensemble of predictions.

    Args:
        results: Dictionary to store computed results.
        aux: Auxiliary data containing the prediction ensemble.
        **kwargs: Additional arguments (ignored).

    Returns:
        Updated `results` and `aux`.
    """
    del kwargs

    pred_ensemble = aux["pred_ensemble"]

    results["pred_cov"] = util.tree.cov(
        pred_ensemble.reshape(pred_ensemble.shape[0], -1), rowvar=False
    )
    return results, aux

nonlin_pred_var

nonlin_pred_var(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]

Compute the variance of ensemble predictions.

This function calculates the empirical variance of the ensemble of predictions. If the covariance is already available, it extracts the diagonal.

Parameters:

Name Type Description Default
results dict[str, Array]

Dictionary to store computed results.

required
aux dict[str, Any]

Auxiliary data containing the prediction ensemble.

required
**kwargs Kwargs

Additional arguments (ignored).

{}

Returns:

Type Description
tuple[dict[str, Array], dict[str, Any]]

Updated results and aux.

Source code in laplax/eval/pushforward.py
def nonlin_pred_var(
    results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs
) -> tuple[dict[str, Array], dict[str, Any]]:
    """Compute the variance of ensemble predictions.

    This function calculates the empirical variance of the ensemble of predictions.
    If the covariance is already available, it extracts the diagonal.

    Args:
        results: Dictionary to store computed results.
        aux: Auxiliary data containing the prediction ensemble.
        **kwargs: Additional arguments (ignored).

    Returns:
        Updated `results` and `aux`.
    """
    del kwargs

    if "pred_cov" in results:
        pred_cov = results["pred_cov"]
        if pred_cov.ndim > 0:
            pred_cov = jnp.diagonal(pred_cov)
        results["pred_var"] = pred_cov
    else:
        pred_ensemble = aux["pred_ensemble"]
        results["pred_var"] = util.tree.var(pred_ensemble, axis=0)
    return results, aux

nonlin_pred_std

nonlin_pred_std(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]

Compute the standard deviation of ensemble predictions.

This function calculates the empirical standard deviation of the ensemble of predictions. If the variance is already available, then it takes the square root.

Parameters:

Name Type Description Default
results dict[str, Array]

Dictionary to store computed results.

required
aux dict[str, Any]

Auxiliary data containing the prediction ensemble.

required
**kwargs Kwargs

Additional arguments (ignored).

{}

Returns:

Type Description
tuple[dict[str, Array], dict[str, Any]]

Updated results and aux.

Source code in laplax/eval/pushforward.py
def nonlin_pred_std(
    results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs
) -> tuple[dict[str, Array], dict[str, Any]]:
    """Compute the standard deviation of ensemble predictions.

    This function calculates the empirical standard deviation of the ensemble of
    predictions. If the variance is already available, then it takes the square root.

    Args:
        results: Dictionary to store computed results.
        aux: Auxiliary data containing the prediction ensemble.
        **kwargs: Additional arguments (ignored).

    Returns:
        Updated `results` and `aux`.
    """
    del kwargs

    if "pred_var" in results:
        results["pred_std"] = jnp.sqrt(results["pred_var"])
    else:
        pred_ensemble = aux["pred_ensemble"]
        results["pred_std"] = util.tree.std(pred_ensemble, axis=0)
    return results, aux

nonlin_samples

nonlin_samples(results: dict[str, Array], aux: dict[str, Any], num_samples: int = 5, **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]

Select samples from ensemble.

This function selects a subset of samples from the ensemble of predictions.

Parameters:

Name Type Description Default
results dict[str, Array]

Dictionary to store computed results.

required
aux dict[str, Any]

Auxiliary data containing the prediction ensemble.

required
num_samples int

Number of samples to select.

5
**kwargs Kwargs

Additional arguments (ignored).

{}

Returns:

Type Description
tuple[dict[str, Array], dict[str, Any]]

Updated results and aux.

Source code in laplax/eval/pushforward.py
def nonlin_samples(
    results: dict[str, Array],
    aux: dict[str, Any],
    num_samples: int = 5,
    **kwargs: Kwargs,
) -> tuple[dict[str, Array], dict[str, Any]]:
    """Select samples from ensemble.

    This function selects a subset of samples from the ensemble of predictions.

    Args:
        results: Dictionary to store computed results.
        aux: Auxiliary data containing the prediction ensemble.
        num_samples: Number of samples to select.
        **kwargs: Additional arguments (ignored).

    Returns:
        Updated `results` and `aux`.
    """
    del kwargs

    pred_ensemble = aux["pred_ensemble"]
    results["samples"] = util.tree.tree_slice(pred_ensemble, 0, num_samples)
    return results, aux

nonlin_special_pred_act

nonlin_special_pred_act(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]

Apply special predictive methods to nonlinear Laplace for classification.

This function applies special predictive methods (Laplace Bridge, Mean Field-0, Mean Field-1, or Mean Field-2) to nonlinear Laplace for classification. These methods transform the predictions into probability space using specific formulations rather than Monte Carlo sampling.

Parameters:

Name Type Description Default
results dict[str, Array]

Dictionary to store computed results.

required
aux dict[str, Any]

Auxiliary data containing prediction information.

required
**kwargs Kwargs

Additional arguments, including:

  • special_pred_type: Type of special prediction ("laplace_bridge", "mean_field_0", "mean_field_1", or "mean_field_2")
  • use_correction: Whether to apply correction term for applicable methods.
{}

Returns:

Type Description
tuple[dict[str, Array], dict[str, Any]]

Updated results and aux.

Source code in laplax/eval/pushforward.py
def nonlin_special_pred_act(
    results: dict[str, Array],
    aux: dict[str, Any],
    **kwargs: Kwargs,
) -> tuple[dict[str, Array], dict[str, Any]]:
    """Apply special predictive methods to nonlinear Laplace for classification.

    This function applies special predictive methods (Laplace Bridge, Mean Field-0,
    Mean Field-1, or Mean Field-2) to nonlinear Laplace for classification. These
    methods transform the predictions into probability space using specific formulations
    rather than Monte Carlo sampling.

    Args:
        results: Dictionary to store computed results.
        aux: Auxiliary data containing prediction information.
        **kwargs: Additional arguments, including:

            - `special_pred_type`: Type of special prediction ("laplace_bridge",
              "mean_field_0", "mean_field_1", or "mean_field_2")
            - `use_correction`: Whether to apply correction term for applicable methods.

    Returns:
        Updated `results` and `aux`.
    """
    return special_pred_act(results, aux, linearized=False, **kwargs)

nonlin_mc_pred_act

nonlin_mc_pred_act(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]

Compute Monte Carlo predictions for nonlinear Laplace classification.

This function generates Monte Carlo predictions for classification by averaging softmax probabilities across different weight samples. If samples are not already available, it generates them first.

Parameters:

Name Type Description Default
results dict[str, Array]

Dictionary to store computed results.

required
aux dict[str, Any]

Auxiliary data containing prediction information.

required
**kwargs Kwargs

Additional arguments passed to sample generation.

{}

Returns:

Type Description
tuple[dict[str, Array], dict[str, Any]]

Updated results and aux.

Source code in laplax/eval/pushforward.py
def nonlin_mc_pred_act(
    results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs
) -> tuple[dict[str, Array], dict[str, Any]]:
    """Compute Monte Carlo predictions for nonlinear Laplace classification.

    This function generates Monte Carlo predictions for classification by averaging
    softmax probabilities across different weight samples. If samples are not already
    available, it generates them first.

    Args:
        results: Dictionary to store computed results.
        aux: Auxiliary data containing prediction information.
        **kwargs: Additional arguments passed to sample generation.

    Returns:
        Updated `results` and `aux`.
    """
    if "samples" not in results:
        results, aux = nonlin_samples(results=results, aux=aux, **kwargs)

    results["mc_pred_act"] = jnp.mean(
        jax.nn.softmax(results["samples"], axis=1), axis=0
    )

    return results, aux

set_output_mv

set_output_mv(posterior_state: Posterior, input: InputArray, jvp: Callable[[InputArray, Params], PredArray], vjp: Callable[[InputArray, PredArray], Params]) -> dict

Create matrix-vector product functions for output covariance and scale.

This function propagates uncertainty from weight space to output space by constructing matrix-vector product functions for the output covariance and scale matrices. These functions utilize the posterior's covariance and scale operators in conjunction with Jacobian-vector products (JVP) and vector-Jacobian products (VJP).

Parameters:

Name Type Description Default
posterior_state Posterior

The posterior state containing covariance and scale operators.

required
input InputArray

Input data for the model.

required
jvp Callable[[InputArray, Params], PredArray]

Function for computing Jacobian-vector products.

required
vjp Callable[[InputArray, PredArray], Params]

Function for computing vector-Jacobian products.

required

Returns:

Type Description
dict

A dictionary with:

  • cov_mv: Function for the output covariance matrix-vector product.
  • jac_mv: Function for the JVP with a fixed input.
Source code in laplax/eval/pushforward.py
def set_output_mv(
    posterior_state: Posterior,
    input: InputArray,
    jvp: Callable[[InputArray, Params], PredArray],
    vjp: Callable[[InputArray, PredArray], Params],
) -> dict:
    """Create matrix-vector product functions for output covariance and scale.

    This function propagates uncertainty from weight space to output space by
    constructing matrix-vector product functions for the output covariance and
    scale matrices. These functions utilize the posterior's covariance and scale
    operators in conjunction with Jacobian-vector products (JVP) and
    vector-Jacobian products (VJP).

    Args:
        posterior_state: The posterior state containing covariance and scale operators.
        input: Input data for the model.
        jvp: Function for computing Jacobian-vector products.
        vjp: Function for computing vector-Jacobian products.

    Returns:
        A dictionary with:

            - `cov_mv`: Function for the output covariance matrix-vector product.
            - `jac_mv`: Function for the JVP with a fixed input.
    """
    cov_mv = posterior_state.cov_mv(posterior_state.state)

    def output_cov_mv(vec: PredArray) -> PredArray:
        return jvp(input, cov_mv(vjp(input, vec)[0]))

    def output_jac_mv(vec: PredArray) -> PredArray:
        return jvp(input, vec)

    return {"cov_mv": output_cov_mv, "jac_mv": output_jac_mv}

lin_setup

lin_setup(results: dict[str, Array], aux: dict[str, Any], input: InputArray, dist_state: DistState, **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]

Prepare linearized pushforward functions for uncertainty propagation.

This function sets up matrix-vector product functions for the output covariance and scale matrices in a linearized pushforward framework. It verifies the validity of input components (posterior state, JVP, and VJP) and stores the resulting functions in the auxiliary dictionary.

Parameters:

Name Type Description Default
results dict[str, Array]

Dictionary to store computed results.

required
aux dict[str, Any]

Auxiliary data to store matrix-vector product functions.

required
input InputArray

Input data for the model.

required
dist_state DistState

Distribution state containing posterior state, JVP, and VJP functions.

required
**kwargs Kwargs

Additional arguments (ignored).

{}

Returns:

Type Description
tuple[dict[str, Array], dict[str, Any]]

Updated results and aux.

Raises:

Type Description
TypeError

When the posterior_state, vjp, or jvp has an incorrect type.

Source code in laplax/eval/pushforward.py
def lin_setup(
    results: dict[str, Array],
    aux: dict[str, Any],
    input: InputArray,
    dist_state: DistState,
    **kwargs: Kwargs,
) -> tuple[dict[str, Array], dict[str, Any]]:
    """Prepare linearized pushforward functions for uncertainty propagation.

    This function sets up matrix-vector product functions for the output covariance
    and scale matrices in a linearized pushforward framework. It verifies the
    validity of input components (posterior state, JVP, and VJP) and stores the
    resulting functions in the auxiliary dictionary.

    Args:
        results: Dictionary to store computed results.
        aux: Auxiliary data to store matrix-vector product functions.
        input: Input data for the model.
        dist_state: Distribution state containing posterior state, JVP, and VJP
            functions.
        **kwargs: Additional arguments (ignored).

    Returns:
        Updated `results` and `aux`.

    Raises:
        TypeError: When the posterior_state, vjp, or jvp has an incorrect type.
    """
    del kwargs

    jvp = dist_state["jvp"]
    vjp = dist_state["vjp"]
    posterior_state = dist_state["posterior_state"]

    # Check types (mainly needed for type checker)
    if not isinstance(posterior_state, Posterior):
        msg = "posterior state is not a Posterior type"
        raise TypeError(msg)

    if not isinstance(jvp, Callable):
        msg = "JVP is not a JVPType"
        raise TypeError(msg)

    if not isinstance(vjp, Callable):
        msg = "VJP is not a VJPType"
        raise TypeError(msg)

    mv = set_output_mv(posterior_state, input, jvp, vjp)
    aux["cov_mv"] = mv["cov_mv"]
    aux["jac_mv"] = mv["jac_mv"]

    return results, aux

lin_pred_mean

lin_pred_mean(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]

Restore the linearized predictions.

This function extracts the prediction from the results dictionary and stores it.

Parameters:

Name Type Description Default
results dict[str, Array]

Dictionary to store computed results.

required
aux dict[str, Any]

Auxiliary data (ignored).

required
**kwargs Kwargs

Additional arguments (ignored).

{}

Returns:

Type Description
tuple[dict[str, Array], dict[str, Any]]

Updated results and aux.

Note

This function is used for the linearized mean prediction.

Source code in laplax/eval/pushforward.py
def lin_pred_mean(
    results: dict[str, Array],
    aux: dict[str, Any],
    **kwargs: Kwargs,
) -> tuple[dict[str, Array], dict[str, Any]]:
    """Restore the linearized predictions.

    This function extracts the prediction from the results dictionary and
    stores it.

    Args:
        results: Dictionary to store computed results.
        aux: Auxiliary data (ignored).
        **kwargs: Additional arguments (ignored).

    Returns:
        Updated `results` and `aux`.

    Note:
        This function is used for the linearized mean prediction.
    """
    del kwargs

    results["pred_mean"] = results["map"]
    return results, aux

lin_pred_var

lin_pred_var(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]

Compute and store the variance of the linearized predictions.

This function calculates the variance of predictions by extracting the diagonal of the output covariance matrix.

Parameters:

Name Type Description Default
results dict[str, Array]

Dictionary containing computed results.

required
aux dict[str, Any]

Auxiliary data, including covariance matrix functions.

required
**kwargs Kwargs

Additional arguments (ignored).

{}

Returns:

Type Description
tuple[dict[str, Array], dict[str, Any]]

Updated results and aux.

Source code in laplax/eval/pushforward.py
def lin_pred_var(
    results: dict[str, Array],
    aux: dict[str, Any],
    **kwargs: Kwargs,
) -> tuple[dict[str, Array], dict[str, Any]]:
    """Compute and store the variance of the linearized predictions.

    This function calculates the variance of predictions by extracting the diagonal
    of the output covariance matrix.

    Args:
        results: Dictionary containing computed results.
        aux: Auxiliary data, including covariance matrix functions.
        **kwargs: Additional arguments (ignored).

    Returns:
        Updated `results` and `aux`.
    """
    cov = results.get("pred_cov", aux["cov_mv"])

    if "pred_mean" not in results:
        results, aux = lin_pred_mean(results, aux, **kwargs)

    pred_mean = results["pred_mean"]

    # Compute diagonal as variance
    results["pred_var"] = util.mv.diagonal(cov, layout=math.prod(pred_mean.shape))
    return results, aux

lin_pred_std

lin_pred_std(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]

Compute and store the standard deviation of the linearized predictions.

This function calculates the standard deviation by taking the square root of the predicted variance.

Parameters:

Name Type Description Default
results dict[str, Array]

Dictionary containing computed results.

required
aux dict[str, Any]

Auxiliary data (ignored).

required
**kwargs Kwargs

Additional arguments.

{}

Returns:

Type Description
tuple[dict[str, Array], dict[str, Any]]

Updated results and aux.

Source code in laplax/eval/pushforward.py
def lin_pred_std(
    results: dict[str, Array],
    aux: dict[str, Any],
    **kwargs: Kwargs,
) -> tuple[dict[str, Array], dict[str, Any]]:
    """Compute and store the standard deviation of the linearized predictions.

    This function calculates the standard deviation by taking the square root
    of the predicted variance.

    Args:
        results: Dictionary containing computed results.
        aux: Auxiliary data (ignored).
        **kwargs: Additional arguments.

    Returns:
        Updated `results` and `aux`.
    """
    if "pred_var" not in results:  # Fall back to `lin_pred_var`
        results, aux = lin_pred_var(results, aux, **kwargs)

    var = results["pred_var"]
    results["pred_std"] = util.tree.sqrt(var)
    return results, aux

lin_pred_cov

lin_pred_cov(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]

Compute and store the covariance of the linearized predictions.

This function computes the full output covariance matrix in dense form using the covariance matrix-vector product function.

Parameters:

Name Type Description Default
results dict[str, Array]

Dictionary containing computed results.

required
aux dict[str, Any]

Auxiliary data containing covariance matrix-vector product functions.

required
**kwargs Kwargs

Additional arguments (ignored).

{}

Returns:

Type Description
tuple[dict[str, Array], dict[str, Any]]

Updated results and aux.

Source code in laplax/eval/pushforward.py
def lin_pred_cov(
    results: dict[str, Array],
    aux: dict[str, Any],
    **kwargs: Kwargs,
) -> tuple[dict[str, Array], dict[str, Any]]:
    """Compute and store the covariance of the linearized predictions.

    This function computes the full output covariance matrix in dense form
    using the covariance matrix-vector product function.

    Args:
        results: Dictionary containing computed results.
        aux: Auxiliary data containing covariance matrix-vector product functions.
        **kwargs: Additional arguments (ignored).

    Returns:
        Updated `results` and `aux`.
    """
    if "pred_mean" not in results:
        results, aux = lin_pred_mean(results, aux, **kwargs)

    pred_mean = results["pred_mean"]
    cov_mv = aux["cov_mv"]

    results["pred_cov"] = util.mv.to_dense(cov_mv, layout=pred_mean)
    return results, aux

lin_samples

lin_samples(results: dict[str, Array], aux: dict[str, Any], dist_state: DistState, **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]

Generate and store samples from the linearized distribution.

This function computes samples in the output space by applying the scale matrix to weight samples generated from the posterior distribution.

Parameters:

Name Type Description Default
results dict[str, Array]

Dictionary to store computed results.

required
aux dict[str, Any]

Auxiliary data containing the scale matrix function.

required
dist_state DistState

Distribution state containing sampling functions and sample count.

required
**kwargs Kwargs

Additional arguments, including:

  • lin_samples_batch_size: Batch size for computing samples.
{}

Returns:

Type Description
tuple[dict[str, Array], dict[str, Any]]

Updated results and aux.

Source code in laplax/eval/pushforward.py
def lin_samples(
    results: dict[str, Array],
    aux: dict[str, Any],
    dist_state: DistState,
    **kwargs: Kwargs,
) -> tuple[dict[str, Array], dict[str, Any]]:
    """Generate and store samples from the linearized distribution.

    This function computes samples in the output space by applying the scale
    matrix to weight samples generated from the posterior distribution.

    Args:
        results: Dictionary to store computed results.
        aux: Auxiliary data containing the scale matrix function.
        dist_state: Distribution state containing sampling functions and sample count.
        **kwargs: Additional arguments, including:

            - `lin_samples_batch_size`: Batch size for computing samples.

    Returns:
        Updated `results` and `aux`.
    """
    if "pred_mean" not in results:
        results, aux = lin_pred_mean(results, aux, **kwargs)

    # Unpack arguments
    jac_mv = aux["jac_mv"]
    get_weight_samples = dist_state["get_weight_samples"]
    num_samples = dist_state["num_samples"]

    # Compute samples
    results["samples"] = jax.lax.map(
        lambda i: add(results["pred_mean"], jac_mv(get_weight_samples(i))),
        jnp.arange(num_samples),
        batch_size=kwargs.get(
            "lin_samples_batch_size", kwargs.get("weight_batch_size")
        ),
    )
    return results, aux

lin_special_pred_act

lin_special_pred_act(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]

Apply special predictive methods to linearized Laplace for classification.

This function applies special predictive methods (Laplace Bridge, Mean Field-0, Mean Field-1, or Mean Field-2) to linearized Laplace for classification. These methods transform the predictions into probability space using specific formulations rather than Monte Carlo sampling.

Parameters:

Name Type Description Default
results dict[str, Array]

Dictionary to store computed results.

required
aux dict[str, Any]

Auxiliary data containing prediction information.

required
**kwargs Kwargs

Additional arguments, including:

  • special_pred_type: Type of special prediction ("laplace_bridge", "mean_field_0", "mean_field_1", or "mean_field_2")
  • use_correction: Whether to apply correction term for applicable methods.
{}

Returns:

Type Description
tuple[dict[str, Array], dict[str, Any]]

Updated results and aux.

Source code in laplax/eval/pushforward.py
def lin_special_pred_act(
    results: dict[str, Array],
    aux: dict[str, Any],
    **kwargs: Kwargs,
) -> tuple[dict[str, Array], dict[str, Any]]:
    """Apply special predictive methods to linearized Laplace for classification.

    This function applies special predictive methods (Laplace Bridge, Mean Field-0,
    Mean Field-1, or Mean Field-2) to linearized Laplace for classification. These
    methods transform the predictions into probability space using specific formulations
    rather than Monte Carlo sampling.

    Args:
        results: Dictionary to store computed results.
        aux: Auxiliary data containing prediction information.
        **kwargs: Additional arguments, including:

            - `special_pred_type`: Type of special prediction ("laplace_bridge",
                "mean_field_0", "mean_field_1", or "mean_field_2")
            - `use_correction`: Whether to apply correction term for applicable methods.

    Returns:
        Updated `results` and `aux`.
    """
    return special_pred_act(results, aux, linearized=True, **kwargs)

lin_mc_pred_act

lin_mc_pred_act(results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs) -> tuple[dict[str, Array], dict[str, Any]]

Compute Monte Carlo predictions for linear Laplace classification.

This function generates Monte Carlo predictions for classification by averaging softmax probabilities across different weight samples. If samples are not already available, it generates them first.

Parameters:

Name Type Description Default
results dict[str, Array]

Dictionary to store computed results.

required
aux dict[str, Any]

Auxiliary data containing prediction information.

required
**kwargs Kwargs

Additional arguments passed to sample generation.

{}

Returns:

Type Description
tuple[dict[str, Array], dict[str, Any]]

Updated results and aux.

Source code in laplax/eval/pushforward.py
def lin_mc_pred_act(
    results: dict[str, Array], aux: dict[str, Any], **kwargs: Kwargs
) -> tuple[dict[str, Array], dict[str, Any]]:
    """Compute Monte Carlo predictions for linear Laplace classification.

    This function generates Monte Carlo predictions for classification by averaging
    softmax probabilities across different weight samples. If samples are not already
    available, it generates them first.

    Args:
        results: Dictionary to store computed results.
        aux: Auxiliary data containing prediction information.
        **kwargs: Additional arguments passed to sample generation.

    Returns:
        Updated `results` and `aux`.
    """
    if "samples" not in results:
        results, aux = lin_samples(results=results, aux=aux, **kwargs)

    results["mc_pred_act"] = jnp.mean(
        jax.nn.softmax(results["samples"], axis=1), axis=0
    )

    return results, aux

set_prob_predictive

set_prob_predictive(model_fn: ModelFn, mean_params: Params, dist_state: DistState, pushforward_fns: list[Callable], **kwargs: Kwargs) -> Callable[[InputArray], dict[str, Array]]

Create a probabilistic predictive function.

This function generates a predictive callable that computes uncertainty-aware predictions using a set of pushforward functions. The generated function can evaluate mean predictions and propagate uncertainty from the posterior over weights to output space.

Parameters:

Name Type Description Default
model_fn ModelFn

The model function to evaluate, which takes input and parameters.

required
mean_params Params

The mean of the posterior distribution over model parameters.

required
dist_state DistState

The distribution state for uncertainty propagation, containing functions and parameters related to the posterior.

required
pushforward_fns list[Callable]

A list of pushforward functions, such as mean, variance, and covariance.

required
**kwargs Kwargs

Additional arguments passed to the pushforward functions.

{}

Returns:

Type Description
Callable[[InputArray], dict[str, Array]]

A function that takes an input array and returns a dictionary of predictions and uncertainty metrics.

Source code in laplax/eval/pushforward.py
def set_prob_predictive(
    model_fn: ModelFn,
    mean_params: Params,
    dist_state: DistState,
    pushforward_fns: list[Callable],
    **kwargs: Kwargs,
) -> Callable[[InputArray], dict[str, Array]]:
    """Create a probabilistic predictive function.

    This function generates a predictive callable that computes uncertainty-aware
    predictions using a set of pushforward functions. The generated function can
    evaluate mean predictions and propagate uncertainty from the posterior over
    weights to output space.

    Args:
        model_fn: The model function to evaluate, which takes input and parameters.
        mean_params: The mean of the posterior distribution over model parameters.
        dist_state: The distribution state for uncertainty propagation, containing
            functions and parameters related to the posterior.
        pushforward_fns: A list of pushforward functions, such as mean, variance, and
            covariance.
        **kwargs: Additional arguments passed to the pushforward functions.

    Returns:
        A function that takes an input array and returns a dictionary
            of predictions and uncertainty metrics.
    """

    def prob_predictive(input: InputArray) -> dict[str, Array]:
        # MAP prediction
        pred_map = model_fn(input=input, params=mean_params)
        aux = {"model_fn": model_fn, "mean_params": mean_params}
        results = {"map": pred_map}

        # Compute prediction
        return finalize_fns(
            fns=pushforward_fns,
            results=results,
            dist_state=dist_state,
            aux=aux,
            input=input,
            **kwargs,
        )

    return prob_predictive

set_nonlin_pushforward

set_nonlin_pushforward(model_fn: ModelFn, mean_params: Params, posterior_fn: Callable[[PriorArguments, Int], Posterior], prior_arguments: PriorArguments, *, key: KeyType, loss_scaling_factor: Float = 1.0, pushforward_fns: list = DEFAULT_NONLIN_FINALIZE_FNS, num_samples: int = 100, **kwargs: Kwargs) -> Callable

Construct a Monte Carlo pushforward predictive function.

This function creates a probabilistic predictive callable that computes ensemble-based Monte Carlo (MC) predictions and propagates uncertainty from weight space to output space using sampling.

Parameters:

Name Type Description Default
model_fn ModelFn

The model function to evaluate, which takes input and parameters.

required
mean_params Params

The mean of the posterior distribution over model parameters.

required
posterior_fn Callable[[PriorArguments, Int], Posterior]

A callable that generates the posterior state from prior arguments.

required
prior_arguments PriorArguments

Arguments for defining the prior distribution.

required
key KeyType

PRNG key for generating random samples.

required
loss_scaling_factor Float

Factor by which the user-provided loss function is scaled. Defaults to 1.0.

1.0
pushforward_fns list

A list of Monte Carlo pushforward functions (default: DEFAULT_MC_FUNCTIONS).

DEFAULT_NONLIN_FINALIZE_FNS
num_samples int

Number of weight samples for Monte Carlo predictions.

100
**kwargs Kwargs

Additional arguments passed to the pushforward functions.

{}

Returns:

Type Description
Callable

A probabilistic predictive function that computes predictions and uncertainty metrics using Monte Carlo sampling.

Source code in laplax/eval/pushforward.py
def set_nonlin_pushforward(
    model_fn: ModelFn,
    mean_params: Params,
    posterior_fn: Callable[[PriorArguments, Int], Posterior],
    prior_arguments: PriorArguments,
    *,
    key: KeyType,
    loss_scaling_factor: Float = 1.0,
    pushforward_fns: list = DEFAULT_NONLIN_FINALIZE_FNS,
    num_samples: int = 100,
    **kwargs: Kwargs,
) -> Callable:
    """Construct a Monte Carlo pushforward predictive function.

    This function creates a probabilistic predictive callable that computes
    ensemble-based Monte Carlo (MC) predictions and propagates uncertainty
    from weight space to output space using sampling.

    Args:
        model_fn: The model function to evaluate, which takes input and parameters.
        mean_params: The mean of the posterior distribution over model parameters.
        posterior_fn: A callable that generates the posterior state from prior
            arguments.
        prior_arguments: Arguments for defining the prior distribution.
        key: PRNG key for generating random samples.
        loss_scaling_factor: Factor by which the user-provided loss function is scaled.
            Defaults to 1.0.
        pushforward_fns: A list of Monte Carlo pushforward functions
            (default: `DEFAULT_MC_FUNCTIONS`).
        num_samples: Number of weight samples for Monte Carlo predictions.
        **kwargs: Additional arguments passed to the pushforward functions.

    Returns:
        A probabilistic predictive function that computes predictions
            and uncertainty metrics using Monte Carlo sampling.
    """
    # Create weight sample function
    posterior_state = posterior_fn(prior_arguments, loss_scaling_factor)

    # Posterior state to dist_state
    dist_state = get_dist_state(
        mean_params,
        model_fn,
        posterior_state,
        linearized=False,
        num_samples=num_samples,
        key=key,
        **kwargs,
    )

    # Set prob predictive
    prob_predictive = set_prob_predictive(
        model_fn=model_fn,
        mean_params=mean_params,
        dist_state=dist_state,
        pushforward_fns=pushforward_fns,
        **kwargs,
    )

    return prob_predictive

set_lin_pushforward

set_lin_pushforward(model_fn: ModelFn, mean_params: Params, posterior_fn: Callable[[PriorArguments, Int], Posterior], prior_arguments: PriorArguments, loss_scaling_factor: Float = 1.0, pushforward_fns: list = DEFAULT_LIN_FINALIZE_FNS, **kwargs: Kwargs) -> Callable

Construct a linearized pushforward predictive function.

This function generates a probabilistic predictive callable that computes predictions and propagates uncertainty using a linearized approximation of the model function.

Parameters:

Name Type Description Default
model_fn ModelFn

The model function to evaluate, which takes input and parameters.

required
mean_params Params

The mean of the posterior distribution over model parameters.

required
posterior_fn Callable[[PriorArguments, Int], Posterior]

A callable that generates the posterior state from prior arguments.

required
prior_arguments PriorArguments

Arguments for defining the prior distribution.

required
loss_scaling_factor Float

Factor by which the user-provided loss function is scaled. Defaults to 1.0.

1.0
pushforward_fns list

A list of linearized pushforward functions (default: DEFAULT_LIN_FINALIZE).

DEFAULT_LIN_FINALIZE_FNS
**kwargs Kwargs

Additional arguments passed to the pushforward functions, including:

  • n_samples: Number of samples for approximating uncertainty metrics.
  • key: PRNG key for generating random samples.
{}

Returns:

Type Description
Callable

A probabilistic predictive function that computes predictions and uncertainty metrics using a linearized approximation.

Source code in laplax/eval/pushforward.py
def set_lin_pushforward(
    model_fn: ModelFn,
    mean_params: Params,
    posterior_fn: Callable[[PriorArguments, Int], Posterior],
    prior_arguments: PriorArguments,
    loss_scaling_factor: Float = 1.0,
    pushforward_fns: list = DEFAULT_LIN_FINALIZE_FNS,
    **kwargs: Kwargs,
) -> Callable:
    """Construct a linearized pushforward predictive function.

    This function generates a probabilistic predictive callable that computes
    predictions and propagates uncertainty using a linearized approximation of
    the model function.

    Args:
        model_fn: The model function to evaluate, which takes input and parameters.
        mean_params: The mean of the posterior distribution over model parameters.
        posterior_fn: A callable that generates the posterior state from prior
            arguments.
        prior_arguments: Arguments for defining the prior distribution.
        loss_scaling_factor: Factor by which the user-provided loss function is scaled.
            Defaults to 1.0.
        pushforward_fns: A list of linearized pushforward functions
            (default: `DEFAULT_LIN_FINALIZE`).
        **kwargs: Additional arguments passed to the pushforward functions, including:

            - `n_samples`: Number of samples for approximating uncertainty metrics.
            - `key`: PRNG key for generating random samples.

    Returns:
        A probabilistic predictive function that computes predictions
            and uncertainty metrics using a linearized approximation.
    """
    # Create posterior state
    posterior_state = posterior_fn(prior_arguments, loss_scaling_factor)

    # Posterior state to dist_state
    dist_state = get_dist_state(
        mean_params,
        model_fn,
        posterior_state,
        linearized=True,
        **kwargs,
    )

    # Set prob predictive
    prob_predictive = set_prob_predictive(
        model_fn=model_fn,
        mean_params=mean_params,
        dist_state=dist_state,
        pushforward_fns=pushforward_fns,
        **kwargs,
    )

    return prob_predictive

set_posterior_gp_kernel

set_posterior_gp_kernel(model_fn: ModelFn, mean: Params, posterior_fn: Callable[[PriorArguments, Int], Posterior], prior_arguments: PriorArguments, loss_scaling_factor: Float = 1.0, **kwargs: Kwargs) -> tuple[Callable, DistState]

Construct a kernel matrix-vector product function for a posterior GP.

This function generates a callable for the kernel matrix-vector product (MVP) in a posterior GP framework. The kernel MVP is constructed using the posterior state and propagates uncertainty in weight space to output space via linearization. The resulting kernel MVP can optionally return a dense matrix representation.

Parameters:

Name Type Description Default
model_fn ModelFn

The model function to evaluate, which takes input and parameters.

required
mean Params

The mean of the posterior distribution over model parameters.

required
posterior_fn Callable[[PriorArguments, Int], Posterior]

A callable that generates the posterior state from prior arguments.

required
prior_arguments PriorArguments

Arguments for defining the prior distribution.

required
loss_scaling_factor Float

Factor by which the user-provided loss function is scaled. Defaults to 1.0.

1.0
**kwargs Kwargs

Additional arguments, including:

  • dense: Whether to return a dense kernel matrix instead of the MVP.
  • output_layout: The layout of the dense kernel matrix (required if dense is True).
{}

Returns:

Type Description
tuple[Callable, DistState]

A kernel MVP callable or a dense kernel matrix function, and the distribution state containing posterior information.

Raises:

Type Description
ValueError

If dense is True but output_layout is not specified.

Source code in laplax/eval/pushforward.py
def set_posterior_gp_kernel(
    model_fn: ModelFn,
    mean: Params,
    posterior_fn: Callable[[PriorArguments, Int], Posterior],
    prior_arguments: PriorArguments,
    loss_scaling_factor: Float = 1.0,
    **kwargs: Kwargs,
) -> tuple[Callable, DistState]:
    """Construct a kernel matrix-vector product function for a posterior GP.

    This function generates a callable for the kernel matrix-vector product (MVP)
    in a posterior GP framework. The kernel MVP is constructed using the posterior
    state and propagates uncertainty in weight space to output space via linearization.
    The resulting kernel MVP can optionally return a dense matrix representation.

    Args:
        model_fn: The model function to evaluate, which takes input and parameters.
        mean: The mean of the posterior distribution over model parameters.
        posterior_fn: A callable that generates the posterior state from prior
            arguments.
        prior_arguments: Arguments for defining the prior distribution.
        loss_scaling_factor: Factor by which the user-provided loss function is scaled.
            Defaults to 1.0.
        **kwargs: Additional arguments, including:

            - `dense`: Whether to return a dense kernel matrix instead of the MVP.
            - `output_layout`: The layout of the dense kernel matrix (required if
                `dense` is True).

    Returns:
        A kernel MVP callable or a dense kernel matrix function, and the
            distribution state containing posterior information.

    Raises:
        ValueError: If `dense` is True but `output_layout` is not specified.
    """
    # Create posterior state
    posterior_state = posterior_fn(prior_arguments, loss_scaling_factor)

    # Posterior state to dist_state
    dist_state = get_dist_state(
        mean,
        model_fn,
        posterior_state,
        linearized=True,
        num_samples=0,
        **kwargs,
    )

    # Kernel mv
    def kernel_mv(
        vec: PredArray, x1: InputArray, x2: InputArray, dist_state: dict[str, Any]
    ) -> PredArray:
        cov_mv = dist_state["posterior_state"].cov_mv(
            dist_state["posterior_state"].state
        )
        return dist_state["jvp"](x1, cov_mv(dist_state["vjp"](x2, vec)[0]))

    if kwargs.get("dense"):
        output_layout = kwargs.get("output_layout")
        if output_layout:
            return lambda x1, x2: util.mv.to_dense(
                lambda v: kernel_mv(v, x1, x2, dist_state), layout=output_layout
            ), dist_state
        msg = "function should return a dense matrix, but no output layout is specified"
        raise ValueError(msg)

    return kernel_mv, dist_state