-
Notifications
You must be signed in to change notification settings - Fork 16
Prevent max ids per partition fluctuating by always taking the max. #168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @hertschuh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical issue in the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly addresses the issue of fluctuating max_ids_per_partition and max_unique_ids_per_partition values by ensuring they are monotonically increasing. The fix is to take the maximum of the current batch's statistics and the existing stored statistics. The changes are logical and well-implemented. I've added one suggestion to improve code readability and use more idiomatic Python.
| for stack_name, spec in stacked_table_specs.items(): | ||
| all_stats.max_ids_per_partition[stack_name] = np.max( | ||
| [ | ||
| all_stats.max_ids_per_partition[stack_name], | ||
| spec.max_ids_per_partition, | ||
| ] | ||
| ) | ||
| all_stats.max_unique_ids_per_partition[stack_name] = np.max( | ||
| [ | ||
| all_stats.max_unique_ids_per_partition[stack_name], | ||
| spec.max_unique_ids_per_partition, | ||
| ] | ||
| ) | ||
| all_stats.required_buffer_size_per_sc[stack_name] = np.max( | ||
| [ | ||
| all_stats.required_buffer_size_per_sc[stack_name], | ||
| ( | ||
| (spec.suggested_coo_buffer_size_per_device or 0) | ||
| + (num_sc_per_device - 1) | ||
| ) | ||
| // num_sc_per_device, | ||
| ] | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For improved readability and to use more idiomatic Python, you can refactor this loop. Using the built-in max() is more direct for comparing two scalar values than np.max() on a list of two elements. Additionally, extracting the calculation for the old required_buffer_size_per_sc into a separate variable makes the logic clearer.
for stack_name, spec in stacked_table_specs.items():
all_stats.max_ids_per_partition[stack_name] = max(
all_stats.max_ids_per_partition[stack_name],
spec.max_ids_per_partition,
)
all_stats.max_unique_ids_per_partition[stack_name] = max(
all_stats.max_unique_ids_per_partition[stack_name],
spec.max_unique_ids_per_partition,
)
old_required_buffer_size_per_sc = (
(spec.suggested_coo_buffer_size_per_device or 0)
+ (num_sc_per_device - 1)
) // num_sc_per_device
all_stats.required_buffer_size_per_sc[stack_name] = max(
all_stats.required_buffer_size_per_sc[stack_name],
old_required_buffer_size_per_sc,
)There are two scenarios that currently cause the numbers to go down: 1. The numbers went up for one stack, but slightly lower for another stack. Because we pass in all the stacks to update_preprocessing_parameters it will update all the specs for all the stacks, meaning it will put lower values in the other stacks. So we could filter out the stacks that went down, but it still doesn't cover the next case. 2. For one stack, the `max_ids_per_partition` went up, but the `max_unique_ids_per_partition` went down for a batch for whatever reason. Then it will update the stats, it will put the higher value for `max_ids_per_partition` but it will put a lower value for `max_unique_ids_per_partition`. So instead of converging to the upper bounds, the max_ids_per_partitionand max_unique_ids_per_partition keep going up and down batch after batch. The fix is to do a max to take the max between the current `StackedTableSpec` and the incoming max values.
There are two scenarios that currently cause the numbers to go down:
The numbers went up for one stack, but slightly lower for another stack. Because we pass in all the stacks to update_preprocessing_parameters it will update all the specs for all the stacks, meaning it will put lower values in the other stacks. So we could filter out the stacks that went down, but it still doesn't cover the next case.
For one stack, the
max_ids_per_partitionwent up, but themax_unique_ids_per_partitionwent down for a batch for whatever reason. Then it will update the stats, it will put the higher value formax_ids_per_partitionbut it will put a lower value formax_unique_ids_per_partition.So instead of converging to the upper bounds, the max_ids_per_partitionand max_unique_ids_per_partition keep going up and down batch after batch.
The fix is to do a max to take the max between the current
StackedTableSpecand the incoming max values.