-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Description
There seems to be some batch-dependence to linalg.solve. Specifically, solving a linear system, and solving a batch of identical linear systems gives different results.
The code below reproduces the problem, both on my mac and in a colab notebook.
import jax
import jax.numpy as jnp
import numpy as onp
a, b = jax.random.normal(jax.random.PRNGKey(0), (2, 100, 100))
assert a.shape == b.shape == (100, 100)
a_batch = jnp.stack([a, a], axis=0)
b_batch = jnp.stack([b, b], axis=0)
# Solve with jax, with and without batch dimensions.
sol_jax = jnp.linalg.solve(a, b)
sol_jax_with_batch = jnp.linalg.solve(a_batch, b_batch)
onp.testing.assert_array_equal(sol_jax_with_batch[0, ...], sol_jax_with_batch[1, ...]) # passes
onp.testing.assert_array_equal(sol_jax, sol_jax_with_batch[0, ...]) # FAILS
onp.testing.assert_array_equal(sol_jax, sol_jax_with_batch[1, ...]) # FAILS
# Solve with numpy, with and without batch dimensions.
sol_numpy = onp.linalg.solve(a, b)
sol_numpy_with_batch = onp.linalg.solve(a_batch, b_batch)
onp.testing.assert_array_equal(sol_numpy_with_batch[0, ...], sol_numpy_with_batch[1, ...]) # passes
onp.testing.assert_array_equal(sol_numpy, sol_numpy_with_batch[0, ...]) # passes
onp.testing.assert_array_equal(sol_numpy, sol_numpy_with_batch[1, ...]) # passesSystem info (python version, jaxlib version, accelerator, etc.)
Colab notebook, jaxlib 0.4.23
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working
Type
Projects
Milestone
Relationships
Development
Select code repository
Activity
mfschubert commentedon Mar 2, 2024
Here are the visualized errors, A @ x - b
jakevdp commentedon Mar 2, 2024
Thanks for the report - if this is on a CPU backend, then the issue may be related to openxla/xla#3891
mfschubert commentedon Mar 2, 2024
Sure thing! I have tried this on CPU and GPU (T4) colab, and it also fails on my M2 macbook.
hawkinsp commentedon Mar 4, 2024
My gut feeling is that this is working as intended. Adding a batch dimension will allow and encourage XLA to change the order of operations, and floating point computations will not produce bitwise exact results if you change the order of operations. Indeed: that's sort of the point of
vmap: we can and will compute things in a different and possibly more efficient order if there's a batch dimension.(The most common place this surprises people is that adding a batch dimension to a matrix-vector multiplication makes it a matrix-matrix multiplication, which triggers different and sometimes lower precision matmul algorithms especially on GPUs and TPUs. However in this case I don't think that applies.)
The differences seem to be very small here:
1e-13absolute difference in 64-bit mode. Is there a reason to think that's an unreasonable difference, allowing for different floating point orders of operations?What do you think?
mfschubert commentedon Mar 4, 2024
Thanks @hawkinsp for weighing in.
While the error of 1e-13 in the basic example I provided does not seem significant to me, I do have scenarios where the error is far more significant.
In the case where I encountered the issue, this error is > 100: below is the output of a test comparing the result of
jnp.linalg.solvewith a custom solve which uses a python for loop to handle the batch dimension.The custom solve (with inserted code to perform the comparison to directly using
jnp.linalg.solve) is as follows:The printed information about e.g. matrix size is as follows:
For some background, I uncovered the issue when running batch simulations in fmmax; without a batch dimension, results would be accurate, while when running the same simulation twice (in a batch) yielded completely unphysical results. Using the custom solve above remedied the issue.
If it is of interest, I could probably write some code that generates such a problematic input to
jnp.linalg.solve, but it may depend on fmmax.mfschubert commentedon Mar 4, 2024
In case it's helpful, here are plots of the magnitude of elements in
aandbin the example above.hawkinsp commentedon Mar 4, 2024
A repro of the real problem would be helpful. You should be able to demonstrate the problem with a problematic pair of input matrices?
mfschubert commentedon Mar 5, 2024
Sure. The only trick is generating matrices known to be problematic, which (at the moment) I can do using fmmax. Let me know if this is workable or if we need an even more stripped-down repro. Here is a colab that generates problematic matrices and compares solve results.
https://colab.research.google.com/drive/12eB02zNt3pA2chhVCmIfv7hb0CvBFhtB
Using the
generate_problematic_matricesfunction defined in that colab, we can run the following:The failing test has the following message:
And for reference, the
generate_problematic_matricescode is as follows:jakevdp commentedon Mar 5, 2024
It looks like the matrices you are using are mildly ill-conditioned for float32 operations:
This condition number is larger than
2 ** 23 = 8388608, which is the approximate dynamic range of float32 values. Given this, it's not entirely surprising that different routes to solving the same system of equations would lead to different accumulations of floating point error when operating in float32, which JAX does by default. If you need more precision, you might try enabling 64-bit computation: (see JAX Sharp Bits: Double Precision).mfschubert commentedon Mar 5, 2024
I see, I was a bit surprised by the magnitude of difference between the two methods, even knowing that the details of a calculation may differ e.g. when jitting. Thanks for the tips and for looking into this (and for all the great work on jax).