Skip to content

laplax.curv.full

Full curvature approximation.

create_full_curvature

create_full_curvature(mv: CurvatureMV, layout: Layout, **kwargs: Kwargs) -> Num[Array, 'P P']

Generate a full curvature approximation.

The curvature is densed and flattened into a 2D array, that corresponds to the flattened parameter layout.

Parameters:

Name Type Description Default
mv CurvatureMV

Matrix-vector product function representing the curvature.

required
layout Layout

Structure defining the parameter layout that is assumed by the matrix-vector product function. If None or an integer, no flattening/unflattening is used.

required
**kwargs Kwargs

Additional arguments (unused).

{}

Returns:

Type Description
Num[Array, 'P P']

A dense matrix representing the full curvature approximation.

Source code in laplax/curv/full.py
def create_full_curvature(
    mv: CurvatureMV,
    layout: Layout,
    **kwargs: Kwargs,
) -> Num[Array, "P P"]:
    """Generate a full curvature approximation.

    The curvature is densed and flattened into a 2D array, that corresponds to the
    flattened parameter layout.

    Args:
        mv: Matrix-vector product function representing the curvature.
        layout: Structure defining the parameter layout that is assumed by the
            matrix-vector product function. If `None` or an integer, no
            flattening/unflattening is used.
        **kwargs: Additional arguments (unused).

    Returns:
        A dense matrix representing the full curvature approximation.
    """
    del kwargs
    if isinstance(layout, int):
        msg = (
            "Full curvature assumes parameter dictionary as input, "
            f"got type {type(layout)} instead. Proceeding without wrapper."
        )
        logger.warning(msg)
        mv_wrapped = mv
    else:
        flatten, unflatten = create_pytree_flattener(layout)
        mv_wrapped = wrap_function(mv, input_fn=unflatten, output_fn=flatten)
    curv_estimate = to_dense(mv_wrapped, layout=get_size(layout))
    return curv_estimate

full_curvature_to_precision

full_curvature_to_precision(curv_estimate: Num[Array, 'P P'], prior_arguments: PriorArguments, loss_scaling_factor: Float = 1.0) -> Num[Array, 'P P']

Add prior precision to the curvature estimate.

The prior precision (of an isotropic Gaussian prior) is read of the prior_arguments dictionary and added to the curvature estimate. The curvature is scaled by the \(\sigma^2\) parameter.

Parameters:

Name Type Description Default
curv_estimate Num[Array, 'P P']

Full curvature estimate matrix.

required
prior_arguments PriorArguments

Dictionary containing prior precision as 'prior_prec'.

required
loss_scaling_factor Float

Factor by which the user-provided loss function is scaled. Defaults to 1.0.

1.0

Returns:

Type Description
Num[Array, 'P P']

Updated curvature matrix with added prior precision.

Source code in laplax/curv/full.py
def full_curvature_to_precision(
    curv_estimate: Num[Array, "P P"],
    prior_arguments: PriorArguments,
    loss_scaling_factor: Float = 1.0,
) -> Num[Array, "P P"]:
    r"""Add prior precision to the curvature estimate.

    The prior precision (of an isotropic Gaussian prior) is read of the prior_arguments
    dictionary and added to the curvature estimate. The curvature is scaled by the
    $\sigma^2$ parameter.

    Args:
        curv_estimate: Full curvature estimate matrix.
        prior_arguments: Dictionary containing prior precision as 'prior_prec'.
        loss_scaling_factor: Factor by which the user-provided loss function is
            scaled. Defaults to 1.0.

    Returns:
        Updated curvature matrix with added prior precision.
    """
    prior_prec = prior_arguments["prior_prec"]
    sigma_squared = prior_arguments.get("sigma_squared", 1.0)

    return (
        sigma_squared * curv_estimate + prior_prec * jnp.eye(curv_estimate.shape[-1])
    ) / loss_scaling_factor

full_prec_to_scale

full_prec_to_scale(prec: Num[Array, 'P P']) -> Num[Array, 'P P']

Convert precision matrix to scale matrix using Cholesky decomposition.

This converts a precision matrix to a scale matrix using a Cholesky decomposition. The scale matrix is the lower triangular matrix L such that L @ L.T is the covariance matrix.

Parameters:

Name Type Description Default
prec Num[Array, 'P P']

Precision matrix to convert.

required

Returns:

Type Description
Num[Array, 'P P']

Scale matrix L where L @ L.T is the covariance matrix.

