Skip to content

laplax.curv.lobpcg

Mixed-Precision Optional-Non-Jittable LOBPCG Wrapper for Sparse Linear Algebra.

This module provides an implementation of the Locally Optimal Block Preconditioned Conjugate Gradient (LOBPCG) method for finding eigenvalues and eigenvectors of large Hermitian matrices.

This is a Wrapper

The original source code can be found at: https://github.com/jax-ml/jax/blob/main/jax/experimental/sparse/linalg.py

What changes were made?

The implementation relies on the JAX experimental sparse linear algebra package but extends its functionality to support:

  • Mixed Precision Arithmetic

    • Computations inside the algorithm (such as orthonormalization, matrix-vector products, and eigenvalue updates) can be performed using higher precision (e.g., float64) to maintain numerical stability in critical steps.
    • Matrix-vector products involving the operator A can be computed in lower precision (e.g., float32) to reduce memory usage and computation time.
  • Non-Jittable Operator Support

    • The implementation supports A as a non-jittable callable, enabling the use of external libraries such as scipy.sparse.linalg for matrix-vector products. This is essential for cases where A cannot be expressed using JAX primitives (e.g., external libraries or precompiled solvers).

Why this Wrapper?

The primary motivation for this implementation is to work around limitations in the JAX lax.while_loop and sparse linear algebra primitives, which require A to be jittable. By decoupling A from the main loop, we can support a broader range of operators while still leveraging the performance advantages of JAX-accelerated numerical routines where possible.

lobpcg_standard

lobpcg_standard(A: Callable[[Array], Array], X: Array, m: int = 100, tol: Float | None = None, calc_dtype: DType = float64, a_dtype: DType = float32, *, A_jit: bool = True) -> tuple[Array, Array, int]

Compute top-k eigenvalues using LOBPCG with mixed precision.

Parameters:

Name Type Description Default
A Callable[[Array], Array]

callable representing the Hermitian matrix operation A @ x.

required
X Array

initial guess \((P, R)\) array.

required
m int

max iterations

100
tol Float | None

tolerance for convergence

None
calc_dtype DType

dtype for internal calculations (float32 or float64)

float64
a_dtype DType

dtype for A calls (e.g., float64 for stable matrix-vector products)

float32
A_jit bool

If True, then pass the computation to jax.experimental.sparse.linalg.lobpcg_standard.

True

Returns:

Type Description
tuple[Array, Array, int]

Tuple containing:

  • Eigenvalues: Array of shape \((R,)\)
  • Eigenvectors: Array of shape \((P, R)\)
  • Iterations: Number of iterations performed
Source code in laplax/curv/lobpcg.py
def lobpcg_standard(
    A: Callable[[Array], Array],
    X: Array,
    m: int = 100,
    tol: Float | None = None,
    calc_dtype: DType = jnp.float64,
    a_dtype: DType = jnp.float32,
    *,
    A_jit: bool = True,
) -> tuple[Array, Array, int]:
    """Compute top-k eigenvalues using LOBPCG with mixed precision.

    Args:
      A: callable representing the Hermitian matrix operation `A @ x`.
      X: initial guess $(P, R)$ array.
      m: max iterations
      tol: tolerance for convergence
      calc_dtype: dtype for internal calculations (`float32` or `float64`)
      a_dtype: dtype for A calls (e.g., `float64` for stable matrix-vector products)
      A_jit: If True, then pass the computation to
            `jax.experimental.sparse.linalg.lobpcg_standard`.

    Returns:
        Tuple containing:

            - Eigenvalues: Array of shape $(R,)$
            - Eigenvectors: Array of shape $(P, R)$
            - Iterations: Number of iterations performed
    """
    if A_jit:
        return linalg.lobpcg_standard(A, X, m, tol)

    n, k = X.shape
    _check_inputs(A, X)

    if tol is None:
        tol = jnp.finfo(calc_dtype).eps

    # Convert initial vectors to computation dtype
    X = X.astype(calc_dtype)

    X = _orthonormalize(X, calc_dtype=calc_dtype)
    P = _extend_basis(X, X.shape[1], calc_dtype=calc_dtype)

    # Precompute initial AX outside of jit
    # Cast to a_dtype before A and back to calc_dtype after
    AX = A(X.astype(a_dtype)).astype(calc_dtype)
    theta = jnp.sum(X * AX, axis=0, keepdims=True)
    R = AX - theta * X

    # JIT-ted iteration step that takes AX, AXPR, AS, etc. in calc_dtype
    @jax.jit
    def _iteration_first_step(X, P, R, AS):
        # Projected eigensolve
        XPR = jnp.concatenate((X, P, R), axis=1)
        theta, Q = _rayleigh_ritz_orth(AS, XPR)

        # Eigenvector X extraction
        B = Q[:, :k]
        normB = jnp.linalg.norm(B, ord=2, axis=0, keepdims=True)
        B /= normB
        X = _mm(XPR, B)
        normX = jnp.linalg.norm(X, ord=2, axis=0, keepdims=True)
        X /= normX

        # Difference terms P extraction
        q, _ = jnp.linalg.qr(Q[:k, k:].T)
        diff_rayleigh_ortho = _mm(Q[:, k:], q)
        P = _mm(XPR, diff_rayleigh_ortho)
        normP = jnp.linalg.norm(P, ord=2, axis=0, keepdims=True)
        P /= jnp.where(normP == 0, 1.0, normP)

        return X, P, R, theta

    @jax.jit
    def _iteration_second_step(X, R, theta, AX, n, tol):
        # Compute new residuals.
        # AX = A(X)
        R = AX - theta[jnp.newaxis, :k] * X
        resid_norms = jnp.linalg.norm(R, ord=2, axis=0)

        # Compute residual norms
        reltol = jnp.linalg.norm(AX, ord=2, axis=0) + theta[:k]
        reltol *= n
        # Allow some margin for a few element-wise operations.
        reltol *= 10
        res_converged = resid_norms < tol * reltol
        converged = jnp.sum(res_converged)

        return X, R, theta[jnp.newaxis, :k], converged

    @jax.jit
    def _projection_step(X, P, R):
        R = _project_out(jnp.concatenate((X, P), axis=1), R)
        return R, jnp.concatenate((X, P, R), axis=1)

    i = 0
    converged = 0
    while i < m and converged < k:
        # Residual basis selection
        R, XPR = _projection_step(X, P, R)

        # Compute AS = AXPR = A(XPR) outside JIT at a_dtype
        AS = A(XPR.astype(a_dtype)).astype(calc_dtype)

        # Call the first iteration step
        X, P, R, theta = _iteration_first_step(X, P, R, AS)

        # Calculate AX
        AX = A(X.astype(a_dtype)).astype(calc_dtype)

        # Call the second iteration step
        X, R, theta, converged = _iteration_second_step(X, R, theta, AX, n, tol)

        i += 1

    return theta[0, :], X, i

