laplax.eval.likelihood
Compute the marginal log-likelihood for different curvature estimations.
Implemented according to: Smith, J., et al. (2023): Scalable Marginal Likelihood Estimation for Model Selection in Deep Learning. Proceedings of the International Conference on Machine Learning, 25(3), 234-245.
It includes functions to calculate the marginal log-likelihood based on various curvature approximations, including:
- full
- diagonal
- low-rank
joint_log_likelihood ¶
joint_log_likelihood(full_fn: Callable, prior_arguments: PriorArguments, params: Params, data: Data) -> Float
Computes the joint log-likelihood for a model.
This function computes the joint log-likelihood for a model, which is given by:
If we assume a Gaussian prior on the parameters with precision \(\tau^{-2}\), then the log-prior is given by:
Parameters:
Name | Type | Description | Default |
---|---|---|---|
full_fn
|
Callable
|
model loss function that has the parameters and the data as input and output the loss |
required |
prior_arguments
|
PriorArguments
|
prior arguments |
required |
params
|
Params
|
model parameters |
required |
data
|
Data
|
training data |
required |
Returns:
Type | Description |
---|---|
Float
|
The joint log-likelihood. |
Source code in laplax/eval/likelihood.py
full_marginal_log_likelihood ¶
full_marginal_log_likelihood(posterior_precision: Num[Array, 'P P'], prior_arguments: PriorArguments, full_fn: Callable, params: Params, data: Data) -> Float
Computes the marginal log likelihood for the full posterior function.
The marginal log-likelihood is given by:
Parameters:
Name | Type | Description | Default |
---|---|---|---|
posterior_precision
|
Num[Array, 'P P']
|
posterior precision |
required |
prior_arguments
|
PriorArguments
|
prior arguments |
required |
full_fn
|
Callable
|
model loss function that has the parameters and the data as input and output the loss |
required |
params
|
Params
|
model parameters |
required |
data
|
Data
|
training data |
required |
Returns:
Type | Description |
---|---|
Float
|
The marginal likelihood estimation |
Source code in laplax/eval/likelihood.py
diagonal_marginal_log_likelihood ¶
diagonal_marginal_log_likelihood(posterior_precision: FlatParams, prior_arguments: PriorArguments, full_fn: Callable, params: Params, data: Data) -> Float
Computes the marginal log likelihood for a diagonal approximation.
The marginal log-likelihood is given by:
Here the log-determinant of the posterior precision simplifies to:
where \(d_i\) is the \(i\)-th diagonal element of the posterior precision.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
posterior_precision
|
FlatParams
|
posterior precision |
required |
prior_arguments
|
PriorArguments
|
prior arguments |
required |
full_fn
|
Callable
|
model loss function that has the parameters and the data as input and output the loss |
required |
params
|
Params
|
model parameters |
required |
data
|
Data
|
training data |
required |
Returns:
Type | Description |
---|---|
Float
|
The marginal likelihood estimation. |
Source code in laplax/eval/likelihood.py
low_rank_marginal_log_likelihood ¶
low_rank_marginal_log_likelihood(posterior_precision: LowRankTerms, prior_arguments: PriorArguments, full_fn: Callable, params: Params, data: Data) -> Float
Computes the marginal log likelihood for a low-rank approximation.
The marginal log-likelihood is given by:
Here the log-determinant of the posterior precision (with \(U\Lambda U^T + D\)) simplifies to:
where \(d_i\) is the \(i\)-th diagonal element of the prior precision and \(\lambda_i\) is the \(i\)-th eigenvalue of the low-rank approximation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
posterior_precision
|
LowRankTerms
|
posterior precision |
required |
prior_arguments
|
PriorArguments
|
prior arguments |
required |
full_fn
|
Callable
|
model loss function that has the parameters and the data as input and output the loss |
required |
params
|
Params
|
model parameters |
required |
data
|
Data
|
training data |
required |
Returns:
Type | Description |
---|---|
Float
|
The marginal likelihood estimation. |
Source code in laplax/eval/likelihood.py
marginal_log_likelihood ¶
marginal_log_likelihood(curv_estimate: PyTree, prior_arguments: PriorArguments, data: Data, model_fn: ModelFn, params: Params, loss_fn: LossFn | str | Callable, curv_type: CurvatureKeyType, *, vmap_over_data: bool = False, loss_scaling_factor: Float = 1.0) -> Float
Compute the marginal log-likelihood for a given curvature approximation.
The marginal log-likelihood is given by:
Here \(H_{\theta_*}\) is the Hessian/GGN of the loss function evaluated at the model parameters. The likelihood function is given by the negative loss function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
curv_estimate
|
PyTree
|
curvature estimate |
required |
prior_arguments
|
PriorArguments
|
prior arguments |
required |
data
|
Data
|
training data |
required |
model_fn
|
ModelFn
|
model function |
required |
params
|
Params
|
model parameters |
required |
loss_fn
|
LossFn | str | Callable
|
loss function |
required |
curv_type
|
CurvatureKeyType
|
curvature type |
required |
vmap_over_data
|
bool
|
whether the model has a batch dimension |
False
|
loss_scaling_factor
|
Float
|
loss scaling factor |
1.0
|
Returns:
Type | Description |
---|---|
Float
|
The marginal log-likelihood. |