diff --git a/.github/workflows/launch_small_fast.yaml b/.github/workflows/launch_small_fast.yaml index 15f423674..584795201 100644 --- a/.github/workflows/launch_small_fast.yaml +++ b/.github/workflows/launch_small_fast.yaml @@ -41,7 +41,7 @@ jobs: - name: Install locally run: | python -m pip install --upgrade pip - pip install -e .[test] "jax[cpu]==0.4.30" + pip install -e .[test] "jax[cpu]==0.4.38" - name: Launch Small Fast TPU Train LM job run: | diff --git a/.github/workflows/run_entry_tests.yaml b/.github/workflows/run_entry_tests.yaml index ab08013ee..d9de2d815 100644 --- a/.github/workflows/run_entry_tests.yaml +++ b/.github/workflows/run_entry_tests.yaml @@ -9,7 +9,7 @@ jobs: strategy: matrix: python-version: ["3.10"] - jax-version: ["0.4.26"] + jax-version: ["0.4.38"] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/run_pre_commit.yaml b/.github/workflows/run_pre_commit.yaml index ee3f0f587..842354ae0 100644 --- a/.github/workflows/run_pre_commit.yaml +++ b/.github/workflows/run_pre_commit.yaml @@ -10,7 +10,7 @@ jobs: strategy: matrix: python-version: ["3.10"] - jax-version: ["0.4.14"] + jax-version: ["0.4.38"] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/run_ray_tests.yaml b/.github/workflows/run_ray_tests.yaml index 42139e576..a1788f777 100644 --- a/.github/workflows/run_ray_tests.yaml +++ b/.github/workflows/run_ray_tests.yaml @@ -9,7 +9,7 @@ jobs: strategy: matrix: python-version: ["3.10"] - jax-version: ["0.4.26"] + jax-version: ["0.4.38"] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 6e9ed7024..ac01bcf5e 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -10,7 +10,7 @@ jobs: strategy: matrix: python-version: ["3.10"] - jax-version: ["0.4.26"] + jax-version: ["0.4.38"] steps: - uses: actions/checkout@v3 diff --git a/.gitignore b/.gitignore index 9615f94ab..018ad5497 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,8 @@ /scratch # Configuration for TPU launches/secrets -.config +.levanter.yaml +.levanter.yaml # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/config/gpt2_small_fast_eval.yaml b/config/gpt2_small_fast_eval.yaml index 14638db1b..11245c7b3 100644 --- a/config/gpt2_small_fast_eval.yaml +++ b/config/gpt2_small_fast_eval.yaml @@ -24,6 +24,11 @@ supervised_data: cache_dir: "gs://levanter-data/tokenized-gpt2/arc_easy/" tags: [ "arc", "e"] +eval_harness: + task_spec: ["piqa", "hellaswag"] + max_examples: 2048 +eval_harness_steps: 1000 + model: type: gpt2 hidden_dim: 768 diff --git a/config/llama2_7b.yaml b/config/llama2_7b.yaml index b4ebe705f..11555f577 100644 --- a/config/llama2_7b.yaml +++ b/config/llama2_7b.yaml @@ -6,24 +6,108 @@ data: cache_dir: "gs://levanter-data/tokenized/openwebtext_llama/" tokenizer: "meta-llama/Llama-2-70b-hf" model: + activation_function: silu + attn_backend: null + cross_entropy_block_size: null + flash_attention_block_size: null + gradient_checkpointing: true + gradient_checkpointing_block_size: 5 + hidden_dim: 4096 + initializer_range: 0.02 + intermediate_dim: 14336 + layer_norm_epsilon: 1.0e-05 + num_heads: 32 + num_kv_heads: 8 + num_layers: 32 + reference_checkpoint: meta-llama/Llama-2-7b-hf + rope: + factor: 1.0 + theta: 10000 + type: default + scan_layers: true + seq_len: 4096 + tie_word_embeddings: false type: llama -# TODO: uncomment this once we resolve the resource exhaustion issue -# initialize_from_hf: "meta-llama/Llama-2-7b-hf" -# use_hf_model_config: true + upcast_attn: false + use_bias: false + use_flash_attention: true + use_layer_norm_weight: true +optimizer: + beta1: 0.9 + beta2: 0.95 + cooldown: null + cycle_length: 10000 + cycles: null + decay: 0.1 + default_weight_decay_mask: null + epsilon: 1.0e-08 + haps: null + learning_rate: 0.001 + lr_schedule: inv + max_grad_norm: 1.0 + min_lr_ratio: 0.1 + rewarmup: 0.0 + type: adam + warmup: 1000 + weight_decay: 0.05 + weight_decay_modules: null trainer: + axis_resources: {} + batch_axis: batch + checkpointer: + append_run_id_to_base_path: false + base_path: gs://levanter-checkpoints/checkpoints/llama-8b-tootsie-0.001-19ad63/checkpoints + keep: + - every: 20000 + save_interval: 10m + fp8: null + fsdp_axis: embed + id: llama-8b-tootsie-0.001-19ad63 + initialize_from: null + jax_config: + jax_softmax_custom_jvp: true + jax_threefry_partitionable: true + load_checkpoint: null + load_checkpoint_path: null + log_dir: logs + max_eval_batches: null + model_axis_size: 1 + mp: compute=bfloat16,params=float32,output=bfloat16 + num_train_steps: 10000 + parameter_axis_resources: {} + per_device_eval_parallelism: 2 + per_device_parallelism: 2 + profiler: false + profiler_num_steps: 100 + profiler_perfetto_link: false + profiler_start_step: 5 + ray: + address: null + auto_start_cluster: false + start_workers: false +# replica_dcn_axis_size: 2 +# replica_ici_axis_size: 1 + require_accelerator: true + seed: 0 + shutdown_at_exit: false + steps_per_eval: 10000 + tensor_parallel_axes: null tracker: + entity: null + group: null + id: null + mode: null + name: null + project: levanter + resume: allow + save_code: true + save_xla_dumps: false + tags: + - llama-8b-test + - llama + - 8b + - wsd-s type: wandb - project: "levanter" - tags: ["openwebtext", "llama"] - - mp: p=f32,c=bfloat16 - train_batch_size: 256 # set for v4-64 TPU - num_train_steps: 1000 - steps_per_eval: 50 - tensor_parallel_axes: ["mlp", "heads"] - fsdp_axis: "embed" - batch_axis: "batch" -optimizer: - learning_rate: 1.2E-5 # set low for fine-tuning - weight_decay: 0.1 - min_lr_ratio: 0.1 + train_batch_size: 1024 + wandb: null +use_hf_model_config: false diff --git a/docker/tpu/Dockerfile.base b/docker/tpu/Dockerfile.base index e2e032e82..09914eb79 100644 --- a/docker/tpu/Dockerfile.base +++ b/docker/tpu/Dockerfile.base @@ -5,8 +5,7 @@ RUN pip install virtualenv # venv binaries encode their directory, so we need to setup the venv in the final location RUN virtualenv -p python3.10 /opt/levanter/.venv ENV PATH /opt/levanter/.venv/bin:$PATH -#RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]==0.4.34" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]@git+https://github.com/dlwh/jax@retry_refuse" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]==0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # Install package dependencies to make incremental builds faster. WORKDIR /tmp/ diff --git a/docker/tpu/Dockerfile.incremental b/docker/tpu/Dockerfile.incremental index 64c14b4c9..45a5b1aa6 100644 --- a/docker/tpu/Dockerfile.incremental +++ b/docker/tpu/Dockerfile.incremental @@ -18,6 +18,7 @@ WORKDIR /opt/levanter ADD pyproject.toml README.md /opt/levanter/ RUN mkdir -p /opt/levanter/src/levanter RUN pip install -e '.[test]' +RUN pip install "lm-eval@git+https://github.com/dlwh/lm-evaluation-harness.git@no_torch" ADD . /opt/levanter # Add $EXTRA_CTX to the same location as in local machine. diff --git a/docs/Getting-Started-TPU-VM.md b/docs/Getting-Started-TPU-VM.md index 20fdaa765..53aaed218 100644 --- a/docs/Getting-Started-TPU-VM.md +++ b/docs/Getting-Started-TPU-VM.md @@ -83,7 +83,7 @@ on your development machine to build and run images on TPUs. First create a configuration file for future launches in your Levanter directory: ```bash -cat > .config < .levanter.yaml < # Optional +# Optional: specific environment variables for TPUs based on the TPU type +accel_env: + v6e: + # If you're lucky enough to have a v6e, you can set the following, which is pretty important for performance + LIBTPU_INIT_ARGS: "--xla_tpu_scoped_vmem_limit_kib=98304" + docker_repository: levanter # default zone: us-west4-a # if not set, will use your default zone tpu_name: test-spin-up-32 tpu_type: "v5litepod-16" -vm_image: "tpu-ubuntu2204-base" # default capacity_type: "preemptible" -autodelete: false subnetwork: "default" # default - EOF ``` @@ -155,6 +158,8 @@ a new file: If you're using `launch.py`, the config will be automatically uploaded as part of your Docker image, so you can just reference the local config path in your command line: +```bash +python infra/launch.py -- python src/levanter/main/train_lm.py --config_path config/my_config.yaml --trainer.checkpointer.base_path gs://' ``` Afterward, you can use the config directly from the TPU VM instance, e.g.: diff --git a/infra/helpers/setup-tpu-vm-tests.sh b/infra/helpers/setup-tpu-vm-tests.sh index 33c1c4add..a5cae20e1 100755 --- a/infra/helpers/setup-tpu-vm-tests.sh +++ b/infra/helpers/setup-tpu-vm-tests.sh @@ -105,7 +105,7 @@ pip install -U wheel # jax and jaxlib # libtpu sometimes has issues installing for clinical (probably firewall?) -retry pip install -U "jax[tpu]@git+https://github.com/dlwh/jax@retry_refuse" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +retru pip install -U "jax[tpu]==0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # clone levanter git clone $REPO levanter diff --git a/infra/helpers/setup-tpu-vm.sh b/infra/helpers/setup-tpu-vm.sh index 3ca81d76b..5bca127e9 100755 --- a/infra/helpers/setup-tpu-vm.sh +++ b/infra/helpers/setup-tpu-vm.sh @@ -105,7 +105,7 @@ pip install -U wheel # jax and jaxlib # libtpu sometimes has issues installing for clinical (probably firewall?) -retry pip install -U "jax[tpu]@git+https://github.com/dlwh/jax@retry_refuse" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +retru pip install -U "jax[tpu]==0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # clone levanter git clone $REPO levanter diff --git a/infra/launch.py b/infra/launch.py index 05d4fffac..612b77c8a 100755 --- a/infra/launch.py +++ b/infra/launch.py @@ -1,8 +1,8 @@ #!/usr/bin/python - import argparse import getpass import subprocess +import sys import time from pathlib import Path @@ -12,6 +12,14 @@ from levanter.infra.tpus import launch_job +# default: tpu-ubuntu2204-base +TPU_TYPE_TO_VM_IMAGE = { + "v5litepod": "v2-alpha-tpuv5-lite", + "v5p": "v2-alpha-tpuv5", + "v6e": "v2-alpha-tpuv6e", +} + + def main(): parser = argparse.ArgumentParser() config = cli.load_config() @@ -28,7 +36,7 @@ def main(): cli.add_arg(parser, config, ["--tpu_name"], required=True) cli.add_arg(parser, config, ["--tpu_type"], required=True) cli.add_arg(parser, config, ["--node_count"], default=1, type=int) - cli.add_arg(parser, config, ["--version"], default="tpu-ubuntu2204-base") + cli.add_arg(parser, config, ["--version"], default=None) cli.add_arg(parser, config, ["--zone"], default=None, type=str, required=False) cli.add_arg(parser, config, ["--retries"], default=10, type=int) cli.add_arg(parser, config, ["--run_id"], default=cli.default_run_id(), type=str) @@ -37,9 +45,7 @@ def main(): cli.add_arg(parser, config, ["--github_token"], type=str) cli.add_arg(parser, config, ["--extra_context"], type=Path, required=False, default=None) - parser.add_argument( - "-e", "--env", action="append", nargs=2, metavar=("KEY", "VALUE"), default=list(config.get("env", {}).items()) - ) + parser.add_argument("-e", "--env", action="append", nargs=2, metavar=("KEY", "VALUE")) parser.add_argument("command", nargs=argparse.REMAINDER) args = parser.parse_args() @@ -57,8 +63,14 @@ def main(): retries = args.retries tpu_name = args.tpu_name tpu_type = args.tpu_type + + tpu_gen = tpu_type.split("-")[0] + version = args.version or TPU_TYPE_TO_VM_IMAGE.get(tpu_gen, "tpu-ubuntu2204-base") + + if not args.version: + print(f"Using default version: {version}", file=sys.stderr) + node_count = args.node_count - version = args.version zone = args.zone run_id = args.run_id registry = args.docker_registry @@ -73,7 +85,10 @@ def main(): raise ValueError("Zone must be specified or set in gcloud config.") region = "-".join(zone.split("-")[:-1]) - env = {k: v for k, v in args.env} + + env = config.env_for_accel(tpu_type) + for key, value in args.env or []: + env[key] = value if "WANDB_PROJECT" not in env: env["WANDB_PROJECT"] = "levanter" diff --git a/infra/launch_on_ray.py b/infra/launch_on_ray.py index 90f2c586a..2e7551f8b 100755 --- a/infra/launch_on_ray.py +++ b/infra/launch_on_ray.py @@ -37,9 +37,8 @@ def main(): cli.add_arg(parser, config, ["--extra_context"], type=Path, required=False, default=None) cli.add_arg(parser, config, ["--zone"], default=None, type=str, required=False) - parser.add_argument( - "-e", "--env", action="append", nargs=2, metavar=("KEY", "VALUE"), default=list(config.get("env", {}).items()) - ) + parser.add_argument("-e", "--env", action="append", nargs=2, metavar=("KEY", "VALUE")) + parser.add_argument("command", nargs=argparse.REMAINDER) args = parser.parse_args() @@ -62,6 +61,10 @@ def main(): github_token = args.github_token extra_context = args.extra_context + env = config.env_for_accel(tpu_type) + for key, value in args.env or []: + env[key] = value + if zone is None: zone = cli.gcloud_config()["zone"] diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index cb15125e0..4821f4d82 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -220,6 +220,8 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: # skip padding result = result[:initial_length] + logger.info(f"Finished running {len(requests)} loglikelihoods.") + return result def _pad_dataset_to_batch_size(self, requests): @@ -295,6 +297,7 @@ class LmEvalHarnessConfig: max_examples: int | None = None max_eval_length: int | None = None log_samples: bool = False + bootstrap_iters: int = 0 # set to 0 see if this makes it not hang randomly def to_task_spec(self) -> list[str | dict]: return [task.to_dict() if isinstance(task, TaskConfig) else task for task in self.task_spec] @@ -307,12 +310,13 @@ def to_task_dict(self) -> dict: run, and LM Eval Harness doesn't seem to want to do that by default. So we need to do some hacky stuff to make it work. """ + logger.info("Loading tasks...") import lm_eval.tasks as tasks manager = tasks.TaskManager() # we need to do it this way b/c i can't figure out how to run e.g. hellaswag 0 shot and 10 shot in a single run this_tasks = {} - for task in self.to_task_spec(): + for task in tqdm(self.to_task_spec()): try: if isinstance(task, str): this_tasks.update(tasks.get_task_dict(task, manager)) @@ -324,6 +328,8 @@ def to_task_dict(self) -> dict: except Exception: logger.exception(f"Failed to load task {task}") raise ValueError(f"Failed to load task {task}") + + logger.info(f"Loaded {len(this_tasks)} tasks") return this_tasks def _get_task_and_rename(self, manager, our_name, task: dict | str): @@ -397,16 +403,25 @@ def _actually_run_eval_harness( max_eval_length = config.max_eval_length EvalPos = model.Pos if max_eval_length is None else model.Pos.resize(max_eval_length) + num_parameters = levanter.utils.jax_utils.parameter_count(model) + logger.info( + f"Evaluating with max eval length {EvalPos.size} and batch size {EvalBatch.size}. There are" + f" {num_parameters} parameters in the model." + ) harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer, mp) - # we always set log_samples here and filter out the samples later if we don't want them - outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples, log_samples=True) + logger.info("Running eval harness...") + outputs = evaluator.evaluate( + harness, + tasks_to_run, + limit=max_examples, + log_samples=config.log_samples, + bootstrap_iters=config.bootstrap_iters, + ) + logger.info("Finished running eval harness.") averages = _compute_averages(outputs) outputs["averages"] = averages - if not config.log_samples: - del outputs["samples"] - return outputs @@ -417,7 +432,9 @@ def _compute_averages(outputs): Args: outputs: Dictionary with results and samples: - "results": Dictionary of task-level results. - - "samples": Dictionary of task-level sample counts. + - "n-samples" : Dictionary of task-level sample counts. + + Returns: Averages dictionary with macro and micro averages for all metrics. @@ -429,7 +446,7 @@ def _compute_averages(outputs): for task_results in outputs["results"].values(): metric_keys.update(k for k in task_results.keys() if "stderr" not in k and k != "alias") - examples_per_task = [len(task_samples) for task_samples in outputs["samples"].values()] + examples_per_task = [task_samples["effective"] for task_samples in outputs["n-samples"].values()] # Compute macro and micro averages for metric in metric_keys: @@ -448,7 +465,8 @@ def _compute_averages(outputs): # Compute macro and micro averages averages["macro_avg_" + metric] = np.mean(metric_values) - averages["micro_avg_" + metric] = np.average(metric_values, weights=this_examples_per_task) + if sum(this_examples_per_task) > 0: + averages["micro_avg_" + metric] = np.average(metric_values, weights=this_examples_per_task) return averages @@ -591,18 +609,24 @@ def run_eval_harness_main(config: EvalHarnessMainConfig): ) logger.info("Finished running LM eval harness") + + # log the results + logger.info("Logging results to tracker") + log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) + logger.info("Finished logging results to tracker") + # log the results as json + logger.info("uploading artifacts...") with open("lm_eval_harness_results.json", "w") as f: json.dump(outputs, f, indent=2) + f.flush() + f_path = f.name + levanter.tracker.current_tracker().log_artifact(f_path, name="lm_eval_harness_results") # also write to stdout if jax.process_index() == 0: print(json.dumps(outputs, indent=2), flush=True) - # also log the results - levanter.tracker.current_tracker().log_artifact("lm_eval_harness_results.json", name="lm_eval_harness_results") - log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) - return outputs @@ -639,6 +663,7 @@ def lm_eval_harness(step: StepInfo, force=False): return # don't run eval on the first step model = inference_mode(step.model, True) + logger.info("Running eval harness...") outputs = _actually_run_eval_harness( config, model, @@ -648,18 +673,22 @@ def lm_eval_harness(step: StepInfo, force=False): axis_resources, mp, ) + logger.info("Finished running eval harness.") - if jax.process_index() == 0: - log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) + log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) + logger.info("Logged report to tracker") + if jax.process_index() == 0: # don't delete b/c wandb will sometimes defer upload with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f: import json json.dump(outputs, f) + f.flush() levanter.tracker.current_tracker().log_artifact( f.name, name=f"lm_eval_harness_results.{step.step}.json", type="lm_eval_output" ) + logger.info("Uploaded results to tracker") return lm_eval_harness diff --git a/src/levanter/infra/cli_helpers.py b/src/levanter/infra/cli_helpers.py index 58413ef2b..6c4229224 100644 --- a/src/levanter/infra/cli_helpers.py +++ b/src/levanter/infra/cli_helpers.py @@ -1,14 +1,62 @@ import argparse import base64 +import dataclasses import os -import shlex import subprocess +import warnings +from dataclasses import dataclass +from functools import cached_property from typing import Optional +import draccus import yaml from google.cloud import storage +@dataclass(frozen=True) +class CliConfig: + project: str | None = None + zone: str | None = None + tpu: str | None = None + repository: str | None = None + image: str | None = None + tag: str | None = None + github_user: str | None = None + github_token: str | None = None + docker_file: str | None = None + extra_context: str | None = None + docker_target: str | None = None + docker_repository: str | None = None + subnetwork: str | None = None + + env: dict[str, str] = dataclasses.field(default_factory=dict) + + accel_env: dict[str, dict[str, str]] = dataclasses.field(default_factory=dict) + """ + Environment variables specific to a type of accelerator. The keys are the accelerator type (e.g. v5litepod-256) or + generation (e.g. v5litepod), with priority given to the more specific key. The values are dictionaries of environment + variables to set. These take priority over the general `env` field. + """ + + def env_for_accel(self, accel_type: str) -> dict[str, str]: + + base_env = self.env.copy() + + if "-" in accel_type: + base_env.update(self.accel_env.get(accel_type.split("-")[0], {})) + + if accel_type in self.accel_env: + base_env.update(self.accel_env[accel_type]) + + return base_env + + @cached_property + def as_dict(self): + dict = dataclasses.asdict(self) + # remove Nones + return {k: v for k, v in dict.items() if v is not None} + + # Oddly enough, there's no API to simply fetch the current gcloud configuration... def gcloud_config(): client = storage.Client() @@ -31,11 +79,11 @@ def get_default_zone() -> Optional[str]: return None -def add_arg(parser: argparse.ArgumentParser, config: dict, flags: list[str], required=False, default=None, **kw): +def add_arg(parser: argparse.ArgumentParser, config: CliConfig, flags: list[str], required=False, default=None, **kw): """Add an argument to the parser, using `config` or the environment to resolve default values.""" key = flags[0].lstrip("-").replace("-", "_") - if key in config: - default = config[key] + if key in config.as_dict: + default = config.as_dict[key] if key.upper() in os.environ: default = os.environ[key.upper()] @@ -48,11 +96,16 @@ def add_arg(parser: argparse.ArgumentParser, config: dict, flags: list[str], req parser.add_argument(*flags, **kw) -def load_config(): - if os.path.exists(".config"): - return yaml.load(open(".config", "r"), Loader=yaml.SafeLoader) +def load_config() -> CliConfig: + if os.path.exists(".levanter.yaml"): + d = yaml.load(open(".levanter.yaml", "r"), Loader=yaml.SafeLoader) + elif os.path.exists(".config"): + warnings.warn("Using deprecated .config file. Please rename to .levanter.yaml") + d = yaml.load(open(".config", "r"), Loader=yaml.SafeLoader) else: - return {} + d = {} + + return draccus.decode(CliConfig, d) def get_git_commit(): @@ -60,36 +113,6 @@ def get_git_commit(): return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip() -def make_docker_run_command(image_id, command, *, foreground, env, name="levanter"): - docker_command = [ - "docker", - "run", - "-t" if foreground else "-d", - f"--name={shlex.quote(name)}", - "--privileged", - "--shm-size=32gb", - "--net=host", - "--init", - "--mount", - "type=volume,source=levanter,target=/home/levanter", - "-v", - "/tmp:/tmp", - ] - - # optionally add multislice env vars (if set by ray runtime env vars) - for v in ["MEGASCALE_COORDINATOR_ADDRESS", "MEGASCALE_NUM_SLICES", "MEGASCALE_PORT", "MEGASCALE_SLICE_ID"]: - v = shlex.quote(str(v)) - docker_command.extend(["-e", v]) - - for k, v in env.items(): - v = shlex.quote(str(v)) - k = shlex.quote(str(k)) - docker_command.extend(["-e", f"{k}={v}"]) - - docker_command.extend([image_id, *command]) - return docker_command - - def default_run_id(): """Generate a run ID for wandb and continuation. diff --git a/src/levanter/infra/docker.py b/src/levanter/infra/docker.py index d48b558a5..aabce6b1a 100644 --- a/src/levanter/infra/docker.py +++ b/src/levanter/infra/docker.py @@ -1,6 +1,7 @@ import json import os import pty +import shlex import shutil import subprocess import sys @@ -236,3 +237,33 @@ def split_image_and_tag(docker_base_image): base_image = docker_base_image base_tag = "latest" return base_image, base_tag + + +def make_docker_run_command(image_id, command, *, foreground, env, name="levanter"): + docker_command = [ + "docker", + "run", + "-t" if foreground else "-d", + f"--name={shlex.quote(name)}", + "--privileged", + "--shm-size=32gb", + "--net=host", + "--init", + "--mount", + "type=volume,source=levanter,target=/home/levanter", + "-v", + "/tmp:/tmp", + ] + + # optionally add multislice env vars (if set by ray runtime env vars) + for v in ["MEGASCALE_COORDINATOR_ADDRESS", "MEGASCALE_NUM_SLICES", "MEGASCALE_PORT", "MEGASCALE_SLICE_ID"]: + v = shlex.quote(str(v)) + docker_command.extend(["-e", v]) + + for k, v in env.items(): + v = shlex.quote(str(v)) + k = shlex.quote(str(k)) + docker_command.extend(["-e", f"{k}={v}"]) + + docker_command.extend([image_id, *command]) + return docker_command diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py index 1a9342c54..86ce4223a 100644 --- a/src/levanter/infra/ray_tpu.py +++ b/src/levanter/infra/ray_tpu.py @@ -18,7 +18,7 @@ from ray.exceptions import NodeDiedError, RayError, RaySystemError, RayTaskError, WorkerCrashedError from ray.remote_function import RemoteFunction -from levanter.infra.cli_helpers import make_docker_run_command +from levanter.infra.docker import make_docker_run_command from levanter.utils.ray_utils import ser_exc_info diff --git a/src/levanter/infra/tpus.py b/src/levanter/infra/tpus.py index bbb1cc5f5..fa0f1a23c 100644 --- a/src/levanter/infra/tpus.py +++ b/src/levanter/infra/tpus.py @@ -10,7 +10,7 @@ import requests # type: ignore -from levanter.infra.cli_helpers import make_docker_run_command +from levanter.infra.docker import make_docker_run_command logger = logging.getLogger(__name__) diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index 7044200ba..b054f1972 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -806,10 +806,10 @@ def _tpu_splash_attention( if bias is not None: raise NotImplementedError("Splash attention does not support bias") - if attention_dtype is not None and attention_dtype != jnp.float32: - warnings.warn("Splash attention only supports float32. Switching to float32.") + # if attention_dtype is not None and attention_dtype != jnp.float32: + # warnings.warn("Splash attention only supports float32. Switching to float32.") - attention_dtype = jnp.float32 + # attention_dtype = jnp.float32 q_class, k_class, v_class = _bin_and_group_axes_by_function(query, key, value, QPos, KPos, Key) diff --git a/src/levanter/utils/flop_utils.py b/src/levanter/utils/flop_utils.py index 7d5c4fc0a..eef91f110 100644 --- a/src/levanter/utils/flop_utils.py +++ b/src/levanter/utils/flop_utils.py @@ -146,6 +146,11 @@ def lm_flops_per_token( "tpu v5p": { "bf16": 459e12, }, + # Source: https://cloud.google.com/tpu/docs/v6e + "tpu v6 lite": { + "bf16": 918e12, + "int8": 1836e12, + }, }