Skip to content

laplax.curv.loss

Loss Gradients and Hessians.

fetch_loss_gradient_fn

fetch_loss_gradient_fn(loss_fn: LossFn | str | Callable[[PredArray, TargetArray], Num[Array, ...]] | None, loss_gradient_fn: Callable | None, *, handle_batches: bool = False, **kwargs: Kwargs) -> Callable[[PredArray, TargetArray], Num[Array, ...]]

Fetch a loss gradient function from the given arguments.

If 'loss_gradient_fn' is passed, return this. If a known 'LossFn' is passed, return analytic gradient. If a custom 'Callable' is passed, use automatic differentiation.

Parameters:

Name Type Description Default
loss_fn LossFn | str | Callable[[PredArray, TargetArray], Num[Array, ...]] | None

Loss function to compute the gradient for. Supported options are:

  • LossFn.BINARY_CROSS_ENTROPY for binary cross-entropy loss.
  • LossFn.CROSS_ENTROPY for cross-entropy loss.
  • LossFn.MSE for mean squared error loss.
  • A custom callable loss function that takes predictions and targets.
required
loss_gradient_fn Callable | None

Custom precomputed loss gradient to use.

required
handle_batches bool

Whether the loss gradient function should handle batches

False
**kwargs Kwargs

Unused keyword arguments.

{}

Returns:

Type Description
Callable[[PredArray, TargetArray], Num[Array, ...]]

A function that computes the gradient loss given predictions and targets.

Callable[[PredArray, TargetArray], Num[Array, ...]]

If 'handle_batches'=True, takes batches of predictions and targets and returns a

Callable[[PredArray, TargetArray], Num[Array, ...]]

batch of gradients

Raises:

Type Description
ValueError

If both loss_fn and loss_gradient_fn are provided.

ValueError

If neither loss_fn nor loss_gradient_fn are provided.

ValueError

When an unsupported loss function is provided.

Source code in laplax/curv/loss.py
def fetch_loss_gradient_fn(
    loss_fn: LossFn
    | str
    | Callable[[PredArray, TargetArray], Num[Array, "..."]]
    | None,
    loss_gradient_fn: Callable | None,
    *,
    handle_batches: bool = False,
    **kwargs: Kwargs,
) -> Callable[[PredArray, TargetArray], Num[Array, "..."]]:
    r"""Fetch a loss gradient function from the given arguments.

    If 'loss_gradient_fn' is passed, return this.
    If a known 'LossFn' is passed, return analytic gradient.
    If a custom 'Callable' is passed, use automatic differentiation.

    Args:
        loss_fn: Loss function to compute the gradient for.
            Supported options are:

            - `LossFn.BINARY_CROSS_ENTROPY` for binary cross-entropy loss.
            - `LossFn.CROSS_ENTROPY` for cross-entropy loss.
            - `LossFn.MSE` for mean squared error loss.
            - A custom callable loss function that takes predictions and targets.

        loss_gradient_fn: Custom precomputed loss gradient to use.
        handle_batches: Whether the loss gradient function should handle batches
        **kwargs: Unused keyword arguments.

    Returns:
        A function that computes the gradient loss given predictions and targets.
        If 'handle_batches'=True, takes batches of predictions and targets and returns a
        batch of gradients

    Raises:
        ValueError: If both `loss_fn` and `loss_gradient_fn` are provided.
        ValueError: If neither `loss_fn` nor `loss_gradient_fn` are provided.
        ValueError: When an unsupported loss function is provided.
    """
    del kwargs

    if loss_gradient_fn is not None:
        if loss_fn is not None:
            msg = "Only one of loss_fn or loss_gradient_fn must be provided."
            raise ValueError(msg)
        return loss_gradient_fn

    if loss_fn is None:
        msg = "Either loss_fn or loss_gradient_fn must be provided."
        raise ValueError(msg)

    if isinstance(loss_fn, Callable):
        grad = jax.grad(loss_fn, argnums=0)
        if handle_batches:
            grad = vmap(grad)
        return grad

    if loss_fn == LossFn.BINARY_CROSS_ENTROPY:
        grad = _binary_cross_entropy_gradient

    elif loss_fn == LossFn.CROSS_ENTROPY:
        grad = _cross_entropy_gradient

    elif loss_fn == LossFn.MSE:
        grad = _mse_gradient

    # Does not support LossFn.None because identity is not scalar-valued,
    # so there exists no gradient
    else:
        msg = f"Unsupported loss function '{loss_fn}' provided."
        raise ValueError(msg)

    loss_grad_fn = partial(grad, handle_batches=handle_batches)
    return loss_grad_fn