Source code in laplax/curv/full.py
def full_prec_to_scale(
    prec: Num[Array, "P P"],
) -> Num[Array, "P P"]:
    """Convert precision matrix to scale matrix using Cholesky decomposition.

    This converts a precision matrix to a scale matrix using a Cholesky decomposition.
    The scale matrix is the lower triangular matrix L such that L @ L.T is the
    covariance matrix.

    Args:
        prec: Precision matrix to convert.

    Returns:
        Scale matrix L where L @ L.T is the covariance matrix.
    """
    Lf = jnp.linalg.cholesky(jnp.flip(prec, axis=(-2, -1)))
    L_inv = jnp.transpose(jnp.flip(Lf, axis=(-2, -1)), axes=(-2, -1))
    Id = jnp.eye(prec.shape[-1], dtype=prec.dtype)
    L = jax.scipy.linalg.solve_triangular(L_inv, Id, trans="T")
    return L

full_prec_to_posterior_state

full_prec_to_posterior_state(prec: Num[Array, 'P P']) -> dict[str, Num[Array, 'P P']]

Convert precision matrix to scale matrix.

The provided precision matrix is converted to a scale matrix, which is the lower triangular matrix L such that L @ L.T is the covariance matrix using :func: full_prec_to_scale.

Parameters:

Name Type Description Default
prec Num[Array, 'P P']

Precision matrix to convert.

required

Returns:

Type Description
dict[str, Num[Array, 'P P']]

Scale matrix L where L @ L.T is the covariance matrix.

Source code in laplax/curv/full.py
def full_prec_to_posterior_state(
    prec: Num[Array, "P P"],
) -> dict[str, Num[Array, "P P"]]:
    """Convert precision matrix to scale matrix.

    The provided precision matrix is converted to a scale matrix, which is the lower
    triangular matrix L such that L @ L.T is the covariance matrix using
    :func: `full_prec_to_scale`.

    Args:
        prec: Precision matrix to convert.

    Returns:
        Scale matrix L where L @ L.T is the covariance matrix.
    """
    scale = full_prec_to_scale(prec)

    return {"scale": scale}

full_posterior_state_to_scale

full_posterior_state_to_scale(state: dict[str, Num[Array, 'P P']]) -> Callable[[FlatParams], FlatParams]

Create a scale matrix-vector product function.

The scale matrix is read from the state dictionary and is used to create a corresponding matrix-vector product function representing the action of the scale matrix on a vector.

Parameters:

Name Type Description Default
state dict[str, Num[Array, 'P P']]

Dictionary containing the scale matrix.

required

Returns:

Type Description
Callable[[FlatParams], FlatParams]

A function that computes the scale matrix-vector product.

Source code in laplax/curv/full.py
def full_posterior_state_to_scale(
    state: dict[str, Num[Array, "P P"]],
) -> Callable[[FlatParams], FlatParams]:
    """Create a scale matrix-vector product function.

    The scale matrix is read from the state dictionary and is used to create a
    corresponding matrix-vector product function representing the action of the scale
    matrix on a vector.

    Args:
        state: Dictionary containing the scale matrix.

    Returns:
        A function that computes the scale matrix-vector product.
    """

    def scale_mv(vec: FlatParams) -> FlatParams:
        return state["scale"] @ vec

    return scale_mv

full_posterior_state_to_cov

full_posterior_state_to_cov(state: dict[str, Num[Array, 'P P']]) -> Callable[[FlatParams], FlatParams]

Create a covariance matrix-vector product function.

The scale matrix is read from the state dictionary and is used to create a corresponding matrix-vector product function representing the action of the cov matrix on a vector. The covariance matrix is computed as the product of the scale matrix and its transpose.

Parameters:

Name Type Description Default
state dict[str, Num[Array, 'P P']]

Dictionary containing the scale matrix.

required

Returns:

Type Description
Callable[[FlatParams], FlatParams]

A function that computes the covariance matrix-vector product.

Source code in laplax/curv/full.py
def full_posterior_state_to_cov(
    state: dict[str, Num[Array, "P P"]],
) -> Callable[[FlatParams], FlatParams]:
    """Create a covariance matrix-vector product function.

    The scale matrix is read from the state dictionary and is used to create a
    corresponding matrix-vector product function representing the action of the cov
    matrix on a vector. The covariance matrix is computed as the product of the scale
    matrix and its transpose.

    Args:
        state: Dictionary containing the scale matrix.

    Returns:
        A function that computes the covariance matrix-vector product.
    """
    cov = state["scale"] @ state["scale"].T

    def cov_mv(vec: FlatParams) -> FlatParams:
        return cov @ vec

    return cov_mv