-
Notifications
You must be signed in to change notification settings - Fork 328
[train] Support streaming mini-batch (non-blocking async training) #1607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7185ef1
4adaade
6d7f74c
dbb781d
bde201e
1e703db
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,125 +1,5 @@ | ||
| .venv/ | ||
| .venv311/ | ||
| __pycache__/ | ||
| /wandb/ | ||
| **/*.egg-info/ | ||
| # hydra logs | ||
| /outputs/ | ||
| /data/lcb | ||
|
|
||
| # MkDocs build output (generated during build) | ||
| docs/public/api-ref/ | ||
|
|
||
| # Documentation cache | ||
| .doctrees/ | ||
| .cache/ | ||
| .pytest_cache/ | ||
|
|
||
| # NOTE (sumanthrh): Don't add .env to gitignore. .env file when passed to uv is used to set env vars for each ray worker process. | ||
| # If it's in .gitignore then it won't be a part of the working directory shipped by uv and your env vars will not be set. | ||
| # This will just appear as a warning (silent failure) and you're gonna have a bad time. | ||
| # .env | ||
|
|
||
| # .env files inside directories can be ignored | ||
| /skyrl-gym/.env | ||
|
|
||
| /skyrl-gym/.venv | ||
|
|
||
| # build | ||
| /skyrl-gym/build | ||
| /skyrl-gym/dist | ||
|
|
||
| *.log | ||
| nohup.out | ||
| tensorboard_log/ | ||
|
|
||
| # SQLite database files | ||
| *.db | ||
|
|
||
| # Byte-compiled / optimized / DLL files | ||
| __pycache__/ | ||
| *.py[cod] | ||
| *$py.class | ||
|
|
||
| # C extensions | ||
| *.so | ||
|
|
||
| # Distribution / packaging | ||
| .Python | ||
| build/ | ||
| develop-eggs/ | ||
| dist/ | ||
| downloads/ | ||
| eggs/ | ||
| .eggs/ | ||
| lib/ | ||
| !docs/lib/ | ||
| lib64/ | ||
| parts/ | ||
| sdist/ | ||
| var/ | ||
| wheels/ | ||
| pip-wheel-metadata/ | ||
| share/python-wheels/ | ||
| *.egg-info/ | ||
| .installed.cfg | ||
| *.egg | ||
| MANIFEST | ||
|
|
||
| # PyInstaller | ||
| *.manifest | ||
| *.spec | ||
|
|
||
| # Installer logs | ||
| pip-log.txt | ||
| pip-delete-this-directory.txt | ||
|
|
||
| # Unit test / coverage reports | ||
| htmlcov/ | ||
| .tox/ | ||
| .nox/ | ||
| .coverage | ||
| .coverage.* | ||
| .cache | ||
| nosetests.xml | ||
| coverage.xml | ||
| *.cover | ||
| *.py,cover | ||
| .hypothesis/ | ||
| .pytest_cache/ | ||
|
|
||
| # Jupyter Notebook | ||
| .ipynb_checkpoints | ||
|
|
||
| # Environments | ||
| .env | ||
| .venv | ||
| env/ | ||
| venv/ | ||
| ENV/ | ||
| env.bak/ | ||
| venv.bak/ | ||
|
|
||
| # MkDocs build output | ||
| site/ | ||
|
|
||
| # IDEs and editors | ||
| .idea/ | ||
| .vscode/ | ||
|
|
||
| # OS generated files | ||
| .DS_Store | ||
| Thumbs.db | ||
|
|
||
| # Hydra outputs | ||
| outputs/ | ||
|
|
||
| # Local artifacts | ||
| tinker.db | ||
|
|
||
| # Alembic - don't track pycache | ||
| tx/tinker/alembic/__pycache__/ | ||
|
|
||
| # SQLite databases (tracked in git by default, but ignore if created locally) | ||
| *.db | ||
| *.db-journal | ||
| *.db-wal | ||
| *.db-shm | ||
| *.pyc | ||
| *.egg-info/ | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -384,26 +384,21 @@ async def train(self): | |||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| for step_idx in range(self.global_step, (1 + epoch) * self.num_steps_per_epoch + 1): | ||||||||||||||||||||||||||||||||||||||||
| with Timer("step", self.all_timings): | ||||||||||||||||||||||||||||||||||||||||
| # 1. Wait until we have enough groups buffered. | ||||||||||||||||||||||||||||||||||||||||
| # 1. Non-blocking streaming training: process mini-batch when buffer has enough data. | ||||||||||||||||||||||||||||||||||||||||
| cur_generation_group_mini_batch: List[GeneratedOutputGroup] = [] | ||||||||||||||||||||||||||||||||||||||||
| with Timer("wait_for_generation_buffer", self.all_timings): | ||||||||||||||||||||||||||||||||||||||||
| buffer_pbar = tqdm( | ||||||||||||||||||||||||||||||||||||||||
| total=self.mini_batch_size, | ||||||||||||||||||||||||||||||||||||||||
| initial=0, | ||||||||||||||||||||||||||||||||||||||||
| desc="Generation Buffer Progress", | ||||||||||||||||||||||||||||||||||||||||
| position=1, | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| # NOTE(Charlie): we currently trim the train_dataloader to make it perfectly divisible by | ||||||||||||||||||||||||||||||||||||||||
| # self.mini_batch_size, and assume that all trajectories succeed (just like sync training), | ||||||||||||||||||||||||||||||||||||||||
| # so we always get a full mini-batch. Otherwise (e.g. want to drop stale trajectories), we | ||||||||||||||||||||||||||||||||||||||||
| # should handle the case where the dataloader is exhausted and the buffer is empty, or | ||||||||||||||||||||||||||||||||||||||||
| # else this loop will never exit. | ||||||||||||||||||||||||||||||||||||||||
| while len(cur_generation_group_mini_batch) < self.mini_batch_size: | ||||||||||||||||||||||||||||||||||||||||
| while generation_output_group_buffer.qsize() < self.mini_batch_size: | ||||||||||||||||||||||||||||||||||||||||
| # Sleep briefly to avoid busy waiting while generation workers keep running. | ||||||||||||||||||||||||||||||||||||||||
| await asyncio.sleep(0.01) | ||||||||||||||||||||||||||||||||||||||||
| logger.info(f"Buffer size: {generation_output_group_buffer.qsize()}") | ||||||||||||||||||||||||||||||||||||||||
| for _ in range(self.mini_batch_size): | ||||||||||||||||||||||||||||||||||||||||
| # We do finish-time FIFO here (not schedule-time FIFO) | ||||||||||||||||||||||||||||||||||||||||
| cur_generation_group_mini_batch.append(await generation_output_group_buffer.get()) | ||||||||||||||||||||||||||||||||||||||||
| buffer_pbar.update(1) | ||||||||||||||||||||||||||||||||||||||||
| buffer_pbar.set_postfix({"buffer qsize": generation_output_group_buffer.qsize()}) | ||||||||||||||||||||||||||||||||||||||||
| buffer_pbar.close() | ||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||
| cur_generation_group_mini_batch.append(generation_output_group_buffer.get_nowait()) | ||||||||||||||||||||||||||||||||||||||||
| except asyncio.QueueEmpty as e: | ||||||||||||||||||||||||||||||||||||||||
| raise AssertionError( | ||||||||||||||||||||||||||||||||||||||||
| "Generation buffer unexpectedly drained while collecting a mini-batch." | ||||||||||||||||||||||||||||||||||||||||
| ) from e | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+390
to
+401
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replacing
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # 2. Post-process the generated groups, aggregating to a single GeneratorOutput, and convert to training format. | ||||||||||||||||||||||||||||||||||||||||
| with Timer("convert_to_training_input", self.all_timings): | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -593,11 +588,9 @@ async def _run_generate_for_a_group_loop(self, generation_output_group_buffer: a | |||||||||||||||||||||||||||||||||||||||
| await self._staleness_manager.on_rollout_accepted() | ||||||||||||||||||||||||||||||||||||||||
| except asyncio.CancelledError: | ||||||||||||||||||||||||||||||||||||||||
| # If a slot was acquired but we exit early, release running count | ||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||
| if "slot_acquired" in locals() and slot_acquired: | ||||||||||||||||||||||||||||||||||||||||
| raise RuntimeError("Generation workers should only be cancelled when they finish running.") | ||||||||||||||||||||||||||||||||||||||||
| finally: | ||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||
| if "slot_acquired" in locals() and slot_acquired: | ||||||||||||||||||||||||||||||||||||||||
| raise RuntimeError("Generation workers should only be cancelled when they finish running.") | ||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+591
to
+593
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The removal of the
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||||||||||
| logger.error(f"Generator worker errored out with exception: {e}") | ||||||||||||||||||||||||||||||||||||||||
| logger.error(f"Traceback: \n{traceback.format_exc()}") | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
.gitignorefile has been significantly truncated, removing over 100 lines of rules. This deletes critical ignores for environment variables (.env), build artifacts (outputs/,dist/), IDE settings (.vscode/,.idea/), and various cache directories. This appears to be an accidental change that should be reverted to prevent committing sensitive information or large binary artifacts to the repository.