create_loss_hessian_mv

create_loss_hessian_mv(loss_fn: LossFn | str | Callable[[PredArray, TargetArray], Num[Array, ...]] | None, **kwargs: Kwargs) -> Callable

Create a function to compute the Hessian-vector product for a specified loss fn.

For predefined loss functions like cross-entropy and mean squared error, the function computes their corresponding Hessian-vector products using efficient formulations. For custom loss functions, the Hessian-vector product is computed via automatic differentiation.

Parameters:

Name Type Description Default
loss_fn LossFn | str | Callable[[PredArray, TargetArray], Num[Array, ...]] | None

Loss function to compute the Hessian-vector product for. Supported options are:

  • LossFn.BINARY_CROSS_ENTROPY for binary cross-entropy loss.
  • LossFn.CROSS_ENTROPY for cross-entropy loss.
  • LossFn.MSE for mean squared error loss.
  • LossFn.NONE for no loss.
  • A custom callable loss function that takes predictions and targets.
required
**kwargs Kwargs

Unused keyword arguments.

{}

Returns:

Type Description
Callable

A function that computes the Hessian-vector product for the given loss function.

Raises:

Type Description
ValueError

When loss_fn is None.

ValueError

When an unsupported loss function (not of type: Callable) is provided.

Source code in laplax/curv/loss.py
def create_loss_hessian_mv(
    loss_fn: LossFn
    | str
    | Callable[[PredArray, TargetArray], Num[Array, "..."]]
    | None,
    **kwargs: Kwargs,
) -> Callable:
    r"""Create a function to compute the Hessian-vector product for a specified loss fn.

    For predefined loss functions like cross-entropy and mean squared error, the
    function computes their corresponding Hessian-vector products using efficient
    formulations. For custom loss functions, the Hessian-vector product is computed via
    automatic differentiation.

    Args:
        loss_fn: Loss function to compute the Hessian-vector product for. Supported
            options are:

            - `LossFn.BINARY_CROSS_ENTROPY` for binary cross-entropy loss.
            - `LossFn.CROSS_ENTROPY` for cross-entropy loss.
            - `LossFn.MSE` for mean squared error loss.
            - `LossFn.NONE` for no loss.
            - A custom callable loss function that takes predictions and targets.

        **kwargs: Unused keyword arguments.

    Returns:
        A function that computes the Hessian-vector product for the given loss function.

    Raises:
        ValueError: When `loss_fn` is `None`.
        ValueError: When an unsupported loss function (not of type: `Callable`) is
            provided.
    """
    del kwargs

    if loss_fn is None:
        msg = "loss_fn cannot be None"
        raise ValueError(msg)

    if loss_fn == LossFn.BINARY_CROSS_ENTROPY:
        return _binary_cross_entropy_hessian_mv

    if loss_fn == LossFn.CROSS_ENTROPY:
        return _cross_entropy_hessian_mv

    if loss_fn == LossFn.MSE:
        return _mse_hessian_mv

    if loss_fn == LossFn.NONE:

        def _identity(
            jv: PredArray,
            pred: PredArray,
            target: TargetArray,
            **kwargs,
        ) -> Num[Array, "..."]:
            del pred, target, kwargs
            return jv

        return _identity

    if isinstance(loss_fn, Callable):

        def custom_hessian_mv(
            jv: PredArray,
            pred: PredArray,
            target: TargetArray,
            **kwargs,
        ) -> Num[Array, "..."]:
            del kwargs

            def loss_fn_local(p):
                return loss_fn(p, target)

            return hvp(loss_fn_local, pred, jv)

        return custom_hessian_mv

    msg = "Unsupported loss function provided"
    raise ValueError(msg)

fetch_loss_hessian_mv

fetch_loss_hessian_mv(loss_fn: LossFn | str | Callable[[PredArray, TargetArray], Num[Array, ...]] | None, loss_hessian_mv: Callable | None, *, vmap_over_data: bool = False, **kwargs: Kwargs) -> Callable

Encapsulates fetching the loss hessian mv given a loss_fn or loss_hessian_mv.

For predefined loss functions like cross-entropy and mean squared error, the function computes their corresponding Hessian-vector products using efficient formulations. For custom loss functions, the Hessian-vector product is computed via automatic differentiation.

Parameters:

Name Type Description Default
loss_fn LossFn | str | Callable[[PredArray, TargetArray], Num[Array, ...]] | None

Loss function to compute the Hessian-vector product for. Supported options are:

  • LossFn.BINARY_CROSS_ENTROPY for binary cross-entropy loss.
  • LossFn.CROSS_ENTROPY for cross-entropy loss.
  • LossFn.MSE for mean squared error loss.
  • LossFn.NONE for no loss.
  • A custom callable loss function that takes predictions and targets.
required
loss_hessian_mv Callable | None

Precomputed loss hessian mv to use.

required
vmap_over_data bool

Whether to vmap over the data. Default False.

False
**kwargs Kwargs

Unused keyword arguments.

{}

Returns:

Type Description
Callable

A function that computes the Hessian-vector product for the given loss function.

Raises:

Type Description
ValueError

If both loss_fn and loss_hessian_mv are provided.

ValueError

If neither loss_fn nor loss_hessian_mv are provided.

ValueError

When an unsupported loss function is provided.

Source code in laplax/curv/loss.py
def fetch_loss_hessian_mv(
    loss_fn: LossFn
    | str
    | Callable[[PredArray, TargetArray], Num[Array, "..."]]
    | None,
    loss_hessian_mv: Callable | None,
    *,
    vmap_over_data: bool = False,
    **kwargs: Kwargs,
) -> Callable:
    r"""Encapsulates fetching the loss hessian mv given a loss_fn or loss_hessian_mv.

    For predefined loss functions like cross-entropy and mean squared error, the
    function computes their corresponding Hessian-vector products using efficient
    formulations. For custom loss functions, the Hessian-vector product is computed via
    automatic differentiation.

    Args:
        loss_fn: Loss function to compute the Hessian-vector product for. Supported
            options are:

            - `LossFn.BINARY_CROSS_ENTROPY` for binary cross-entropy loss.
            - `LossFn.CROSS_ENTROPY` for cross-entropy loss.
            - `LossFn.MSE` for mean squared error loss.
            - `LossFn.NONE` for no loss.
            - A custom callable loss function that takes predictions and targets.

        loss_hessian_mv: Precomputed loss hessian mv to use.
        vmap_over_data: Whether to vmap over the data. Default False.
        **kwargs: Unused keyword arguments.

    Returns:
        A function that computes the Hessian-vector product for the given loss function.

    Raises:
        ValueError: If both `loss_fn` and `loss_hessian_mv` are provided.
        ValueError: If neither `loss_fn` nor `loss_hessian_mv` are provided.
        ValueError: When an unsupported loss function is provided.
    """
    # Enforce either loss_fn or loss_hessian_mv must be provided:
    if loss_fn is None and loss_hessian_mv is None:
        msg = "Either loss_fn or loss_hessian_mv must be provided."
        raise ValueError(msg)

    # Enforce not both loss_fn and loss_hessian_mv are prvovided:
    if loss_fn is not None and loss_hessian_mv is not None:
        msg = "Only one of loss_fn or loss_hessian_mv must be provided."
        raise ValueError(msg)

    loss_hessian_mv = loss_hessian_mv or create_loss_hessian_mv(loss_fn, **kwargs)
    if vmap_over_data:
        loss_hessian_mv = jax.vmap(loss_hessian_mv)

    return loss_hessian_mv