laplax.curv.ggn
Generalized Gauss-Newton matrix-vector product.
create_ggn_mv_without_data ¶
create_ggn_mv_without_data(model_fn: ModelFn, params: Params, loss_fn: LossFn | str | Callable | None, factor: Float, *, vmap_over_data: bool = True, loss_hessian_mv: Callable | None = None) -> Callable[[Params, Data], Params]
Create Generalized Gauss-Newton (GGN) matrix-vector product without fixed data.
The GGN matrix is computed using the Jacobian of the model and the Hessian of the loss function. The resulting product is given by:
where \(J_i\) is the Jacobian of the model at data point \(i\), \(H_{L, i}\) is the
Hessian of the loss, and \(v\) is the vector. The factor is a scaling factor that
is used to scale the GGN matrix.
This function computes the above expression efficiently without hardcoding the dataset, making it suitable for distributed or batched computations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_fn
|
ModelFn
|
The model's forward pass function. |
required |
params
|
Params
|
Model parameters. |
required |
loss_fn
|
LossFn | str | Callable | None
|
Loss function to use for the GGN computation. |
required |
factor
|
Float
|
Scaling factor for the GGN computation. |
required |
vmap_over_data
|
bool
|
Whether to vmap over the data. Defaults to True. |
True
|
loss_hessian_mv
|
Callable | None
|
The loss Hessian matrix-vector product. |
None
|
Returns:
| Type | Description |
|---|---|
Callable[[Params, Data], Params]
|
A function that takes a vector and a batch of data, and computes the GGN |
Callable[[Params, Data], Params]
|
matrix-vector product. |
Note
The function assumes as a default that the data has a batch dimension.
Source code in laplax/curv/ggn.py
create_ggn_mv ¶
create_ggn_mv(model_fn: ModelFn, params: Params, data: Data, loss_fn: LossFn | str | Callable | None = None, *, num_curv_samples: Int | None = None, num_total_samples: Int | None = None, vmap_over_data: bool = True, loss_hessian_mv: Callable | None = None) -> Callable[[Params], Params]
Computes the Generalized Gauss-Newton (GGN) matrix-vector product with data.
The GGN matrix is computed using the Jacobian of the model and the Hessian of the loss function. For a given dataset, the GGN matrix-vector product is computed as:
where \(J_i\) is the Jacobian of the model for the \(i\)-th data point, \(\nabla^2_{
f(x, \theta), f(x, \theta)}\mathcal{L}_i(f(x_i, \theta), y_i)\) is the Hessian of
the loss for the \(i\)-th data point, and \(N\) is the number of data points. The
factor is a scaling factor that is used to scale the GGN matrix.
This function hardcodes the dataset, making it ideal for scenarios where the dataset remains fixed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_fn
|
ModelFn
|
The model's forward pass function. |
required |
params
|
Params
|
Model parameters. |
required |
data
|
Data
|
A batch of input and target data. |
required |
loss_fn
|
LossFn | str | Callable | None
|
Loss function to use for the GGN computation. |
None
|
num_curv_samples
|
Int | None
|
Number of samples used to calculate the GGN. Defaults to None,
in which case it is inferred from |
None
|
num_total_samples
|
Int | None
|
Number of total samples the model was trained on. See the
remark in |
None
|
vmap_over_data
|
bool
|
Whether to vmap over the data. Defaults to True. |
True
|
loss_hessian_mv
|
Callable | None
|
The loss Hessian matrix-vector product. If not provided, it is
computed using the |
None
|
Returns:
| Type | Description |
|---|---|
Callable[[Params], Params]
|
A function that takes a vector and computes the GGN matrix-vector product for the given data. |
Note
The function assumes as a default that the data has a batch dimension.