Skip to content

Batch dependence of jax.numpy.linalg.solve #20047

@mfschubert

Description

@mfschubert
Contributor

Description

There seems to be some batch-dependence to linalg.solve. Specifically, solving a linear system, and solving a batch of identical linear systems gives different results.

The code below reproduces the problem, both on my mac and in a colab notebook.

import jax
import jax.numpy as jnp
import numpy as onp


a, b = jax.random.normal(jax.random.PRNGKey(0), (2, 100, 100))
assert a.shape == b.shape == (100, 100)

a_batch = jnp.stack([a, a], axis=0)
b_batch = jnp.stack([b, b], axis=0)


# Solve with jax, with and without batch dimensions.
sol_jax = jnp.linalg.solve(a, b)
sol_jax_with_batch = jnp.linalg.solve(a_batch, b_batch)

onp.testing.assert_array_equal(sol_jax_with_batch[0, ...], sol_jax_with_batch[1, ...])  # passes
onp.testing.assert_array_equal(sol_jax, sol_jax_with_batch[0, ...])  # FAILS
onp.testing.assert_array_equal(sol_jax, sol_jax_with_batch[1, ...])  # FAILS


# Solve with numpy, with and without batch dimensions.
sol_numpy = onp.linalg.solve(a, b)
sol_numpy_with_batch = onp.linalg.solve(a_batch, b_batch)

onp.testing.assert_array_equal(sol_numpy_with_batch[0, ...], sol_numpy_with_batch[1, ...])  # passes
onp.testing.assert_array_equal(sol_numpy, sol_numpy_with_batch[0, ...])  # passes
onp.testing.assert_array_equal(sol_numpy, sol_numpy_with_batch[1, ...])  # passes

System info (python version, jaxlib version, accelerator, etc.)

Colab notebook, jaxlib 0.4.23

Activity

mfschubert

mfschubert commented on Mar 2, 2024

@mfschubert
ContributorAuthor

Here are the visualized errors, A @ x - b

image

import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4))
plt.subplot(131)
plt.imshow((a_batch @ sol_jax_with_batch - b_batch)[0, ...])
plt.colorbar()
plt.title("error\nbatch solve, batch_idx=0")
plt.subplot(132)
plt.imshow((a_batch @ sol_jax_with_batch - b_batch)[1, ...])
plt.colorbar()
plt.title("error\nbatch solve, batch_idx=1")
plt.subplot(133)
plt.imshow(a @ sol_jax - b)
plt.colorbar()
plt.title("error\nnon-batch solve")
self-assigned this
on Mar 2, 2024
jakevdp

jakevdp commented on Mar 2, 2024

@jakevdp
Collaborator

Thanks for the report - if this is on a CPU backend, then the issue may be related to openxla/xla#3891

mfschubert

mfschubert commented on Mar 2, 2024

@mfschubert
ContributorAuthor

Thanks for the report - if this is on a CPU backend, then the issue may be related to openxla/xla#3891

Sure thing! I have tried this on CPU and GPU (T4) colab, and it also fails on my M2 macbook.

hawkinsp

hawkinsp commented on Mar 4, 2024

@hawkinsp
Collaborator

My gut feeling is that this is working as intended. Adding a batch dimension will allow and encourage XLA to change the order of operations, and floating point computations will not produce bitwise exact results if you change the order of operations. Indeed: that's sort of the point of vmap: we can and will compute things in a different and possibly more efficient order if there's a batch dimension.

