Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ optm_weight_decay = 1e-10
optm_betas = [0.9, 0.95]
optm_warmup_steps = 1000
optm_grad_norm_clip = 1.0
optm_decay_type = "cosine"
optm_decay_ratio = 1.0
epsilon = 1e-8
compile = true
param_dtype = "bfloat16"
Expand Down Expand Up @@ -73,4 +75,7 @@ model_type = "PI05"
prompt_from_task = true
skip_norm_stats = false
norm_stats="/workspace/official_pi05_base_norm_stats.json"
episodes_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]
episodes_index = [0, 200]
add_dataset_name = "delinqu/comet-1.5k"
add_dataset_local_dir = "/workspace/comet-1.5k"
add_data_multiplier = 2
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ optm_weight_decay = 1e-10
optm_betas = [0.9, 0.95]
optm_warmup_steps = 1000
optm_grad_norm_clip = 1.0
optm_decay_type = "cosine"
optm_decay_ratio = 1.0
epsilon = 1e-8
compile = true
param_dtype = "bfloat16"
Expand Down Expand Up @@ -73,4 +75,7 @@ model_type = "PI05"
prompt_from_task = true
skip_norm_stats = false
norm_stats="/workspace/comet_weights_pytorch_2/pi05-b1kpt50-cs32/assets/behavior-1k/2025-challenge-demos/norm_stats.json" # default use json from resumed weights
episodes_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]
episodes_index = [0, 200]
add_dataset_name = "delinqu/comet-1.5k"
add_dataset_local_dir = "/workspace/comet-1.5k"
add_data_multiplier = 2
15 changes: 14 additions & 1 deletion cosmos_rl/tools/dataset/b1k_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,19 @@ def __call__(self, data: dict) -> dict:
return inputs


def parse_episodes_index(episodes: Any) -> list[int]:
if not isinstance(episodes, list) or not all(isinstance(i, int) for i in episodes):
raise ValueError("`episodes_index` must be a list of integers.")

if len(episodes) == 2:
start, end = episodes
if end < start:
raise ValueError(f"`episodes_index` has invalid interval: end({end}) < start({start}).")
return list(range(start, end))

return episodes


class BehaviorSFTDataset(Dataset):
"""
Thin wrapper around `BehaviorLeRobotDataset` used for SFT.
Expand All @@ -523,7 +536,7 @@ def __init__(self, config: CosmosConfig):
key: [t / 30.0 for t in range(config.custom["action_horizon"])]
for key in config.custom["action_sequence_keys"]
},
episodes=config.custom["episodes_index"],
episodes=parse_episodes_index(config.custom["episodes_index"]),
chunk_streaming_using_keyframe=True,
shuffle=True,
)
Expand Down
Loading