laplax.curv.lobpcg
Mixed-Precision Optional-Non-Jittable LOBPCG Wrapper for Sparse Linear Algebra.
This module provides an implementation of the Locally Optimal Block Preconditioned Conjugate Gradient (LOBPCG) method for finding eigenvalues and eigenvectors of large Hermitian matrices.
This is a Wrapper¶
The original source code can be found at: https://github.com/jax-ml/jax/blob/main/jax/experimental/sparse/linalg.py
What changes were made?¶
The implementation relies on the JAX experimental sparse linear algebra package but extends its functionality to support:
-
Mixed Precision Arithmetic
- Computations inside the algorithm (such as orthonormalization, matrix-vector
products, and eigenvalue updates) can be performed using higher precision
(e.g.,
float64
) to maintain numerical stability in critical steps. - Matrix-vector products involving the operator
A
can be computed in lower precision (e.g.,float32
) to reduce memory usage and computation time.
- Computations inside the algorithm (such as orthonormalization, matrix-vector
products, and eigenvalue updates) can be performed using higher precision
(e.g.,
-
Non-Jittable Operator Support
- The implementation supports
A
as a non-jittable callable, enabling the use of external libraries such asscipy.sparse.linalg
for matrix-vector products. This is essential for cases whereA
cannot be expressed using JAX primitives (e.g., external libraries or precompiled solvers).
- The implementation supports
Why this Wrapper?¶
The primary motivation for this implementation is to work around limitations
in the JAX lax.while_loop
and sparse linear algebra primitives, which require
A
to be jittable. By decoupling A
from the main loop, we can support a
broader range of operators while still leveraging the performance advantages
of JAX-accelerated numerical routines where possible.
lobpcg_standard ¶
lobpcg_standard(A: Callable[[Array], Array], X: Array, m: int = 100, tol: Float | None = None, calc_dtype: DType = float64, a_dtype: DType = float32, *, A_jit: bool = True) -> tuple[Array, Array, int]
Compute top-k eigenvalues using LOBPCG with mixed precision.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
A
|
Callable[[Array], Array]
|
callable representing the Hermitian matrix operation |
required |
X
|
Array
|
initial guess \((P, R)\) array. |
required |
m
|
int
|
max iterations |
100
|
tol
|
Float | None
|
tolerance for convergence |
None
|
calc_dtype
|
DType
|
dtype for internal calculations ( |
float64
|
a_dtype
|
DType
|
dtype for A calls (e.g., |
float32
|
A_jit
|
bool
|
If True, then pass the computation to
|
True
|
Returns:
Type | Description |
---|---|
tuple[Array, Array, int]
|
Tuple containing:
|
Source code in laplax/curv/lobpcg.py
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
|
lobpcg_lowrank ¶
lobpcg_lowrank(A: Callable[[Array], Array] | Array, *, key: KeyType | None = None, layout: Layout | None = None, rank: int = 20, tol: Float = 1e-06, mv_dtype: DType | None = None, calc_dtype: DType = float64, return_dtype: DType | None = None, mv_jit: bool = True, **kwargs: Kwargs) -> LowRankTerms
Compute a low-rank approximation using the LOBPCG algorithm.
This function computes the leading eigenvalues and eigenvectors of a matrix
represented by a matrix-vector product function mv
, without explicitly forming
the matrix. It uses the Locally Optimal Block Preconditioned Conjugate Gradient
(LOBPCG) algorithm to achieve efficient low-rank approximation, with support
for mixed-precision arithmetic and optional JIT compilation.
Mathematically, the low-rank approximation seeks to find the leading \(R\)
eigenpairs \((\lambda_i, u_i)\) such that:
\(A u_i = \lambda_i u_i \quad \text{for } i = 1, \ldots, R\), where \(A\) is the matrix
represented by the matrix-vector product mv
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
A
|
Callable[[Array], Array] | Array
|
A callable that computes the matrix-vector product, representing the matrix
|
required |
key
|
KeyType | None
|
PRNG key for random initialization of the search directions. |
None
|
layout
|
Layout | None
|
Dimension of the input/output space of the matrix. |
None
|
rank
|
int
|
Number of leading eigenpairs to compute. Defaults to \(R=20\). |
20
|
tol
|
Float
|
Convergence tolerance for the algorithm. If |
1e-06
|
mv_dtype
|
DType | None
|
Data type for the matrix-vector product function. |
None
|
calc_dtype
|
DType
|
Data type for internal calculations during LOBPCG. |
float64
|
return_dtype
|
DType | None
|
Data type for the final results. |
None
|
mv_jit
|
bool
|
If |
True
|
**kwargs
|
Kwargs
|
Additional arguments (ignored). |
{}
|
Returns:
Type | Description |
---|---|
LowRankTerms
|
A dataclass containing:
|
Note
- If
mv_jit
isTrue
, the function will be vectorized over the data. - If the size of the matrix is small relative to
rank
, the number of iterations is reduced to avoid over-computation. - Mixed precision can significantly reduce memory usage, especially for large
matrices. If
mv_dtype
isNone
, the data type is automatically determined based on thejax_enable_x64
configuration.
Example
Source code in laplax/curv/lobpcg.py
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 |
|