Skip to content

Determinism in Gemma model #82

@daralthus

Description

@daralthus

Hello!
I am trying to understand, is there a particular reason for this behaviour?
image

From the code it seems like the model doesn't take an rng.
image

The model here is gemma, loaded as in the tutorial.

Activity

danieldjohnson

danieldjohnson commented on Aug 20, 2024

@danieldjohnson
Collaborator

I believe this is a side effect of how JAX (and XLA) compute matrix-matrix products vs matrix-vector products. In short, batched matrix multiplication can cause XLA to use a different algorithm to compute the products, which can slightly change results due to the non-associativity of floating point operations.

Here's a related comment that seems to describe the same issue: jax-ml/jax#20047 (comment)

Relevant bit:

My gut feeling is that this is working as intended. Adding a batch dimension will allow and encourage XLA to change the order of operations, and floating point computations will not produce bitwise exact results if you change the order of operations. Indeed: that's sort of the point of vmap: we can and will compute things in a different and possibly more efficient order if there's a batch dimension.

The most common place this surprises people is that adding a batch dimension to a matrix-vector multiplication makes it a matrix-matrix multiplication, which triggers different and sometimes lower precision matmul algorithms especially on GPUs and TPUs.

The size of the difference here is probably because you are using bfloat16, which is a very low precision format and can lead to cascading differences in model outputs. You could try initializing the model with upcast_activations_to_float32=True, which will use a higher precision to store the model activations and should lead to smaller discrepancies (at the cost of somewhat slower speed and higher memory usage).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @daralthus@danieldjohnson

        Issue actions

          Determinism in Gemma model ยท Issue #82 ยท google-deepmind/penzai