Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion mason.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,12 @@ def make_internal_command(command: list[str], args: argparse.Namespace, whoami:
is_open_instruct_training = any(cmd in command for cmd in OPEN_INSTRUCT_COMMANDS)
if is_open_instruct_training:
from open_instruct.dataset_transformation import get_commit_hash
from open_instruct.utils import download_from_hf, gs_folder_exists, upload_to_gs_bucket
from open_instruct.utils import (
download_from_gs_bucket,
download_from_hf,
gs_folder_exists,
upload_to_gs_bucket,
)

# HACK: Cache dataset logic:
# Here we basically try to run the tokenization full_command locally before running it on beaker
Expand All @@ -467,6 +472,31 @@ def make_internal_command(command: list[str], args: argparse.Namespace, whoami:
continue

filtered_command = build_command_without_args(command[idx:], CACHE_EXCLUDED_ARGS)

# if model is only on gs, download tokenizer from gs for dataset preprocessing
try:
model_arg_idx = filtered_command.index("--model_name_or_path")
model_name_idx = model_arg_idx + 1
model_name_or_path = filtered_command[model_name_idx].rstrip("/")

if model_name_or_path.startswith("gs://"):
model_name_hash = hashlib.md5(model_name_or_path.encode("utf-8")).hexdigest()[:8]
local_cache_folder = f"{args.auto_output_dir_path}/{whoami}/tokenizer_{model_name_hash}/"

if not os.path.exists(local_cache_folder):
download_from_gs_bucket(
[
f"{model_name_or_path}/tokenizer.json",
f"{model_name_or_path}/tokenizer_config.json",
f"{model_name_or_path}/config.json",
],
local_cache_folder,
)
Comment on lines +487 to +494
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The list of files to download is hardcoded inside the download_from_gs_bucket call. To improve readability and maintainability, consider extracting this list into a named variable before the call. This makes it clearer what files are being downloaded. For example:

tokenizer_files = [
    "tokenizer.json",
    "tokenizer_config.json",
    "config.json",
]
download_from_gs_bucket(
    [f"{model_name_or_path}/{f}" for f in tokenizer_files],
    local_cache_folder,
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ehh, I think mine is fine as is


filtered_command[model_name_idx] = local_cache_folder
except ValueError:
pass
Comment on lines +497 to +498
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The try...except block only catches ValueError, but an IndexError can occur on line 480 if --model_name_or_path is the last argument in filtered_command. This would cause the script to crash. It's safer to catch both exceptions to handle this edge case gracefully.

Suggested change
except ValueError:
pass
except (ValueError, IndexError):
pass

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if model_name_or_path is the last argument then we would want to crash as that's a weird input

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rewrite this to make these changes:

  1. add a comment saying when we get a ValueError
  2. Can you change this to be a function, and return early when a condition is not true? then we can have less nesting, which will make it easier to follow.


caching_command = "python " + " ".join(filtered_command) + " --cache_dataset_only"
console.log("📦📦📦 Running the caching command with `--cache_dataset_only`")
import subprocess
Expand Down
9 changes: 6 additions & 3 deletions open_instruct/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,8 @@ def download_from_hf(model_name_or_path: str, revision: str) -> None:
return output


def download_from_gs_bucket(src_path: str, dest_path: str) -> None:
def download_from_gs_bucket(src_paths: str | list[str], dest_path: str) -> None:
os.makedirs(dest_path, exist_ok=True)
cmd = [
"gsutil",
"-o",
Expand All @@ -1074,9 +1075,11 @@ def download_from_gs_bucket(src_path: str, dest_path: str) -> None:
"-m",
"cp",
"-r",
src_path,
dest_path,
]
if not isinstance(src_paths, list):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's change this to only take a list and force the callers to pass one in?

src_paths = [src_paths]
cmd.extend(src_paths)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some tests here? Let's mock live_subprocess_output so we capture the cmd it's called with and verify it against some known correct values. Should be a one prompt change with Codex

cmd.append(dest_path)
print(f"Downloading from GS bucket with command: {cmd}")
live_subprocess_output(cmd)

Expand Down