Skip to content

Conversation

@hertschuh
Copy link
Collaborator

There are two scenarios that currently cause the numbers to go down:

  1. The numbers went up for one stack, but slightly lower for another stack. Because we pass in all the stacks to update_preprocessing_parameters it will update all the specs for all the stacks, meaning it will put lower values in the other stacks. So we could filter out the stacks that went down, but it still doesn't cover the next case.

  2. For one stack, the max_ids_per_partition went up, but the max_unique_ids_per_partition went down for a batch for whatever reason. Then it will update the stats, it will put the higher value for max_ids_per_partition but it will put a lower value for max_unique_ids_per_partition.

So instead of converging to the upper bounds, the max_ids_per_partitionand max_unique_ids_per_partition keep going up and down batch after batch.

The fix is to do a max to take the max between the current StackedTableSpec and the incoming max values.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @hertschuh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical issue in the DistributedEmbedding layer where key preprocessing parameters, specifically max_ids_per_partition and max_unique_ids_per_partition, were exhibiting undesirable fluctuations. Instead of monotonically increasing to reflect true maximums, these values could decrease due to interactions between different stacks or batch-specific anomalies. The implemented solution ensures that these parameters are always updated by taking the maximum of their current state and newly computed statistics, thereby guaranteeing a stable and accurate representation of the required resource allocation.

Highlights

  • Addressing fluctuating ID counts: This change resolves an issue where max_ids_per_partition and max_unique_ids_per_partition would fluctuate downwards instead of consistently increasing towards their true maximums.
  • Root cause of fluctuation: The fluctuation was caused by two scenarios: either one stack's higher values inadvertently lowering stats for other stacks during updates, or max_unique_ids_per_partition decreasing for a batch while max_ids_per_partition increased.
  • Ensuring monotonic increase: The fix ensures that max_ids_per_partition, max_unique_ids_per_partition, and required_buffer_size_per_sc always take the maximum value between the currently stored StackedTableSpec and the newly aggregated statistics, preventing any downward adjustments.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses the issue of fluctuating max_ids_per_partition and max_unique_ids_per_partition values by ensuring they are monotonically increasing. The fix is to take the maximum of the current batch's statistics and the existing stored statistics. The changes are logical and well-implemented. I've added one suggestion to improve code readability and use more idiomatic Python.

Comment on lines +649 to +689
for stack_name, spec in stacked_table_specs.items():
all_stats.max_ids_per_partition[stack_name] = np.max(
[
all_stats.max_ids_per_partition[stack_name],
spec.max_ids_per_partition,
]
)
all_stats.max_unique_ids_per_partition[stack_name] = np.max(
[
all_stats.max_unique_ids_per_partition[stack_name],
spec.max_unique_ids_per_partition,
]
)
all_stats.required_buffer_size_per_sc[stack_name] = np.max(
[
all_stats.required_buffer_size_per_sc[stack_name],
(
(spec.suggested_coo_buffer_size_per_device or 0)
+ (num_sc_per_device - 1)
)
// num_sc_per_device,
]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved readability and to use more idiomatic Python, you can refactor this loop. Using the built-in max() is more direct for comparing two scalar values than np.max() on a list of two elements. Additionally, extracting the calculation for the old required_buffer_size_per_sc into a separate variable makes the logic clearer.

                for stack_name, spec in stacked_table_specs.items():
                    all_stats.max_ids_per_partition[stack_name] = max(
                        all_stats.max_ids_per_partition[stack_name],
                        spec.max_ids_per_partition,
                    )
                    all_stats.max_unique_ids_per_partition[stack_name] = max(
                        all_stats.max_unique_ids_per_partition[stack_name],
                        spec.max_unique_ids_per_partition,
                    )
                    old_required_buffer_size_per_sc = (
                        (spec.suggested_coo_buffer_size_per_device or 0)
                        + (num_sc_per_device - 1)
                    ) // num_sc_per_device
                    all_stats.required_buffer_size_per_sc[stack_name] = max(
                        all_stats.required_buffer_size_per_sc[stack_name],
                        old_required_buffer_size_per_sc,
                    )

There are two scenarios that currently cause the numbers to go down:

1. The numbers went up for one stack, but slightly lower for another stack. Because we pass in all the stacks to update_preprocessing_parameters it will update all the specs for all the stacks, meaning it will put lower values in the other stacks. So we could filter out the stacks that went down, but it still doesn't cover the next case.

2. For one stack, the `max_ids_per_partition` went up, but the `max_unique_ids_per_partition` went down for a batch for whatever reason. Then it will update the stats, it will put the higher value for `max_ids_per_partition` but it will put a lower value for `max_unique_ids_per_partition`.

So instead of converging to the upper bounds, the max_ids_per_partitionand max_unique_ids_per_partition keep going up and down batch after batch.

The fix is to do a max to take the max between the current `StackedTableSpec` and the incoming max values.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant