laplax.curv.utils
Utility functions for curvature estimation.
LowRankTerms
dataclass
¶
Components of the low-rank curvature approximation.
This dataclass encapsulates the results of the low-rank approximation, including the eigenvectors, eigenvalues, and a scalar factor which can be used for the prior.
Attributes:
Name | Type | Description |
---|---|---|
U |
Num[Array, 'P R']
|
Matrix of eigenvectors, where each column corresponds to an eigenvector. |
S |
Num[Array, ' R']
|
Array of eigenvalues associated with the eigenvectors. |
scalar |
Float[Array, '']
|
Scalar factor added to the matrix during the approximation. |
Source code in laplax/curv/utils.py
get_matvec ¶
get_matvec(A: Callable | Array, *, layout: Layout | None = None, jit: bool = True) -> tuple[Callable[[Array], Array], int]
Returns a function that computes the matrix-vector product.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
A
|
Callable | Array
|
Either a jnp.ndarray or a callable performing the operation. |
required |
layout
|
Layout | None
|
Required if |
None
|
jit
|
bool
|
Whether to jit-compile the operator. |
True
|
Returns:
Type | Description |
---|---|
tuple[Callable[[Array], Array], int]
|
A tuple (matvec, input_dim) where matvec is the callable operator. |
Raises:
Type | Description |
---|---|
TypeError
|
When |
Source code in laplax/curv/utils.py
log_sigmoid_cross_entropy ¶
Computes log sigmoid cross entropy given logits and targets.
This function computes the cross entropy loss between the sigmoid of the logits and the target values. The formula implemented is:
Parameters:
Name | Type | Description | Default |
---|---|---|---|
logits
|
Num[Array, ...]
|
The predicted logits before sigmoid activation |
required |
targets
|
Num[Array, ...]
|
The target values (0 or 1) |
required |
Returns:
Type | Description |
---|---|
Num[Array, ...]
|
The computed loss value |
Source code in laplax/curv/utils.py
concatenate_model_and_loss_fn ¶
concatenate_model_and_loss_fn(model_fn: ModelFn, loss_fn: LossFn | str | Callable, *, vmap_over_data: bool = False) -> Callable[[InputArray, TargetArray, Params], Num[Array, ...]]
Combine a model function and a loss function into a single callable.
This creates a new function that evaluates the model and applies the specified
loss function. If vmap_over_data
is True
, the model function is vectorized over
the batch dimension using jax.vmap
.
Mathematically, the combined function computes:
where \(f\) is the model function, \(\theta\) are the model parameters, \(x\) is the input, \(y\) is the target, and \(\mathcal{L}\) is the loss function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_fn
|
ModelFn
|
The model function to evaluate. |
required |
loss_fn
|
LossFn | str | Callable
|
The loss function to apply. Supported options are:
|
required |
vmap_over_data
|
bool
|
Whether the model function should be vectorized over the data. |
False
|
Returns:
Type | Description |
---|---|
Callable[[InputArray, TargetArray, Params], Num[Array, ...]]
|
A combined function that computes the loss for given inputs, targets, and parameters. |
Raises:
Type | Description |
---|---|
ValueError
|
When the loss function is unknown. |