lobpcg_lowrank

lobpcg_lowrank(A: Callable[[Array], Array] | Array, *, key: KeyType | None = None, layout: Layout | None = None, rank: int = 20, tol: Float = 1e-06, mv_dtype: DType | None = None, calc_dtype: DType = float64, return_dtype: DType | None = None, mv_jit: bool = True, **kwargs: Kwargs) -> LowRankTerms

Compute a low-rank approximation using the LOBPCG algorithm.

This function computes the leading eigenvalues and eigenvectors of a matrix represented by a matrix-vector product function mv, without explicitly forming the matrix. It uses the Locally Optimal Block Preconditioned Conjugate Gradient (LOBPCG) algorithm to achieve efficient low-rank approximation, with support for mixed-precision arithmetic and optional JIT compilation.

Mathematically, the low-rank approximation seeks to find the leading \(R\) eigenpairs \((\lambda_i, u_i)\) such that: \(A u_i = \lambda_i u_i \quad \text{for } i = 1, \ldots, R\), where \(A\) is the matrix represented by the matrix-vector product mv.

Parameters:

Name Type Description Default
A Callable[[Array], Array] | Array

A callable that computes the matrix-vector product, representing the matrix A @ x.

required
key KeyType | None

PRNG key for random initialization of the search directions.

None
layout Layout | None

Dimension of the input/output space of the matrix.

None
rank int

Number of leading eigenpairs to compute. Defaults to \(R=20\).

20
tol Float

Convergence tolerance for the algorithm. If None, the machine epsilon for calc_dtype is used.

1e-06
mv_dtype DType | None

Data type for the matrix-vector product function.

None
calc_dtype DType

Data type for internal calculations during LOBPCG.

float64
return_dtype DType | None

Data type for the final results.

None
mv_jit bool

If True, enables JIT compilation for the matrix-vector product.

True
**kwargs Kwargs

Additional arguments (ignored).

{}

Returns:

Type Description
LowRankTerms

A dataclass containing:

  • U: Eigenvectors as a matrix of shape \((P, R)\).
  • S: Eigenvalues as an array of length \((R,)\).
  • scalar: Scalar factor, initialized to 0.0.
Note
  • If mv_jit is True, the function will be vectorized over the data.
  • If the size of the matrix is small relative to rank, the number of iterations is reduced to avoid over-computation.
  • Mixed precision can significantly reduce memory usage, especially for large matrices. If mv_dtype is None, the data type is automatically determined based on the jax_enable_x64 configuration.
