laplax.curv.hessian
Hessian vector product for curvature estimation.
hvp ¶
Compute the Hessian-vector product (HVP) for a given function.
The Hessian-vector product is computed by differentiating the gradient of the function. This avoids explicitly constructing the Hessian matrix, making the computation efficient.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func
|
Callable
|
The scalar function for which the HVP is computed. |
required |
primals
|
PyTree
|
The point at which the gradient and Hessian are evaluated. |
required |
tangents
|
PyTree
|
The vector to multiply with the Hessian. |
required |
Returns:
Type | Description |
---|---|
PyTree
|
The Hessian-vector product. |
Source code in laplax/curv/hessian.py
create_hessian_mv_without_data ¶
create_hessian_mv_without_data(model_fn: ModelFn, params: Params, loss_fn: LossFn | str | Callable, factor: Float, *, vmap_over_data: bool = True, **kwargs: Kwargs) -> Callable[[Params, Data], Params]
Computes the Hessian-vector product (HVP) for a model and loss function.
This function computes the HVP by combining the model and loss functions into a single callable. It evaluates the Hessian at the provided model parameters, with respect to the model and loss function.
Mathematically:
where \(\mathcal{L}\) is the combined loss function, \(f\) is the model function, \(x\) is the input, \(y\) is the target, \(\theta\) are the parameters, and \(v\) is the input input vector.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_fn
|
ModelFn
|
The model function to evaluate. |
required |
params
|
Params
|
The parameters of the model. |
required |
loss_fn
|
LossFn | str | Callable
|
The loss function to apply. Supported options are:
|
required |
factor
|
Float
|
Scaling factor for the Hessian computation. |
required |
vmap_over_data
|
bool
|
Whether the model function should be vectorized over the data. |
True
|
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
Callable[[Params, Data], Params]
|
A function that computes the HVP for a given vector and batch of data. |
Source code in laplax/curv/hessian.py
create_hessian_mv ¶
create_hessian_mv(model_fn: ModelFn, params: Params, data: Data, loss_fn: LossFn | str | Callable, *, num_curv_samples: Int | None = None, num_total_samples: Int | None = None, vmap_over_data: bool = True, **kwargs: Kwargs) -> Callable[[Params], Params]
Computes the Hessian-vector product (HVP) for a model and loss fn. with data.
This function wraps :func: create_hessian_mv_without_data
, fixing the dataset to
produce a function that computes the HVP for the specified data.
Mathematically:
where \(\mathcal{L}\) is the combined loss function, \(f\) is the model function, \(x\) is the input, \(y\) is the target, \(\theta\) are the parameters, and \(v\) is the input vector of the HVP.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_fn
|
ModelFn
|
The model function to evaluate. |
required |
params
|
Params
|
The parameters of the model. |
required |
data
|
Data
|
A batch of input and target data. |
required |
loss_fn
|
LossFn | str | Callable
|
The loss function to apply. Supported options are:
|
required |
num_curv_samples
|
Int | None
|
Number of samples used to calculate the Hessian. Defaults to
None, in which case it is inferred from |
None
|
num_total_samples
|
Int | None
|
Number of total samples the model was trained on. See the
remark in |
None
|
vmap_over_data
|
bool
|
Whether to vmap over the data. Defaults to True. |
True
|
**kwargs
|
Kwargs
|
Additional arguments. |
{}
|
Returns:
Type | Description |
---|---|
Callable[[Params], Params]
|
A function that computes the HVP for a given vector and the fixed dataset. |
Note
The function assumes as a default that the data has a batch dimension.
Source code in laplax/curv/hessian.py
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
|