(The most common place this surprises people is that adding a batch dimension to a matrix-vector multiplication makes it a matrix-matrix multiplication, which triggers different and sometimes lower precision matmul algorithms especially on GPUs and TPUs. However in this case I don't think that applies.)

The differences seem to be very small here: 1e-13 absolute difference in 64-bit mode. Is there a reason to think that's an unreasonable difference, allowing for different floating point orders of operations?

What do you think?

mfschubert

mfschubert commented on Mar 4, 2024

@mfschubert
ContributorAuthor

Thanks @hawkinsp for weighing in.

While the error of 1e-13 in the basic example I provided does not seem significant to me, I do have scenarios where the error is far more significant.

In the case where I encountered the issue, this error is > 100: below is the output of a test comparing the result of jnp.linalg.solve with a custom solve which uses a python for loop to handle the batch dimension.

AssertionError: 
Arrays are not equal

Mismatched elements: 964801 / 1929612 (50%)
Max absolute difference: 588.2706909
Max relative difference: 315.16777597
 x: array([[[ 3.206806e-04-1.520511e-02j,  1.619071e-03-1.980731e-03j,
         -2.531086e-03-1.379434e-03j, ...,  0.000000e+00+0.000000e+00j,
          0.000000e+00+0.000000e+00j,  0.000000e+00+0.000000e+00j],...
 y: array([[[ 3.206804e-04-1.520511e-02j,  1.619071e-03-1.980731e-03j,
         -2.531086e-03-1.379434e-03j, ...,  0.000000e+00+0.000000e+00j,
          0.000000e+00+0.000000e+00j,  0.000000e+00+0.000000e+00j],...

The custom solve (with inserted code to perform the comparison to directly using jnp.linalg.solve) is as follows:

def solve(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
    """A limited version of `linalg.solve` that has no batch dependency."""
    assert a.shape == b.shape
    m = a.shape[-1]
    a_flat = a.reshape((-1, m, m))
    b_flat = b.reshape((-1, m, m))
    results = [jnp.linalg.solve(af, bf) for af, bf in zip(a_flat, b_flat, strict=True)]
    result = jnp.asarray(results).reshape(a.shape)

    # Lines below compare result to `jnp.linalg.solve`
    result_linalg = jnp.linalg.solve(a, b)
    print(jnp.amax(jnp.abs(a)), jnp.amax(jnp.abs(b)))
    print(a.shape, b.shape)
    onp.testing.assert_array_equal(result_linalg, result)

    return result

The printed information about e.g. matrix size is as follows:

# 2654.1815475773055 39.304841915302305
# (3, 802, 802) (3, 802, 802)

For some background, I uncovered the issue when running batch simulations in fmmax; without a batch dimension, results would be accurate, while when running the same simulation twice (in a batch) yielded completely unphysical results. Using the custom solve above remedied the issue.

If it is of interest, I could probably write some code that generates such a problematic input to jnp.linalg.solve, but it may depend on fmmax.

mfschubert

mfschubert commented on Mar 4, 2024

@mfschubert
ContributorAuthor

In case it's helpful, here are plots of the magnitude of elements in a and b in the example above.

image

hawkinsp

hawkinsp commented on Mar 4, 2024

@hawkinsp
Collaborator

A repro of the real problem would be helpful. You should be able to demonstrate the problem with a problematic pair of input matrices?

mfschubert

mfschubert commented on Mar 5, 2024

@mfschubert
ContributorAuthor

Sure. The only trick is generating matrices known to be problematic, which (at the moment) I can do using fmmax. Let me know if this is workable or if we need an even more stripped-down repro. Here is a colab that generates problematic matrices and compares solve results.

https://colab.research.google.com/drive/12eB02zNt3pA2chhVCmIfv7hb0CvBFhtB

Using the generate_problematic_matrices function defined in that colab, we can run the following:

# Generate problematic matrices, with batch size 1 and batch size 2.

# problematic `a, b` with batch size 1
a1, b1 = generate_problematic_matrices(batch_size=1)

# problematic `a, b` with batch size 2
a2, b2 = generate_problematic_matrices(batch_size=2)


# Solve with batch size 1 and batch size 2
x1 = jnp.linalg.solve(a1, b1)
x2 = jnp.linalg.solve(a2, b2)


# Check that input matrices are identical
onp.testing.assert_array_equal(a2[jnp.newaxis, 0, ...], a1)  # passes
onp.testing.assert_array_equal(a2[jnp.newaxis, 1, ...], a1)  # passes
onp.testing.assert_array_equal(b2[jnp.newaxis, 0, ...], b1)  # passes
onp.testing.assert_array_equal(b2[jnp.newaxis, 1, ...], b1)  # passes

# Check that results within a batch are identical
onp.testing.assert_array_equal(x2[jnp.newaxis, 0, ...], x2[jnp.newaxis, 1, ...])  # passes

# Results with batch size 2 do not match those with batch size 1
onp.testing.assert_array_equal(x1, x2[jnp.newaxis, 0, ...])  # FAILS

The failing test has the following message:

AssertionError: 
Arrays are not equal

Mismatched elements: 80802 / 161604 (50%)
Max absolute difference: 61.684
Max relative difference: 240.05663
 x: array([[[-2.009601e-02-1.928231e-02j,  9.348132e-04-1.573229e-02j,
          9.222653e-05-1.584100e-04j, ...,  0.000000e+00+0.000000e+00j,
          0.000000e+00+0.000000e+00j,  0.000000e+00+0.000000e+00j],...
 y: array([[[-2.009643e-02-1.928224e-02j,  9.347852e-04-1.573259e-02j,
          9.184785e-05-1.591966e-04j, ...,  0.000000e+00+0.000000e+00j,
          0.000000e+00+0.000000e+00j,  0.000000e+00+0.000000e+00j],...

And for reference, the generate_problematic_matrices code is as follows:

!pip install fmmax

from typing import Tuple

import jax.numpy as jnp
import numpy as onp

from fmmax import basis, fmm, pml, scattering, utils

def generate_problematic_matrices(batch_size, approximate_num_terms=400):
    """Generates problematic matrices encountered during an fmmax simulation."""
    width = 20.0  # Simulation unit cell width
    width_pml = 2.0  # Width of perfectly matched layers on edges of unit cell.
    grid_spacing = 0.02

    wavelength = jnp.asarray([0.45] * batch_size)
    in_plane_wavevector = jnp.zeros((2,))
    primitive_lattice_vectors = basis.LatticeVectors(u=width * basis.X, v=basis.Y)
    grid_shape = (int(width / grid_spacing), 1)
    formulation = fmm.Formulation.FFT

    # Generate the expansion, the terms in the Fourier representation (1D here).
    nmax = approximate_num_terms // 2
    ix = onp.zeros((2 * nmax + 1,), dtype=int)
    ix[1::2] = -jnp.arange(1, nmax + 1, dtype=int)
    ix[2::2] = jnp.arange(1, nmax + 1, dtype=int)
    assert tuple(ix[:5].tolist()) == (0, -1, 1, -2, 2)
    expansion = basis.Expansion(
        basis_coefficients=onp.stack([ix, onp.zeros_like(ix)], axis=-1)
    )

    permittivities = [
        jnp.full(grid_shape, 1.0 + 0.0001j),
        jnp.full(grid_shape, 2.4 + 0.0001j),
    ]
    layer_thicknesses = [jnp.ones(()), jnp.ones(())]   

    def eigensolve_fn(permittivity: jnp.ndarray) -> fmm.LayerSolveResult:
        permittivities_pml, permeabilities_pml = pml.apply_uniaxial_pml(
            permittivity=permittivity,
            pml_params=pml.PMLParams(num_x=int(width_pml / grid_spacing), num_y=0),
        )
        return fmm.eigensolve_general_anisotropic_media(
            wavelength,
            in_plane_wavevector,
            primitive_lattice_vectors,
            *permittivities_pml,
            *permeabilities_pml,
            expansion=expansion,
            formulation=formulation,
            vector_field_source=jnp.mean(jnp.asarray(permittivities_pml), axis=0),
        )

    layer_solve_results = [eigensolve_fn(permittivity=p) for p in permittivities]

    # Problematic matrices are encountered in the generation of scattering
    # matrices. Code below generates these matrices directly.
    eye = utils.diag(jnp.ones_like(layer_solve_results[0].eigenvalues))
    s_matrix = scattering.ScatteringMatrix(
        s11=eye,
        s12=jnp.zeros_like(eye),
        s21=jnp.zeros_like(eye),
        s22=eye,
        start_layer_solve_result=layer_solve_results[0],
        start_layer_thickness=layer_thicknesses[0],
        end_layer_solve_result=layer_solve_results[0],
        end_layer_thickness=layer_thicknesses[0],
    )

    return problematic_matrices_from_extend_s_matrix(
        s_matrix=s_matrix,
        next_layer_solve_result=layer_solve_results[1],
        next_layer_thickness=layer_thicknesses[1],
    )


def problematic_matrices_from_extend_s_matrix(
    s_matrix: scattering.ScatteringMatrix,
    next_layer_solve_result: fmm.LayerSolveResult,
    next_layer_thickness: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Return problematic matrices encountered in `scattering.extend_s_matrix`."""
    # https://github.com/facebookresearch/fmmax/blob/main/src/fmmax/scattering.py#L251
    s_matrix_blocks = (s_matrix.s11, s_matrix.s12, s_matrix.s21, s_matrix.s22)
    layer_solve_result = s_matrix.end_layer_solve_result
    layer_thickness = s_matrix.end_layer_thickness

    # Alias for brevity: eigenvalues, eigenvectors, and omega-k matrix.
    q = layer_solve_result.eigenvalues
    phi = layer_solve_result.eigenvectors
    omega_k = layer_solve_result.omega_script_k_matrix

    next_q = next_layer_solve_result.eigenvalues
    next_phi = next_layer_solve_result.eigenvectors
    next_omega_k = next_layer_solve_result.omega_script_k_matrix

    a = omega_k @ phi
    b = next_omega_k @ next_phi @ utils.diag(1 / next_q)
    return a, b
jakevdp

jakevdp commented on Mar 5, 2024

@jakevdp
Collaborator

It looks like the matrices you are using are mildly ill-conditioned for float32 operations:

>>> print(np.linalg.cond(a1[0]))
14766922.0

This condition number is larger than 2 ** 23 = 8388608, which is the approximate dynamic range of float32 values. Given this, it's not entirely surprising that different routes to solving the same system of equations would lead to different accumulations of floating point error when operating in float32, which JAX does by default. If you need more precision, you might try enabling 64-bit computation: (see JAX Sharp Bits: Double Precision).

mfschubert

mfschubert commented on Mar 5, 2024

@mfschubert
ContributorAuthor

I see, I was a bit surprised by the magnitude of difference between the two methods, even knowing that the details of a calculation may differ e.g. when jitting. Thanks for the tips and for looking into this (and for all the great work on jax).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

    Development

    No branches or pull requests

      Participants

      @hawkinsp@jakevdp@mfschubert

      Issue actions

        Batch dependence of `jax.numpy.linalg.solve` · Issue #20047 · jax-ml/jax