Example
low_rank_terms = lobpcg_lowrank(
    A=jnp.eye(1000),
    key=jax.random.key(42),
    rank=10,
    tol=1e-6,
)
Source code in laplax/curv/lobpcg.py
def lobpcg_lowrank(
    A: Callable[[Array], Array] | Array,
    *,
    key: KeyType | None = None,
    layout: Layout | None = None,
    rank: int = 20,
    tol: Float = 1e-6,
    mv_dtype: DType | None = None,
    calc_dtype: DType = jnp.float64,
    return_dtype: DType | None = None,
    mv_jit: bool = True,
    **kwargs: Kwargs,
) -> LowRankTerms:
    r"""Compute a low-rank approximation using the LOBPCG algorithm.

    This function computes the leading eigenvalues and eigenvectors of a matrix
    represented by a matrix-vector product function `mv`, without explicitly forming
    the matrix. It uses the Locally Optimal Block Preconditioned Conjugate Gradient
    (LOBPCG) algorithm to achieve efficient low-rank approximation, with support
    for mixed-precision arithmetic and optional JIT compilation.

    Mathematically, the low-rank approximation seeks to find the leading $R$
    eigenpairs $(\lambda_i, u_i)$ such that:
    $A u_i = \lambda_i u_i \quad \text{for } i = 1, \ldots, R$, where $A$ is the matrix
    represented by the matrix-vector product `mv`.

    Args:
        A: A callable that computes the matrix-vector product, representing the matrix
            `A @ x`.
        key: PRNG key for random initialization of the search directions.
        layout: Dimension of the input/output space of the matrix.
        rank: Number of leading eigenpairs to compute. Defaults to $R=20$.
        tol: Convergence tolerance for the algorithm. If `None`, the machine epsilon
            for `calc_dtype` is used.
        mv_dtype: Data type for the matrix-vector product function.
        calc_dtype: Data type for internal calculations during LOBPCG.
        return_dtype: Data type for the final results.
        mv_jit: If `True`, enables JIT compilation for the matrix-vector product.
        **kwargs: Additional arguments (ignored).

    Returns:
        A dataclass containing:

            - `U`: Eigenvectors as a matrix of shape $(P, R)$.
            - `S`: Eigenvalues as an array of length $(R,)$.
            - `scalar`: Scalar factor, initialized to 0.0.

    Note:
        - If `mv_jit` is `True`, the function will be vectorized over the data.
        - If the size of the matrix is small relative to `rank`, the number of
        iterations is reduced to avoid over-computation.
        - Mixed precision can significantly reduce memory usage, especially for large
            matrices. If `mv_dtype` is `None`, the data type is automatically determined
            based on the `jax_enable_x64` configuration.

    Example:
        ```python

        low_rank_terms = lobpcg_lowrank(
            A=jnp.eye(1000),
            key=jax.random.key(42),
            rank=10,
            tol=1e-6,
        )

        ```
    """
    del kwargs

    # Initialize handling mixed precision.
    original_float64_enabled = jax.config.read("jax_enable_x64")

    if mv_dtype is None:
        mv_dtype = jnp.float64 if original_float64_enabled else jnp.float32

    if return_dtype is None:
        return_dtype = jnp.float64 if original_float64_enabled else jnp.float32

    jax.config.update("jax_enable_x64", calc_dtype == jnp.float64)

    # Obtain a matrix-vector multiplication function.
    matvec, size = get_matvec(A, layout=layout, jit=mv_jit)

    # Obtain a matrix-matrix product function.
    matmat = jax.vmap(matvec, in_axes=-1, out_axes=-1)

    # Adjust rank if it's too large compared to problem size
    if size < rank * 5:
        rank = max(1, size // 5 - 1)
        msg = f"reduced rank to {rank} due to insufficient size"
        warnings.warn(msg, stacklevel=1)

    # Wrap to_dtype around mv if necessary.
    if mv_dtype != calc_dtype:
        matmat = wrap_function(
            matmat,
            input_fn=lambda x: jnp.asarray(x, dtype=mv_dtype),
            output_fn=lambda x: jnp.asarray(x, dtype=calc_dtype),
        )

    # Initialize random search directions
    if key is None:
        key = jax.random.key(0)
    X = jax.random.normal(key, (size, rank), dtype=calc_dtype)

    # Perform LOBPCG for eigenvalues and eigenvectors using the new wrapper
    eigenvals, eigenvecs, _ = lobpcg_standard(
        A=matmat,
        X=X,
        m=rank,
        tol=tol,
        calc_dtype=calc_dtype,
        a_dtype=mv_dtype,  # type: ignore  # noqa: PGH003
        A_jit=mv_jit,
    )

    # Prepare and convert the results
    low_rank_result = LowRankTerms(
        U=jnp.asarray(eigenvecs, dtype=return_dtype),
        S=jnp.asarray(eigenvals, dtype=return_dtype),
        scalar=jnp.asarray(0.0, dtype=return_dtype),
    )

    # Restore the original configuration dtype
    jax.config.update("jax_enable_x64", original_float64_enabled)

    return low_rank_result