add straggler detection integration for the training#1179
Open
wanglei19991004 wants to merge 2 commits intoflagos-ai:mainfrom
Open
add straggler detection integration for the training#1179wanglei19991004 wants to merge 2 commits intoflagos-ai:mainfrom
wanglei19991004 wants to merge 2 commits intoflagos-ai:mainfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR Category
Train
PR Types
New Features
PR Description
Straggler Detection is used to monitor the performance of each node and GPU during distributed training, and to detect whether there are any “stragglers” (i.e., workers that are significantly slower than others). If a GPU or node runs noticeably slower, it will become a bottleneck and slow down the overall distributed training process.
Quick Start
Single-node example (using GPT-2)
When running run.py, enable the Straggler Detection feature by overriding system configuration parameters:
Multi-node example (2 nodes × 4 GPUs)
The multi-node setup is similar to the single-node case. You just need to additionally specify parameters such as hostfile and the master node address in the runner configuration:
Note
The program automatically identifies the physical machine where each rank is located (by retrieving the node name via os.environ.get('HOSTNAME') or socket.gethostname()).
Core Configuration Parameters
These parameters can be overridden via the command line (as shown above) or modified in the YAML configuration file:
enable_straggler_detection(bool): Whether to enable Straggler detection (default: False).straggler_profiling_interval(int): Profiling interval, i.e., how many steps between recording runtime statistics (default: 10).straggler_report_interval(int): Reporting interval, i.e., how many steps between generating and saving a statistical analysis report (default: 100).straggler_threshold(float): Relative latency threshold for identifying a straggler. For example, 1.5 means a node is considered a straggler if it is 50% slower or more than the fastest node (default: 1.5).straggler_log_dir(str): Directory where Straggler JSON report files are saved.Interpreting Output Reports
When the configured straggler_report_interval is reached:
A text-based report is printed to the console on Rank 0.
A JSON report file is periodically generated in the specified straggler_log_dir.
A file such as straggler_report_step_10.json contains detailed aggregated data:
JSON Field Descriptions
Higher values indicate better performance.
e.g., [2, 7] → these ranks are significantly slower
Custom Monitoring (for Advanced / Secondary Development)
In addition to the default monitored sections (forward_backward and optimizer), users can instrument custom code regions for profiling.
Use case:
If you implement complex logic (e.g., custom loss functions or data preprocessing) and want to analyze its performance across GPUs or detect stragglers, you can manually instrument the code as follows:
Alternatively, you can use lower-level utilities such as SectionContext or decorators.
The recorded section data will be collected and reported when the report_interval is reached.