Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 4 additions & 124 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,125 +1,5 @@
.venv/
.venv311/
__pycache__/
/wandb/
**/*.egg-info/
# hydra logs
/outputs/
/data/lcb

# MkDocs build output (generated during build)
docs/public/api-ref/

# Documentation cache
.doctrees/
.cache/
.pytest_cache/

# NOTE (sumanthrh): Don't add .env to gitignore. .env file when passed to uv is used to set env vars for each ray worker process.
# If it's in .gitignore then it won't be a part of the working directory shipped by uv and your env vars will not be set.
# This will just appear as a warning (silent failure) and you're gonna have a bad time.
# .env

# .env files inside directories can be ignored
/skyrl-gym/.env

/skyrl-gym/.venv

# build
/skyrl-gym/build
/skyrl-gym/dist

*.log
nohup.out
tensorboard_log/

# SQLite database files
*.db

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
!docs/lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Jupyter Notebook
.ipynb_checkpoints

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# MkDocs build output
site/

# IDEs and editors
.idea/
.vscode/

# OS generated files
.DS_Store
Thumbs.db

# Hydra outputs
outputs/

# Local artifacts
tinker.db

# Alembic - don't track pycache
tx/tinker/alembic/__pycache__/

# SQLite databases (tracked in git by default, but ignore if created locally)
*.db
*.db-journal
*.db-wal
*.db-shm
*.pyc
*.egg-info/
Comment on lines +1 to +5
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The .gitignore file has been significantly truncated, removing over 100 lines of rules. This deletes critical ignores for environment variables (.env), build artifacts (outputs/, dist/), IDE settings (.vscode/, .idea/), and various cache directories. This appears to be an accidental change that should be reverted to prevent committing sensitive information or large binary artifacts to the repository.

35 changes: 25 additions & 10 deletions skyrl/train/dataset/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def compute_prompt_mini_batch_boundaries(
train_batch_size: int,
is_stepwise: bool,
n_samples_per_prompt: int,
is_training: bool = True,
) -> List[Tuple[int, int]]:
"""Compute mini-batch ``(start, end)`` slices from a flat ``uids`` list.

Expand All @@ -206,10 +207,12 @@ def compute_prompt_mini_batch_boundaries(
train_batch_size: Number of prompts in a training batch. For sanity check.
is_stepwise: Whether the training is step-wise. For sanity check.
n_samples_per_prompt: how many samples per prompt. For sanity check.
is_training: Whether this is a training batch (strict validation) or eval batch (allows partial batches).
Defaults to True for backward compatibility.
Returns:
List of (start, end) indices of the mini-batches. The length of the list is the number of
mini-batches, guaranteed to be `train_batch_size // mini_batch_size` regardless of whether
the training is step-wise or not.
mini-batches, guaranteed to be `train_batch_size // mini_batch_size` during training, but may differ
during evaluation if the final batch is partial.

Consecutive equal entries in ``uids`` belong to the same prompt. Each mini batch spans exactly
``mini_batch_size`` prompts (the last may be smaller if the total prompt count is not divisible
Expand Down Expand Up @@ -244,23 +247,35 @@ def compute_prompt_mini_batch_boundaries(
prompt_end_indices.append(i)
prompt_end_indices.append(len(uids))

# seen_uids should equal to the number of prompts and equal to `train_batch_size`
# Check that num_prompts matches expected batch size
num_prompts = len(prompt_end_indices)
assert num_prompts == train_batch_size and len(seen_uids) == train_batch_size
assert train_batch_size % mini_batch_size == 0

# Compute boundaries.
if is_training:
assert (
num_prompts == train_batch_size and len(seen_uids) == train_batch_size
), f"Expected {train_batch_size} prompts in training batch, got {num_prompts}."
assert train_batch_size % mini_batch_size == 0
else:
if num_prompts != train_batch_size:
logger.info(
f"Partial batch detected during eval: got {num_prompts} prompts but "
f"train_batch_size={train_batch_size}. Using actual batch size for mini-batch boundaries."
)

# Compute boundaries. Handle partial batches during eval.
boundaries: List[Tuple[int, int]] = []
start_seq = 0

for i in range(0, num_prompts, mini_batch_size):
end_prompt_idx = i + mini_batch_size - 1 # i + mini_batch_size is next mini-batch's first prompt's end index
end_prompt_idx = min(i + mini_batch_size - 1, num_prompts - 1)
end_seq = prompt_end_indices[end_prompt_idx]
boundaries.append((start_seq, end_seq))
start_seq = end_seq
assert len(boundaries) == train_batch_size // mini_batch_size

if is_training:
assert len(boundaries) == train_batch_size // mini_batch_size

# Assert that the mini-batch boundaries are uniform for non-step-wise training.
if not is_stepwise:
if not is_stepwise and is_training:
expected_num_seq_in_mini_batch = n_samples_per_prompt * mini_batch_size
for i, (start, end) in enumerate(boundaries):
assert start == i * expected_num_seq_in_mini_batch
Expand Down
37 changes: 15 additions & 22 deletions skyrl/train/fully_async_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,26 +384,21 @@ async def train(self):

for step_idx in range(self.global_step, (1 + epoch) * self.num_steps_per_epoch + 1):
with Timer("step", self.all_timings):
# 1. Wait until we have enough groups buffered.
# 1. Non-blocking streaming training: process mini-batch when buffer has enough data.
cur_generation_group_mini_batch: List[GeneratedOutputGroup] = []
with Timer("wait_for_generation_buffer", self.all_timings):
buffer_pbar = tqdm(
total=self.mini_batch_size,
initial=0,
desc="Generation Buffer Progress",
position=1,
)
# NOTE(Charlie): we currently trim the train_dataloader to make it perfectly divisible by
# self.mini_batch_size, and assume that all trajectories succeed (just like sync training),
# so we always get a full mini-batch. Otherwise (e.g. want to drop stale trajectories), we
# should handle the case where the dataloader is exhausted and the buffer is empty, or
# else this loop will never exit.
while len(cur_generation_group_mini_batch) < self.mini_batch_size:
while generation_output_group_buffer.qsize() < self.mini_batch_size:
# Sleep briefly to avoid busy waiting while generation workers keep running.
await asyncio.sleep(0.01)
logger.info(f"Buffer size: {generation_output_group_buffer.qsize()}")
for _ in range(self.mini_batch_size):
# We do finish-time FIFO here (not schedule-time FIFO)
cur_generation_group_mini_batch.append(await generation_output_group_buffer.get())
buffer_pbar.update(1)
buffer_pbar.set_postfix({"buffer qsize": generation_output_group_buffer.qsize()})
buffer_pbar.close()
try:
cur_generation_group_mini_batch.append(generation_output_group_buffer.get_nowait())
except asyncio.QueueEmpty as e:
raise AssertionError(
"Generation buffer unexpectedly drained while collecting a mini-batch."
) from e
Comment on lines +390 to +401
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Replacing await queue.get() with a polling loop using qsize() and asyncio.sleep(0.01) is an anti-pattern in asynchronous programming. It introduces unnecessary latency (up to 10ms per check) and CPU overhead compared to the built-in synchronization of asyncio.Queue. The original implementation using await buffer.get() in a loop was already non-blocking for the event loop and more efficient, as it leverages the queue's internal notification system to wake up the task exactly when data is available. The motivation of 'removing the blocking wait' seems to be a misunderstanding of how await works in this context.

Suggested change
while generation_output_group_buffer.qsize() < self.mini_batch_size:
# Sleep briefly to avoid busy waiting while generation workers keep running.
await asyncio.sleep(0.01)
logger.info(f"Buffer size: {generation_output_group_buffer.qsize()}")
for _ in range(self.mini_batch_size):
# We do finish-time FIFO here (not schedule-time FIFO)
cur_generation_group_mini_batch.append(await generation_output_group_buffer.get())
buffer_pbar.update(1)
buffer_pbar.set_postfix({"buffer qsize": generation_output_group_buffer.qsize()})
buffer_pbar.close()
try:
cur_generation_group_mini_batch.append(generation_output_group_buffer.get_nowait())
except asyncio.QueueEmpty as e:
raise AssertionError(
"Generation buffer unexpectedly drained while collecting a mini-batch."
) from e
while len(cur_generation_group_mini_batch) < self.mini_batch_size:
cur_generation_group_mini_batch.append(await generation_output_group_buffer.get())
logger.info(f"Buffer size: {generation_output_group_buffer.qsize()}")


# 2. Post-process the generated groups, aggregating to a single GeneratorOutput, and convert to training format.
with Timer("convert_to_training_input", self.all_timings):
Expand Down Expand Up @@ -593,11 +588,9 @@ async def _run_generate_for_a_group_loop(self, generation_output_group_buffer: a
await self._staleness_manager.on_rollout_accepted()
except asyncio.CancelledError:
# If a slot was acquired but we exit early, release running count
try:
if "slot_acquired" in locals() and slot_acquired:
raise RuntimeError("Generation workers should only be cancelled when they finish running.")
finally:
return
if "slot_acquired" in locals() and slot_acquired:
raise RuntimeError("Generation workers should only be cancelled when they finish running.")
return
Comment on lines +591 to +593
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the try...finally block around the RuntimeError means that this exception will now propagate and crash the trainer at the end of every epoch. When generator_tasks are cancelled (line 471), any worker currently generating will raise this RuntimeError, which is no longer suppressed. Furthermore, if a worker is cancelled after acquiring a slot but before finishing, the running count in the staleness_manager is leaked, which will cause the validate_state_at_epoch_end assertion to fail. The worker should instead release the slot gracefully upon cancellation.

Suggested change
if "slot_acquired" in locals() and slot_acquired:
raise RuntimeError("Generation workers should only be cancelled when they finish running.")
return
if "slot_acquired" in locals() and slot_acquired:
await self._staleness_manager.on_rollout_rejected()
return

except Exception as e:
logger.error(f"Generator worker errored out with exception: {e}")
logger.error(f"Traceback: \n{traceback.format_exc()}")
Expand Down
20 changes: 17 additions & 3 deletions skyrl/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,13 +592,17 @@ def init_weight_sync_state(self):
self.dispatch.init_weight_sync_state(self.inference_engine_client)
logger.info("Initialized weight sync state for policy model and inference engines.")

def convert_to_training_input(self, generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatch:
def convert_to_training_input(
self, generator_output: GeneratorOutput, uids: List[str], is_training: bool = True
) -> TrainingInputBatch:
"""Converts lists to a padded batch of tensors for training

Args:
generator_output (GeneratorOutput): Generated rollouts and associated data.
uids (List[str]): List of prompt-unique identifiers for each generator ouput in the same
order as `generator_output`. Used to identify which prompt each generated rollout belongs to.
is_training (bool): Whether this batch is for training (strict batch size) or evaluation
(allows partial batches). Defaults to True for backward compatibility.
Returns:
training_input (TrainingInputBatch): Padded batch of tensors for training. It preserves the
order of `generator_output` and hence `uids`.
Expand Down Expand Up @@ -680,11 +684,21 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis
n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt
is_stepwise = self.cfg.generator.step_wise_trajectories
training_input.metadata["policy_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries(
uids, self.cfg.trainer.policy_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt
uids,
self.cfg.trainer.policy_mini_batch_size,
train_batch_size,
is_stepwise,
n_samples_per_prompt,
is_training=is_training,
)
if self.cfg.trainer.critic.model.path is not None:
training_input.metadata["critic_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries(
uids, self.cfg.trainer.critic_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt
uids,
self.cfg.trainer.critic_mini_batch_size,
train_batch_size,
is_stepwise,
n_samples_per_prompt,
is_training=is_training,
)

# 5. Record metadata and metrics.
Expand Down
80 changes: 79 additions & 1 deletion tests/train/test_prompt_mini_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,88 @@ def test_same_step_count_as_non_stepwise(self):
)

assert len(stepwise_bounds) == len(non_stepwise_bounds) == 2

# Non-step-wise boundaries should be uniform
assert non_stepwise_bounds == [(0, 640), (640, 1280)]

def test_eval_partial_batch_nonstepwise(self):
"""Test eval mode with partial batches during non-stepwise training.

This addresses the issue where evaluation crashes when val set size is
not divisible by train_batch_size. With is_training=False, partial
batches should be allowed.
"""
train_batch_size = 4
spp = 2
is_stepwise = False
mini_batch_size = 2

# Only 3 prompts instead of 4 (partial batch)
uids = ["p0", "p0", "p1", "p1", "p2", "p2"]

# Should work fine with is_training=False
boundaries = compute_prompt_mini_batch_boundaries(
uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False
)
# With 3 prompts and mini_batch_size=2, we get 2 mini-batches:
# First mini-batch: prompts 0-1 (sequences 0-4)
# Second mini-batch: prompt 2 (sequences 4-6)
assert boundaries == [(0, 4), (4, 6)]

def test_eval_partial_batch_single_minibatch(self):
"""Test eval mode with partial batch that fits in single mini-batch."""
train_batch_size = 4
spp = 2
is_stepwise = False
mini_batch_size = 2

# Only 1 prompt instead of 4 (very partial batch)
uids = ["p0", "p0"]

boundaries = compute_prompt_mini_batch_boundaries(
uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False
)
# With 1 prompt and mini_batch_size=2, we get 1 mini-batch
assert boundaries == [(0, 2)]

def test_eval_rejects_noncontiguous_uids(self):
"""Test that eval mode still enforces contiguous uids."""
train_batch_size = 4
spp = 2
is_stepwise = False
mini_batch_size = 2
# Non-contiguous uids: p0 appears at index 0-1 and 4-5
uids = ["p0", "p0", "p1", "p1", "p0", "p0"]

with pytest.raises(AssertionError, match="uid 'p0' appears in non-contiguous positions"):
compute_prompt_mini_batch_boundaries(
uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False
)

def test_eval_stepwise_partial_batch(self):
"""Test eval mode with stepwise training and partial batch."""
mini_batch_size = 2
train_batch_size = 4
spp = 2
is_stepwise = True

# Only 3 prompts instead of 4
uids = _make_uids_stepwise(
[
("p0", 2, [3, 2]), # 5 seqs
("p1", 2, [1, 4]), # 5 seqs
("p2", 2, [2, 1]), # 3 seqs
]
)

# Should work fine with is_training=False
boundaries = compute_prompt_mini_batch_boundaries(
uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False
)
# With 3 prompts and mini_batch_size=2, we get 2 mini-batches:
# First: prompts 0-1 (sequences 0-10)
# Second: prompt 2 (sequences 10-13)
assert boundaries == [(0, 10), (10, 13)]


# ---------------------------------------------------------------------------
# Tests for MeshDispatch.stage_chunks
Expand Down
Loading