Skip to content

Conversation

@copybara-service
Copy link

[JAX SC] required_buffer_size_per_sc is being undercounted when minibatching is enabled. The initial calculation does not account for the alignment padding added at the end of each minibatch.

This potentially leads to the following sequence of events:

  1. A smaller-than-needed buffer size was calculated and reported. FDO used this incorrect value to reconfigure the buffer for the next run.
  2. The buffer would inevitably overflow and drop IDs.
  3. Could cause a persistent FDO loop.

Now, we can recompute the required buffer size after the minibatches have been merged and their final memory layout is known.

…ibatching is enabled. The initial calculation does not account for the alignment padding added at the end of each minibatch.

This potentially leads to the following sequence of events:

1. A smaller-than-needed buffer size was calculated and reported. FDO used this incorrect value to reconfigure the buffer for the next run.
2. The buffer would inevitably overflow and drop IDs.
3. Could cause a persistent FDO loop.

Now, we can recompute the required buffer size after the minibatches have been merged and their final memory layout is known.

PiperOrigin-RevId: 824850862
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant