Skip to content

laplax.eval.predictives

laplace_bridge

laplace_bridge(mean: Array, var: Array, *, use_correction: bool) -> Array

Laplace bridge approximation.

Returns:

Type Description
Array

The predictive.

Source code in laplax/eval/predictives.py
def laplace_bridge(
    mean: jax.Array,
    var: jax.Array,
    *,
    use_correction: bool,
) -> jax.Array:
    """Laplace bridge approximation.

    Returns:
        The predictive.
    """
    num_classes = mean.shape[1]

    if use_correction:
        c = jnp.sum(var, axis=0) / math.sqrt(num_classes / 2)  # [...]
        c_expanded = jnp.expand_dims(c, axis=0)  # [1, ...]
        mean = mean / jnp.sqrt(c_expanded)  # [C, ...]
        var = var / c_expanded  # [C, ...]

    # Laplace bridge
    sum_exp_neg_mean_p = jnp.sum(jnp.exp(-mean), axis=0)  # [...]
    sum_exp_neg_mean_p_expanded = jnp.expand_dims(
        sum_exp_neg_mean_p, axis=0
    )  # [1, ...]
    dirichlet_params = (
        1
        - 2 / num_classes
        + jnp.exp(mean) * sum_exp_neg_mean_p_expanded / (num_classes**2)
    ) / var  # [C, ...]

    return dirichlet_predictive(dirichlet_params)

dirichlet_predictive

dirichlet_predictive(dirichlet_params: Array) -> Array

Predictive mean of Dirichlet distributions.

Returns:

Type Description
Array

The predictive.

Source code in laplax/eval/predictives.py
def dirichlet_predictive(dirichlet_params: jax.Array) -> jax.Array:
    """Predictive mean of Dirichlet distributions.

    Returns:
        The predictive.
    """
    predictive = dirichlet_params / jnp.sum(dirichlet_params)  # [C, ...]

    return predictive