laplax.curv.loss
Loss Gradients and Hessians.
fetch_loss_gradient_fn ¶
fetch_loss_gradient_fn(loss_fn: LossFn | str | Callable[[PredArray, TargetArray], Num[Array, ...]] | None, loss_gradient_fn: Callable | None, *, handle_batches: bool = False, **kwargs: Kwargs) -> Callable[[PredArray, TargetArray], Num[Array, ...]]
Fetch a loss gradient function from the given arguments.
If 'loss_gradient_fn' is passed, return this. If a known 'LossFn' is passed, return analytic gradient. If a custom 'Callable' is passed, use automatic differentiation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loss_fn
|
LossFn | str | Callable[[PredArray, TargetArray], Num[Array, ...]] | None
|
Loss function to compute the gradient for. Supported options are:
|
required |
loss_gradient_fn
|
Callable | None
|
Custom precomputed loss gradient to use. |
required |
handle_batches
|
bool
|
Whether the loss gradient function should handle batches |
False
|
**kwargs
|
Kwargs
|
Unused keyword arguments. |
{}
|
Returns:
| Type | Description |
|---|---|
Callable[[PredArray, TargetArray], Num[Array, ...]]
|
A function that computes the gradient loss given predictions and targets. |
Callable[[PredArray, TargetArray], Num[Array, ...]]
|
If 'handle_batches'=True, takes batches of predictions and targets and returns a |
Callable[[PredArray, TargetArray], Num[Array, ...]]
|
batch of gradients |
Raises:
| Type | Description |
|---|---|
ValueError
|
If both |
ValueError
|
If neither |
ValueError
|
When an unsupported loss function is provided. |
Source code in laplax/curv/loss.py
create_loss_hessian_mv ¶
create_loss_hessian_mv(loss_fn: LossFn | str | Callable[[PredArray, TargetArray], Num[Array, ...]] | None, **kwargs: Kwargs) -> Callable
Create a function to compute the Hessian-vector product for a specified loss fn.
For predefined loss functions like cross-entropy and mean squared error, the function computes their corresponding Hessian-vector products using efficient formulations. For custom loss functions, the Hessian-vector product is computed via automatic differentiation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loss_fn
|
LossFn | str | Callable[[PredArray, TargetArray], Num[Array, ...]] | None
|
Loss function to compute the Hessian-vector product for. Supported options are:
|
required |
**kwargs
|
Kwargs
|
Unused keyword arguments. |
{}
|
Returns:
| Type | Description |
|---|---|
Callable
|
A function that computes the Hessian-vector product for the given loss function. |
Raises:
| Type | Description |
|---|---|
ValueError
|
When |
ValueError
|
When an unsupported loss function (not of type: |
Source code in laplax/curv/loss.py
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 | |
fetch_loss_hessian_mv ¶
fetch_loss_hessian_mv(loss_fn: LossFn | str | Callable[[PredArray, TargetArray], Num[Array, ...]] | None, loss_hessian_mv: Callable | None, *, vmap_over_data: bool = False, **kwargs: Kwargs) -> Callable
Encapsulates fetching the loss hessian mv given a loss_fn or loss_hessian_mv.
For predefined loss functions like cross-entropy and mean squared error, the function computes their corresponding Hessian-vector products using efficient formulations. For custom loss functions, the Hessian-vector product is computed via automatic differentiation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loss_fn
|
LossFn | str | Callable[[PredArray, TargetArray], Num[Array, ...]] | None
|
Loss function to compute the Hessian-vector product for. Supported options are:
|
required |
loss_hessian_mv
|
Callable | None
|
Precomputed loss hessian mv to use. |
required |
vmap_over_data
|
bool
|
Whether to vmap over the data. Default False. |
False
|
**kwargs
|
Kwargs
|
Unused keyword arguments. |
{}
|
Returns:
| Type | Description |
|---|---|
Callable
|
A function that computes the Hessian-vector product for the given loss function. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If both |
ValueError
|
If neither |
ValueError
|
When an unsupported loss function is provided. |