laplax.curv.ggn
Generalized Gauss-Newton matrix-vector product and loss hessian.
create_loss_hessian_mv ¶
create_loss_hessian_mv(loss_fn: LossFn | str | Callable[[PredArray, TargetArray], Num[Array, ...]] | None, **kwargs: Kwargs) -> Callable
Create a function to compute the Hessian-vector product for a specified loss fn.
For predefined loss functions like cross-entropy and mean squared error, the function computes their corresponding Hessian-vector products using efficient formulations. For custom loss functions, the Hessian-vector product is computed via automatic differentiation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_fn
|
LossFn | str | Callable[[PredArray, TargetArray], Num[Array, ...]] | None
|
Loss function to compute the Hessian-vector product for. Supported options are:
|
required |
**kwargs
|
Kwargs
|
Unused keyword arguments. |
{}
|
Returns:
Type | Description |
---|---|
Callable
|
A function that computes the Hessian-vector product for the given loss function. |
Raises:
Type | Description |
---|---|
ValueError
|
When |
ValueError
|
When an unsupported loss function (not of type: |
Source code in laplax/curv/ggn.py
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
|
create_ggn_mv_without_data ¶
create_ggn_mv_without_data(model_fn: ModelFn, params: Params, loss_fn: LossFn | str | Callable | None, factor: Float, *, vmap_over_data: bool = True, loss_hessian_mv: Callable | None = None) -> Callable[[Params, Data], Params]
Create Generalized Gauss-Newton (GGN) matrix-vector productwithout fixed data.
The GGN matrix is computed using the Jacobian of the model and the Hessian of the loss function. The resulting product is given by:
where \(J_i\) is the Jacobian of the model at data point \(i\), \(H_{L, i}\) is the
Hessian of the loss, and \(v\) is the vector. The factor
is a scaling factor that
is used to scale the GGN matrix.
This function computes the above expression efficiently without hardcoding the dataset, making it suitable for distributed or batched computations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_fn
|
ModelFn
|
The model's forward pass function. |
required |
params
|
Params
|
Model parameters. |
required |
loss_fn
|
LossFn | str | Callable | None
|
Loss function to use for the GGN computation. |
required |
factor
|
Float
|
Scaling factor for the GGN computation. |
required |
vmap_over_data
|
bool
|
Whether to vmap over the data. Defaults to True. |
True
|
loss_hessian_mv
|
Callable | None
|
The loss Hessian matrix-vector product. |
None
|
Returns:
Type | Description |
---|---|
Callable[[Params, Data], Params]
|
A function that takes a vector and a batch of data, and computes the GGN |
Callable[[Params, Data], Params]
|
matrix-vector product. |
Note
The function assumes as a default that the data has a batch dimension.
Source code in laplax/curv/ggn.py
create_ggn_mv ¶
create_ggn_mv(model_fn: ModelFn, params: Params, data: Data, loss_fn: LossFn | str | Callable | None = None, *, num_curv_samples: Int | None = None, num_total_samples: Int | None = None, vmap_over_data: bool = True, loss_hessian_mv: Callable | None = None) -> Callable[[Params], Params]
Computes the Generalized Gauss-Newton (GGN) matrix-vector product with data.
The GGN matrix is computed using the Jacobian of the model and the Hessian of the loss function. For a given dataset, the GGN matrix-vector product is computed as:
where \(J_i\) is the Jacobian of the model for the \(i\)-th data point, \(\nabla^2_{
f(x, \theta), f(x, \theta)}\mathcal{L}_i(f(x_i, \theta), y_i)\) is the Hessian of
the loss for the \(i\)-th data point, and \(N\) is the number of data points. The
factor
is a scaling factor that is used to scale the GGN matrix.
This function hardcodes the dataset, making it ideal for scenarios where the dataset remains fixed.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_fn
|
ModelFn
|
The model's forward pass function. |
required |
params
|
Params
|
Model parameters. |
required |
data
|
Data
|
A batch of input and target data. |
required |
loss_fn
|
LossFn | str | Callable | None
|
Loss function to use for the GGN computation. |
None
|
num_curv_samples
|
Int | None
|
Number of samples used to calculate the GGN. Defaults to None,
in which case it is inferred from |
None
|
num_total_samples
|
Int | None
|
Number of total samples the model was trained on. See the
remark in |
None
|
vmap_over_data
|
bool
|
Whether to vmap over the data. Defaults to True. |
True
|
loss_hessian_mv
|
Callable | None
|
The loss Hessian matrix-vector product. If not provided, it is
computed using the |
None
|
Returns:
Type | Description |
---|---|
Callable[[Params], Params]
|
A function that takes a vector and computes the GGN matrix-vector product for the given data. |
Raises:
Type | Description |
---|---|
ValueError
|
If both |
ValueError
|
If neither |
Note
The function assumes as a default that the data has a batch dimension.
Source code in laplax/curv/ggn.py
274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 |
|