diff --git a/benchmarks/api_server/README.md b/benchmarks/api_server/README.md new file mode 100644 index 0000000000..4fd0f31c69 --- /dev/null +++ b/benchmarks/api_server/README.md @@ -0,0 +1,317 @@ +# MaxText API Server + +This directory contains an OpenAI-compatible API server for serving MaxText models, enabling benchmarks with evaluation frameworks like lm-eval-harness and evalchemy. It uses [FastAPI](https://fastapi.tiangolo.com/) as the web framework and can be deployed on a single machine or a multi-host GKE cluster. + +## Table of Contents +- [Installation](#installation) +- [Environment Variables](#environment-variables) +- [Launching the Server (Single-Host)](#launching-the-server-single-pod) +- [Deploying on a GKE Cluster (Multi-Host)](#deploying-on-a-gke-cluster-multi-host) +- [Interacting with the Server](#interacting-with-the-server) +- [Benchmarking with Evaluation Frameworks](#benchmarking-with-evaluation-frameworks) + + +## Installation + +The server has a few additional dependencies beyond the core MaxText requirements. Install them using the provided `requirements.txt` file: + +```bash +pip install -r benchmarks/api_server/requirements.txt +``` + +## Environment Variables + +Before launching the server, you may need to set the following environment variable: + +- `HF_TOKEN`: Your Hugging Face access token. This is required if the model's tokenizer is hosted on the Hugging Face Hub and is not public. + +```bash +export HF_TOKEN= +``` + +## Launching the Server (Single-Host) + +The primary way to launch the API server is by using the `start_server.sh` script. This script ensures that the server is run from the project's root directory, which is necessary for the Python interpreter to find all the required modules. + +The script takes the path to a base configuration file (e.g., `MaxText/configs/base.yml`) followed by any number of model-specific configuration overrides. + +### Benchmarking Configuration + +To use this server for benchmarking with frameworks like `lm-eval-harness` or `evalchemy`, you **must** include the following two arguments in your launch command: + +- `tokenizer_type="huggingface"`: Ensures the tokenizer is compatible with the evaluation harness. +- `return_log_prob=True`: Enables the log probability calculations required for many standard evaluation metrics. + +### Command Structure + +```bash +bash benchmarks/api_server/start_server.sh /path/to/base.yml [arg1=value1] [arg2=value2] ... +``` + +### Example + +Here is an example of how to launch the server with a `qwen3-30b-a3b` model, configured for benchmarking. This example is configured for a TPU v5p-8, which has 4 chips. + +```bash +# Make sure you are in the root directory of the maxtext project. + +bash benchmarks/api_server/start_server.sh \ + MaxText/configs/base.yml \ + model_name="qwen3-30b-a3b" \ + tokenizer_path="Qwen/Qwen3-30B-A3B-Thinking-2507" \ + load_parameters_path="" \ + per_device_batch_size=4 \ + ici_tensor_parallelism=4 \ + max_prefill_predict_length=1024 \ + max_target_length=2048 \ + async_checkpointing=false \ + scan_layers=false \ + attention="dot_product" \ + tokenizer_type="huggingface" \ + return_log_prob=True +``` + +Once the server starts successfully, you will see a confirmation message from Uvicorn: + +Single-Host Server Startup + +``` +INFO: RANK 0: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) +``` + +The server is now ready to accept requests on port 8000. + +## Deploying on a GKE Cluster (Multi-Host) + +For large models that require a multi-host TPU setup, you can deploy the server using the [xpk (Kubernetes Pod Executor) tool](https://github.com/AI-Hypercomputer/xpk). The recommended approach is to create a single submission script to configure and launch the workload. + + +### 1. Create a Job Submission Script + +Create a new bash script (e.g., `launch_gke_server.sh`) to hold your configuration and `xpk` command. This makes launching jobs repeatable and easy to modify. + +For your convenience, the script below is also available as a template file at `benchmarks/api_server/launch_gke_server.sh.template`. + +Inside this script, you will define the server's startup command and your cluster configuration. Before running the script, define the placeholders at the top of the file. Placeholders are enclosed in angle brackets (e.g., ``). + +```bash +#!/bin/bash +set -e + +# ============================================================================== +# 1. User-Configurable Variables +# ============================================================================== + +# -- GKE Cluster Configuration -- +# (, , ) +export CLUSTER="" +export DEVICE_TYPE="v5p-16" +export PROJECT="" +export ZONE="" + +# -- XPK Workload Configuration -- +# (, ) +export RUNNAME="my-server-$(date +%Y-%m-%d-%H-%M-%S)" +export DOCKER_IMAGE="gcr.io/tpu-prod-env-multipod/maxtext_jax_nightly:" +export HF_TOKEN="" # Optional: if your tokenizer is private + +# -- Model Configuration -- +# IMPORTANT: Replace these with your model's details. +# (, , ) +export MODEL_NAME="qwen3-30b-a3b" +export TOKENIZER_PATH="Qwen/Qwen3-30B-A3B-Thinking-2507" +export LOAD_PARAMETERS_PATH="" +export PER_DEVICE_BATCH_SIZE=4 +# Parallelism settings should match the number of chips on your device. +# For a v5p-16 (8 chips), the product of parallelism values should be 8. +export ICI_TENSOR_PARALLELISM=4 +export ICI_EXPERT_PARALLELISM=2 + +# ============================================================================== +# 2. Define the Command to Run on the Cluster +# ============================================================================== +# This command installs dependencies and then starts the server. +CMD="export HF_TOKEN=${HF_TOKEN} && \ + pip install --upgrade pip && \ + pip install -r benchmarks/api_server/requirements.txt && \ + bash benchmarks/api_server/start_server.sh \ + MaxText/configs/base.yml \ + model_name="${MODEL_NAME}" \ + tokenizer_path="${TOKENIZER_PATH}" \ + load_parameters_path="${LOAD_PARAMETERS_PATH}" \ + per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ + ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \ + ici_expert_parallelism=${ICI_EXPERT_PARALLELISM} \ + tokenizer_type=\"huggingface\" \ + return_log_prob=True" + + +# ============================================================================== +# 3. Launch the Workload +# ============================================================================== +echo "Launching workload ${RUNNAME}..." +xpk workload create --workload "${RUNNAME}" \ + --base-docker-image "${DOCKER_IMAGE}" \ + --command "${CMD}" \ + --num-slices=1 \ + --cluster "${CLUSTER}" --device-type "${DEVICE_TYPE}" --project "${PROJECT}" --zone "${ZONE}" + +echo "Workload ${RUNNAME} created." +echo "Use the following command to connect:" +echo "bash benchmarks/api_server/port_forward_xpk.sh job_name=${RUNNAME} project=${PROJECT} zone=${ZONE} cluster=${CLUSTER}" +``` + +### 2. Launch the Workload + +Make the script executable and run it: + +```bash +chmod +x launch_gke_server.sh +./launch_gke_server.sh +``` + +### 3. Connect to the Server + +The API server only runs on the first host/worker (rank 0 on GPU) of the workload. To connect to it, use the `port_forward_xpk.sh` script as instructed in the output of your launch script. + +```bash +bash benchmarks/api_server/port_forward_xpk.sh \ + job_name= \ + project= \ + zone= \ + cluster= +``` + +The script will automatically find the correct pod and establish the port-forward connection. Your server is now accessible at `http://localhost:8000`. + +## Interacting with the Server + +Once the server is running (either locally or connected via port-forwarding), you can interact with it using any standard HTTP client. The `model` field in the request body can be set to any string; it is used for identification purposes but does not change which model is being served. + +### Using `curl` + +#### Completions API + +The `/v1/completions` endpoint is suitable for simple prompt-response interactions. + +```bash +curl -X POST http://localhost:8000/v1/completions \ +-H "Content-Type: application/json" \ +-d + "{ + "model": "", + "prompt": "The capital of France is", + "max_tokens": 50, + "temperature": 0.7 +}" +``` + +#### Chat Completions API + +The `/v1/chat/completions` endpoint is designed for multi-turn conversations. + +```bash +curl -X POST http://localhost:8000/v1/chat/completions \ +-H "Content-Type: application/json" \ +-d + "{ + "model": "", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the largest planet in our solar system?"} + ], + "max_tokens": 50, + "temperature": 0.7 +}" +``` + +Server logs will display the following information: + +Server Request Logs + +### Using the OpenAI Python Client + +You can also use the official `openai` Python library to interact with the server. + +**Installation:** +```bash +pip install openai +``` + +**Example Python Script:** +```python +from openai import OpenAI + +# Point the client to the local server +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +completion = client.chat.completions.create( + model="", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the largest planet in our solar system?"} + ] +) + +print(completion.choices[0].message.content) +``` + +## Benchmarking with Evaluation Frameworks + +You can evaluate models served by this API using standard frameworks like [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) and [evalchemy](https://github.com/mlfoundations/evalchemy). + +### Setup + +It is highly recommended to set up a new, separate Python virtual environment for the evaluation framework. This prevents any dependency conflicts with the MaxText environment. + +```bash +# In a new terminal +python3 -m venv eval_env +source eval_env/bin/activate +``` + +Install the evaluation frameworks by following their official guides: +- [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) +- [evalchemy](https://github.com/mlfoundations/evalchemy) + + +### Log-Likelihood / Multiple Choice Tasks (e.g., MMLU) + +Tasks that compare the log-probabilities of different choices (`output_type: multiple_choice` or `loglikelihood`) use the `/v1/completions` endpoint. + +To maximize throughput, set the `batch_size` in your evaluation command to match the total batch size of your running server (`per_device_batch_size` * `number of devices`). + +**Example: Running MMLU** +```bash +python -m eval.eval \ + --model local-completions \ + --model_args "pretrained=,base_url=http://localhost:8000/v1/completions,tokenizer_backend=huggingface,tokenizer=,model=,max_length=" \ + --tasks mmlu \ + --batch_size \ + --output_path logs +``` + +An example benchmark outpus will be like: + +MMLU Example + +### Generative Tasks (e.g., AIME) + +Tasks that require generating text until a stop sequence is met (`output_type: generate_until`) use the `/v1/chat/completions` endpoint. + +The chat API does not support batched requests directly. Instead, the evaluation harness sends concurrent requests to simulate batching. To enable this, set `num_concurrent` to match your server's total batch size and set the evaluation `batch_size` to 1. You must also include the `--apply_chat_template` flag. All sampling parameters (like temperature, top_p, etc.) should be passed via the `--gen_kwargs` argument. For Example, if you are using v5p-8(`4 chips`) with `per_device_batch_size = 4`, the `num_concurrent = 16` + +**Example: Running AIME25** +```bash +python -m eval.eval \ + --model local-chat-completions \ + --model_args "num_concurrent=16,pretrained=,base_url=http://localhost:8000/v1/chat/completions,tokenizer_backend=huggingface,tokenizer=,model=,max_length=" \ + --tasks AIME25 \ + --batch_size 1 \ + --output_path logs \ + --apply_chat_template \ + --gen_kwargs "temperature=0.6,top_p=0.95,top_k=20,max_tokens=,max_gen_toks=" +``` +The valid arguments for `--gen_kwargs` are `temperature`, `top_p`, `top_k`, `stop`, `seed`, `max_tokens` and `max_gen_toks`. The `max_gen_toks` argument is used by some tasks in evaluation harness to control the maximum number of tokens to generate. We suggest pass `max_tokens` and `max_gen_toks` with the same value at the same time. + +The evaluation results will be saved to the directory specified by the `--output_path` argument (in the examples above, a directory named `logs`). \ No newline at end of file diff --git a/benchmarks/api_server/images/mmlu_example.png b/benchmarks/api_server/images/mmlu_example.png new file mode 100644 index 0000000000..593e7694ac Binary files /dev/null and b/benchmarks/api_server/images/mmlu_example.png differ diff --git a/benchmarks/api_server/images/server-request-logs.png b/benchmarks/api_server/images/server-request-logs.png new file mode 100644 index 0000000000..01ba7dc064 Binary files /dev/null and b/benchmarks/api_server/images/server-request-logs.png differ diff --git a/benchmarks/api_server/images/single-host-server-startup.png b/benchmarks/api_server/images/single-host-server-startup.png new file mode 100644 index 0000000000..6d07361399 Binary files /dev/null and b/benchmarks/api_server/images/single-host-server-startup.png differ diff --git a/benchmarks/api_server/launch_gke_server.sh.template b/benchmarks/api_server/launch_gke_server.sh.template new file mode 100644 index 0000000000..fedb5d8dd8 --- /dev/null +++ b/benchmarks/api_server/launch_gke_server.sh.template @@ -0,0 +1,78 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +#!/bin/bash +set -e + +# ============================================================================== +# 1. User-Configurable Variables +# ============================================================================== + +# -- GKE Cluster Configuration -- +# (, , ) +export CLUSTER="" +export DEVICE_TYPE="v5p-16" +export PROJECT="" +export ZONE="" + +# -- XPK Workload Configuration -- +# (, ) +export RUNNAME="my-server-$(date +%Y-%m-%d-%H-%M-%S)" +export DOCKER_IMAGE="gcr.io/tpu-prod-env-multipod/maxtext_jax_nightly:" +export HF_TOKEN="" # Optional: if your tokenizer is private + +# -- Model Configuration -- +# IMPORTANT: Replace these with your model's details. +# (, , ) +export MODEL_NAME="qwen3-30b-a3b" +export TOKENIZER_PATH="Qwen/Qwen3-30B-A3B-Thinking-2507" +export LOAD_PARAMETERS_PATH="" +export PER_DEVICE_BATCH_SIZE=4 +# Parallelism settings should match the number of chips on your device. +# For a v5p-16 (8 chips), the product of parallelism values should be 8. +export ICI_TENSOR_PARALLELISM=4 +export ICI_EXPERT_PARALLELISM=2 + +# ============================================================================== +# 2. Define the Command to Run on the Cluster +# ============================================================================== +# This command installs dependencies and then starts the server. +CMD="export HF_TOKEN=${HF_TOKEN} && \ + pip install --upgrade pip && \ + pip install -r benchmarks/api_server/requirements.txt && \ + bash benchmarks/api_server/start_server.sh \ + MaxText/configs/base.yml \ + model_name=\"${MODEL_NAME}\" \ + tokenizer_path=\"${TOKENIZER_PATH}\" \ + load_parameters_path=\"${LOAD_PARAMETERS_PATH}\" \ + per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ + ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \ + ici_expert_parallelism=${ICI_EXPERT_PARALLELISM} \ + tokenizer_type=\"huggingface\" \ + return_log_prob=True" + +# ============================================================================== +# 3. Launch the Workload +# ============================================================================== +echo "Launching workload ${RUNNAME}..." +xpk workload create --workload "${RUNNAME}" \ + --base-docker-image "${DOCKER_IMAGE}" \ + --command "${CMD}" \ + --num-slices=1 \ + --cluster "${CLUSTER}" --device-type "${DEVICE_TYPE}" --project "${PROJECT}" --zone "${ZONE}" + +echo "Workload ${RUNNAME} created." +echo "Use the following command to connect:" +echo "bash benchmarks/api_server/port_forward_xpk.sh job_name=${RUNNAME} project=${PROJECT} zone=${ZONE} cluster=${CLUSTER}" diff --git a/benchmarks/api_server/maxtext_generator.py b/benchmarks/api_server/maxtext_generator.py new file mode 100644 index 0000000000..c81d425fbc --- /dev/null +++ b/benchmarks/api_server/maxtext_generator.py @@ -0,0 +1,634 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may not obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import time +import uuid +import json +import datetime +from typing import Sequence, Optional, List, Union +import logging + + +import jax +import jax.numpy as jnp +from absl import app +import numpy as np + +from MaxText import max_utils, maxengine, pyconfig, multimodal_utils, max_logging + +from dataclasses import dataclass, field + +# Set TF log level to avoid verbose startup messages. +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + + +@dataclass +class LogProbs: + """ + A dataclass to store detailed log probability information for a sequence of tokens. + + Attributes: + tokens: A list of token IDs. + token_logprobs: A list of log probabilities corresponding to each token. + top_logprobs: An optional list of dictionaries, where each dictionary maps + surrounding tokens to their log probabilities at each position. + Currently unused and set to None. + text_offset: A list of character offsets for each token in the generated text. + """ + tokens: List[int] + token_logprobs: List[float] + top_logprobs: Optional[List[None]] = None + text_offset: List[int] = field(default_factory=list) + + +@dataclass +class Completion: + """ + Represents a single completed generation from the model. + + Attributes: + index: The index of this completion in the batch. + text: The generated text. + tokens: The list of token IDs that make up the generated text. + logprobs: An optional `LogProbs` object containing detailed log probability info. + finish_reason: The reason the generation finished (e.g., 'stop', 'length'). + prompt_token_count: The number of tokens in the input prompt. + completion_token_count: The number of tokens in the generated completion. + """ + index: int + text: str + tokens: List[int] + logprobs: Optional[LogProbs] + finish_reason: str = "stop" + prompt_token_count: int = 0 + completion_token_count: int = 0 + + +@dataclass +class GenerationStream: + """Holds the state for a single generation stream within a batch.""" + # Input state + tokens: np.ndarray + true_length: int + image: Optional[np.ndarray] + + # Output accumulators + generated_ids: List[int] = field(default_factory=list) + generated_logprobs: List[float] = field(default_factory=list) + + # For echo=True + prompt_ids: List[int] = field(default_factory=list) + prompt_logprobs: List[float] = field(default_factory=list) + + # Status + finished: bool = False + finish_reason: str = "length" + + +class MaxTextGenerator: + """A reusable class for parallel text generation using MaxText.""" + + def __init__(self, argv: Sequence[str]): + """ + Initializes the MaxText model, tokenizer, and engine. + + Args: + argv: Command-line arguments for MaxText configuration. + """ + start_time = time.time() + + argv_list = list(argv) + + # Check for HF_TOKEN env var and inject as a pyconfig argument if not already present. + hf_token = os.environ.get("HF_TOKEN") + if hf_token and not any("hf_access_token" in arg for arg in argv_list): + max_logging.log("Found HF_TOKEN environment variable. Adding to config.") + argv_list.append(f"hf_access_token={hf_token}") + + # CRITICAL: Initialize the distributed system and config FIRST. + # This call to pyconfig.initialize() contains jax.distributed.initialize() + # and must happen before any other JAX calls. + self.config = pyconfig.initialize(argv_list) + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + + # Now that JAX is initialized, we can set up logging and use JAX functions. + self.rank = jax.process_index() + self.logger = logging.getLogger(__name__) + self.logger.info(f"Initializing MaxTextGenerator with argv: {argv_list}") + + self._validate_config(self.config) + self.logger.info("System information:") + # Temporarily redirect stdout to capture print output for the log + from io import StringIO + import sys + old_stdout = sys.stdout + sys.stdout = captured_stdout = StringIO() + max_utils.print_system_information() + sys.stdout = old_stdout + self.logger.info(captured_stdout.getvalue()) + + self.engine = maxengine.MaxEngine(self.config) + self.rng = jax.random.PRNGKey(1234) + + self.logger.info("Loading model parameters...") + self.rng, rng_load_params = jax.random.split(self.rng) + self.params = self.engine.load_params(rng=rng_load_params) + self.logger.info("Model parameters loaded.") + + self.metadata = self.engine.get_tokenizer() + self.tokenizer = self.engine.build_tokenizer(self.metadata) + eos_id = self.tokenizer.eos_id + if not isinstance(eos_id, list): + eos_id = [eos_id] + self.eos_ids = eos_id + try: + self.has_chat_template = getattr(self.tokenizer.tokenizer, "chat_template", False) + except AttributeError: + self.has_chat_template = False + + self.logger.info(f"Chat Template available: {self.has_chat_template}") + + self.batch_size = int(self.config.per_device_batch_size * jax.device_count()) + + self.rng, rng_init_decode = jax.random.split(self.rng) + self.decode_state = self.engine.init_decode_state(rng=rng_init_decode) + self._jitted_reset_state = jax.jit( + lambda state: jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), state) + ) + + end_time = time.time() + self.logger.info(f"Initialization complete in {end_time - start_time:.2f} seconds. Max batch size: {self.batch_size}") + + + def generate_batch( + self, + prompts: List[str], + image_paths: Optional[List[Optional[str]]] = None, + max_tokens: int = None, + logprobs: int = None, + echo: bool = False, + stop: Optional[Union[str, List[str]]] = None, + temperature: Optional[float] = None, + seed: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + ) -> List[Completion]: + """ + Generates text for a batch of prompts, handling chunking automatically. + + Args: + prompts: A list of prompt strings. + image_paths: An optional list of image paths, one for each prompt. + max_tokens: The maximum number of tokens to generate. + logprobs: The number of top log probabilities to return for each token. + echo: Whether to include the prompt in the generated text. + stop: An optional list of stop sequences. + temperature: An optional temperature for sampling. + seed: An optional seed for deterministic sampling. + top_k: An optional integer for top-k sampling. + top_p: An optional float for nucleus sampling. + + Returns: + A list of generated Completion, corresponding to the input prompts. + """ + if image_paths is None: + image_paths = [None] * len(prompts) + if len(prompts) != len(image_paths): + raise ValueError("The number of prompts must equal the number of image paths.") + + all_results = [] + num_prompts = len(prompts) + for i in range(0, num_prompts, self.batch_size): + prompt_chunk = prompts[i : i + self.batch_size] + image_chunk = image_paths[i : i + self.batch_size] + + chunk_count = (i // self.batch_size) + 1 + total_chunks = (num_prompts + self.batch_size - 1) // self.batch_size + + chunk_results = self._process_chunk( + prompt_chunk, image_chunk, max_tokens, logprobs, echo, stop, temperature, seed, top_k, top_p + ) + all_results.extend(chunk_results) + + return all_results + + + def _process_chunk( + self, + prompts: List[str], + image_paths: List[Optional[str]], + max_tokens: int, + logprobs: int = None, + echo: bool = False, + stop: Optional[Union[str, List[str]]] = None, + temperature: Optional[float] = None, + seed: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + ) -> List[Completion]: + """Orchestrates the generation process for a single chunk of prompts.""" + start_time = time.time() + + # for prompt in prompts: + # self.logger.debug(f"Processing prompt: {prompt}") + + initialize_start_time = time.time() + # Reset the state to handle the new batch while reusing memory. + self.decode_state = self._jitted_reset_state(self.decode_state) + streams, rng = self._initialize_streams_and_state(prompts, image_paths, seed) + initialize_end_time = time.time() + self.logger.info(f"Initialize step took {initialize_end_time - initialize_start_time:.2f}s.") + + if max_tokens is not None and max_tokens <= 0: + self.logger.warning("max_tokens <= 0, returning empty completions.") + return [Completion(index=i, text="", tokens=[], logprobs=None) for i in range(len(streams))] + + prefill_start_time = time.time() + self.decode_state, rng = self._run_prefill_step(streams, self.decode_state, rng, logprobs, echo, temperature, top_k, top_p) + prefill_end_time = time.time() + self.logger.info(f"Prefill step took {prefill_end_time - prefill_start_time:.2f}s.") + + generation_start_time = time.time() + self.decode_state = self._run_generation_loop(streams, self.decode_state, rng, max_tokens, stop, temperature, top_k, top_p) + generation_end_time = time.time() + self.logger.info(f"Generation loop took {generation_end_time - generation_start_time:.2f}s.") + + completions_start_time = time.time() + completions = self._build_completions(streams, logprobs, echo) + completions_end_time = time.time() + self.logger.info(f"Completions loop took {completions_end_time - completions_start_time:.2f}s.") + + end_time = time.time() + self.logger.info(f"Processed {len(prompts)} prompts in {end_time - start_time:.2f}s.") + return completions + + + def _initialize_streams_and_state(self, prompts, image_paths, seed): + """Tokenizes inputs, sets up stream objects, and initializes the decode state.""" + prefill_length = getattr(self.config, "max_prefill_predict_length", 1024) + streams = [] + for prompt, image_path in zip(prompts, image_paths): + toks, tlen, imgs = self._preprocess_inputs(prompt, prefill_length, image_path) + assert tlen <= prefill_length, f"Input token length {tlen} is > {prefill_length}" + streams.append(GenerationStream(tokens=toks, true_length=tlen, image=imgs)) + + if seed is not None: + rng = jax.random.PRNGKey(seed) + else: + self.rng, rng = jax.random.split(self.rng) + + return streams, rng + + + def _determine_sampling_algorithm(self, temperature, top_k, top_p): + """Determines the sampling algorithm based on user-provided parameters.""" + if temperature == 0.0: + return "greedy" + if top_k is not None and top_p is not None: + return "composite" + if top_k is not None: + return "topk" + if top_p is not None: + return "nucleus" + if temperature is not None: + return "weighted" + # If no specific parameters are provided, return None to let the + # engine use its default configured `decode_sampling_strategy`. + return None + + + def _run_prefill_step(self, streams, decode_state, rng, logprobs, echo, temperature, top_k, top_p): + """Runs the prefill step for each stream and inserts results into the decode state.""" + sampling_algorithm = self._determine_sampling_algorithm(temperature, top_k, top_p) + prefill_results_to_insert = {} + + for i, stream in enumerate(streams): + rng, rng_prefill = jax.random.split(rng) + want_prompt_logp = logprobs is not None and echo + + prefill_result, _ = self.engine.prefill( + params=self.params, + padded_tokens=stream.tokens, + true_length=stream.true_length, + images=stream.image, + rng=rng_prefill, + slot=i, + return_prompt_logp=want_prompt_logp, + temperature=temperature, + algorithm=sampling_algorithm, + topk=top_k, + nucleus_topp=top_p, + ) + prefill_results_to_insert[i] = prefill_result + + p_ids = list(map(int, np.array(stream.tokens[:stream.true_length], dtype=np.int32).tolist())) + stream.prompt_ids.extend(p_ids) + if prefill_result.get("prompt_logp") is not None: + p_logp_arr = np.array(prefill_result["prompt_logp"])[0, :stream.true_length] + stream.prompt_logprobs.extend([float(x) for x in p_logp_arr.tolist()]) + + first_token_id = int(np.array(prefill_result["tokens"])[0, 0]) + stream.generated_ids.append(first_token_id) + if prefill_result.get("token_logp") is not None: + first_logp = float(np.array(prefill_result["token_logp"])[0, 0]) + stream.generated_logprobs.append(first_logp) + + for slot_idx, result in prefill_results_to_insert.items(): + decode_state = self.engine.insert(prefix=result, decode_state=decode_state, slot=slot_idx) + + return decode_state, rng + + + def _run_generation_loop(self, streams, decode_state, rng, max_tokens, stop, temperature, top_k, top_p): + """Runs the autoregressive generation loop.""" + target_length = getattr(self.config, "max_target_length", 2048) + prefill_length = getattr(self.config, "max_prefill_predict_length", 1024) + sampling_algorithm = self._determine_sampling_algorithm(temperature, top_k, top_p) + + stop_sequences = [] + max_stop_seq_len_tokens = 0 + if stop: + # Ensure stop is a list of non-empty strings + stop_sequences = [s for s in ([stop] if isinstance(stop, str) else stop) if s] + if stop_sequences: + # Calculate the max token length for any stop sequence to define a lookback window. + for seq in stop_sequences: + # Use the underlying tokenizer here to avoid potential errors with the wrapper + # on single-token sequences, as this is a safe, internal calculation. + token_ids = self.tokenizer.tokenizer.encode(seq, add_special_tokens=False) + max_stop_seq_len_tokens = max(max_stop_seq_len_tokens, len(token_ids)) + + total_steps = target_length - prefill_length + if max_tokens is not None: + total_steps = min(total_steps, max_tokens - 1) # -1 for the token from prefill + + for step in range(total_steps): + self.logger.debug(f"Generation step {step + 1}/{total_steps}") + active_streams = [(i, s) for i, s in enumerate(streams) if not s.finished] + if not active_streams: + self.logger.info("All streams finished. Breaking generation loop.") + break + + rng, rng_generate = jax.random.split(rng) + decode_state, _ = self.engine.generate( + self.params, + decode_state, + rng=rng_generate, + temperature=temperature, + algorithm=sampling_algorithm, + topk=top_k, + nucleus_topp=top_p, + ) + + state_tokens = np.array(decode_state["tokens"]) + state_logp_np = None + if (logp := decode_state.get("token_logp")) is not None: + state_logp_np = np.array(logp) + + for slot_idx, stream in active_streams: + tok_id = int(state_tokens[slot_idx, 0]) + stream.generated_ids.append(tok_id) + if state_logp_np is not None: + stream.generated_logprobs.append(float(state_logp_np[slot_idx, 0])) + + # Check for finish conditions + current_len = stream.true_length + 1 + step + is_max_len = current_len >= target_length + is_eos = tok_id in self.eos_ids + stop_sequence_found = False + + if stop_sequences: + # Define a lookback window for decoding that is slightly larger + # than the longest stop sequence in tokens. + lookback_window = max_stop_seq_len_tokens + 2 + start_index = max(0, len(stream.generated_ids) - lookback_window) + trailing_ids = stream.generated_ids[start_index:] + + if trailing_ids: + # Use the standard jetstream wrapper for decoding as requested. + trailing_text = self.tokenizer.decode([int(tid) for tid in trailing_ids]) + for stop_seq in stop_sequences: + # Use 'in' for a more robust check. + if stop_seq in trailing_text: + stop_sequence_found = True + break + + if is_max_len or is_eos or stop_sequence_found: + stream.finished = True + if is_eos or stop_sequence_found: + stream.finish_reason = "stop" + if getattr(self.config, "attention", "") == "paged": + self.engine.release_pages(slot=slot_idx) + + return decode_state + + + def _build_completions(self, streams, logprobs, echo): + """Builds the final Completion objects from the generated stream states.""" + completions = [] + for i, stream in enumerate(streams): + gen_ids_for_text = stream.generated_ids[:] + gen_logps_for_text = stream.generated_logprobs[:] + + if gen_ids_for_text and gen_ids_for_text[-1] in self.eos_ids: + gen_ids_for_text = gen_ids_for_text[:-1] + if len(gen_logps_for_text) >= len(stream.generated_ids): + gen_logps_for_text = gen_logps_for_text[:-1] + + tokens_for_text = stream.prompt_ids + gen_ids_for_text if echo else gen_ids_for_text + logps_for_text = stream.prompt_logprobs + gen_logps_for_text if echo else gen_logps_for_text + + text = self.tokenizer.decode(tokens_for_text) + offsets = self._token_offsets(tokens_for_text, 0) + + lp_payload = None + if logprobs is not None: + if len(tokens_for_text) != len(logps_for_text): + self.logger.warning(f"[warn] Mismatched token/logprob lengths for stream {i}. No logprobs returned.") + else: + lp_payload = LogProbs( + tokens=tokens_for_text, + token_logprobs=logps_for_text, + top_logprobs=None, + text_offset=offsets, + ) + + completions.append( + Completion( + index=i, + text=text, + tokens=tokens_for_text, + logprobs=lp_payload, + finish_reason=stream.finish_reason, + prompt_token_count=len(stream.prompt_ids), + completion_token_count=len(gen_ids_for_text), + ) + ) + return completions + + + def _preprocess_inputs(self, text, prefill_length, image_path): + """Helper to preprocess a single text and optional image input.""" + processor_output = multimodal_utils.PreprocessorOutput() + images = None + if self.config.use_multimodal and image_path: + text = multimodal_utils.reformat_prompt( + text, image_placeholder=self.config.image_placeholder, model_name=self.config.model_name + ) + loaded_images = multimodal_utils.load_image_from_path(image_path) + processor_output = multimodal_utils.pre_process_image(loaded_images, model_name=self.config.model_name) + prefill_length -= multimodal_utils.get_image_offsets( + self.config.model_name, processor_output=processor_output + ) + images = processor_output.pixel_values + + tokens, true_length = self.tokenizer.encode(text, is_bos=not self.has_chat_template, prefill_lengths=[prefill_length]) + if self.config.use_multimodal and image_path: + tokens = multimodal_utils.prepare_text_for_image_fusion(tokens, model_name=self.config.model_name, processor_output=processor_output) + true_length += multimodal_utils.get_image_offsets(self.config.model_name, processor_output=processor_output) + + return tokens, true_length, images + + + def _validate_config(self, config): + """Validates configuration.""" + assert config.load_full_state_path == "", "Decode doesn't operate on full states! Convert to parameter checkpoint first. Using generate_param_only_checkpoint." + assert config.quantization != "fp8", "fp8 on NVIDIA GPUs is not supported in decode.py yet" + assert config.quantization != "nanoo_fp8", "NANOO fp8 on AMD MI300/MI325 GPUs is not supported in decode.py yet" + assert config.per_device_batch_size * jax.device_count() >= 1, "Total batch size must be at least 1." + + + def _token_offsets(self, token_ids: List[int], start: int = 0) -> List[int]: + """ + Compute char offsets by decoding cumulatively so context-dependent + whitespace/normalization is handled correctly (SentencePiece/LLaMA quirk). + """ + offsets: List[int] = [] + pos = start + decoded_so_far = "" + prefix_ids: List[int] = [] + for tid in token_ids: + offsets.append(pos) + prefix_ids.append(int(tid)) + new_text = self.tokenizer.decode(prefix_ids) + piece_len = len(new_text) - len(decoded_so_far) + # Guard for weird edge cases; shouldn't happen but better safe: + if piece_len < 0: + piece_len = len(self.tokenizer.decode([int(tid)])) + pos += piece_len + decoded_so_far = new_text + return offsets + + +if __name__ == "__main__": + import sys + import time + + + def dump_completion(i, comp): + """ + Dumps the content of a Completion object to the log for debugging. + + Args: + i: The index of the completion. + comp: The Completion object to dump. + """ + max_logging.log(f"\n=== Completion {i} ===") + max_logging.log(f"index: {comp.index}") + max_logging.log(f"text: {repr(comp.text)}") + + + if comp.logprobs is None: + max_logging.log("logprobs: None") + return + + lp = comp.logprobs + # lengths should match: one logprob/offset per token + if not (len(lp.tokens) == len(lp.token_logprobs) == len(lp.text_offset)): + max_logging.log(f"[warn] mismatched lengths: tokens={len(lp.tokens)}, " + f"logps={len(lp.token_logprobs)}, offsets={len(lp.text_offset)}") + + max_logging.log("logprobs:") + max_logging.log(f" tokens (ids): {lp.tokens}") + max_logging.log(f" token_logprobs: {[round(x, 6) for x in lp.token_logprobs]}") + max_logging.log(f" text_offset: {lp.text_offset}") + max_logging.log(f" top_logprobs: {lp.top_logprobs}") + + max_logging.log(" tokens (decoded, id, logprob, offset):") + for tid, logp, off in zip(lp.tokens, lp.token_logprobs, lp.text_offset): + piece = llm.tokenizer.decode([int(tid)]) + max_logging.log(f" {repr(piece):>12s} id={int(tid):>6d} logp={logp:>10.6f} offset={off}") + + # When running standalone, basic logging is automatically configured. + # For server use, the server configures logging. + logging.basicConfig(level=logging.INFO) + + # Instantiate first to initialize JAX + llm = MaxTextGenerator(sys.argv) + + prompts_to_run = [ + "The capital of France is ", + ] + + max_tokens = 32 + echo = True + want_logprobs = 5 + temperature = 0.6 + seed = 72 + top_p = 0.95 + top_k = 20 + + max_logging.log( + f"\n--- Starting Batch Generation for {len(prompts_to_run)} Prompts " + f"(max_tokens={max_tokens}, echo={echo}) ---" + ) + + completions = llm.generate_batch( + prompts=prompts_to_run, + image_paths=None, + max_tokens=max_tokens, + logprobs=want_logprobs, + echo=echo, + seed=seed, + temperature=temperature, + top_p=top_p, + top_k=top_k + ) + + for i, comp in enumerate(completions): + dump_completion(i, comp) + + start = time.time() + + completions = llm.generate_batch( + prompts=prompts_to_run, + image_paths=None, + max_tokens=max_tokens, + logprobs=want_logprobs, + echo=echo, + seed=seed, + temperature=temperature, + top_p=top_p, + top_k=top_k + ) + + max_logging.log("--- Batch Generation Complete ---") + + for i, comp in enumerate(completions): + dump_completion(i, comp) + + end = time.time() + max_logging.log(f"total time: {end - start}") diff --git a/benchmarks/api_server/maxtext_server.py b/benchmarks/api_server/maxtext_server.py new file mode 100644 index 0000000000..46707bf1db --- /dev/null +++ b/benchmarks/api_server/maxtext_server.py @@ -0,0 +1,431 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import sys +import time +import uuid +import json +import signal +import asyncio +import threading +import queue +import logging +from typing import Union + +import uvicorn +from fastapi import FastAPI, HTTPException +import jax +import jax.numpy as jnp +from jax.experimental import multihost_utils + + +from benchmarks.api_server.maxtext_generator import MaxTextGenerator +from benchmarks.api_server.server_models import ( + CompletionRequest, + CompletionResponse, + CompletionChoice, + Usage, + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionChoice, + ChatMessage, +) +from benchmarks.api_server import server_utils +from openai_harmony import ( + load_harmony_encoding, + HarmonyEncodingName, + Role, +) + +# ---------------------------- +# Init +# ---------------------------- + +# JAX distributed initialization must happen before any other JAX calls. +# We suppress the normal logger until after JAX is initialized. +logging.basicConfig(level=logging.WARNING) +print("Initializing MaxTextGenerator and JAX distributed system...") +llm = MaxTextGenerator(sys.argv) +rank = jax.process_index() + +# Now that JAX is initialized, we can get our rank-specific logger. +# The actual handler/formatter configuration will be done by Uvicorn. +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) # Ensure our logger passes INFO messages. +logger.info("MaxTextGenerator initialization complete.") + +harmony_enc = None +try: + harmony_enc = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + logger.info("Harmony encoding for gpt-oss loaded successfully.") +except ImportError: + logger.warning("openai_harmony not installed. GPT-OSS Harmony format will not be available.") +except Exception as e: + logger.error(f"Failed to load Harmony encoding: {e}") + + +app = FastAPI() + +# Global state for communication between threads. +request_queue = queue.Queue() +# A thread-safe dict to hold responses, keyed by request_id. +response_dict = {} +response_lock = threading.Lock() + +# Batching configuration +BATCH_TIMEOUT_S = 0.1 # 100ms +# Timeout for a client waiting for a response. +REQUEST_TIMEOUT_S = int(os.environ.get("MAXTEXT_REQUEST_TIMEOUT_S", "36000")) + + +async def _queue_and_wait_for_response(request: Union[CompletionRequest, ChatCompletionRequest]): + """ + Puts a request on the processing queue and waits for a response. + + This asynchronous function is the core of handling client requests. It generates + a unique ID for the request, places it in a global queue to be processed by + the batching loop, and then waits until the response is available in a + shared dictionary or until a timeout occurs. + + Args: + request: The incoming request object, either for a completion or a chat completion. + + Returns: + The response data once it's available. + + Raises: + HTTPException: If the request times out (504) or if an error occurs + during processing (500). + """ + request_id = f"req_{uuid.uuid4().hex}" + request_queue.put((request_id, request)) + + start_time = time.time() + while time.time() - start_time < REQUEST_TIMEOUT_S: + with response_lock: + if request_id in response_dict: + response_data = response_dict.pop(request_id) + if "error" in response_data: + raise HTTPException(status_code=500, detail=response_data["error"]) + return response_data + # Yield control to the event loop to allow other tasks to run. + await asyncio.sleep(0.05) + + raise HTTPException(status_code=504, detail="Request timed out.") + + +@app.post("/v1/completions", response_model=CompletionResponse) +async def create_completion(request: CompletionRequest): + """Handles completion requests with dynamic batching.""" + return await _queue_and_wait_for_response(request) + + +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) +async def create_chat_completion(request: ChatCompletionRequest): + """Handles chat completion requests with dynamic batching.""" + return await _queue_and_wait_for_response(request) + + +@app.get("/") +def health_check(): + """ + Provides a simple health check endpoint. + + Returns: + A dictionary indicating the server status. + """ + return {"status": "ok", "message": "MaxText API server is running."} + + +def run_server(): + """Runs the Uvicorn server in a separate thread.""" + uvicorn.run(app, host="0.0.0.0", port=8000) + + +# Define a maximum size for the request payload to be broadcasted. +# This avoids broadcasting variable-sized arrays, which can be complex. +MAX_REQUEST_SIZE = 65536 * 10 + + +def _build_chat_completion_response(request, completion_result, llm): + """Builds a ChatCompletionResponse from a single completion result.""" + text_out = completion_result.text + if "gpt-oss" in request.model and harmony_enc: + try: + parsed_messages = harmony_enc.parse_messages_from_completion_tokens( + completion_result.tokens, role=Role.ASSISTANT + ) + user_visible = "".join( + part.text for m in parsed_messages + if m.channel == "final" + for part in m.content + ) + if user_visible: + text_out = user_visible + else: + logger.warning("Harmony parsing for gpt-oss did not yield content in the 'final' channel. Falling back to raw text.") + except Exception as e: + logger.error(f"Harmony parsing failed for gpt-oss: {e}. Falling back to raw text.") + + want_top_logprobs = (request.top_logprobs or 0) > 0 if isinstance(request, ChatCompletionRequest) else (request.logprobs or 0) > 0 + lp_payload = server_utils.to_openai_logprobs( + getattr(completion_result, "logprobs", None), llm, want_top=want_top_logprobs + ) + text_out, lp_payload, finish_reason = server_utils.apply_stops_to_text_and_logprobs( + text_out, lp_payload, request.stop + ) + if finish_reason is None: + finish_reason = completion_result.finish_reason + + usage = Usage( + prompt_tokens=completion_result.prompt_token_count, + completion_tokens=completion_result.completion_token_count, + total_tokens=completion_result.prompt_token_count + completion_result.completion_token_count, + ) + return ChatCompletionResponse( + model=request.model, + choices=[ + ChatCompletionChoice( + index=0, + message=ChatMessage(role="assistant", content=text_out), + finish_reason=finish_reason, + logprobs=lp_payload, + ) + ], + usage=usage, + ) + + +def _build_completion_response(request, completions, prompts, llm): + """Builds a CompletionResponse from a list of completion results.""" + choices = [] + prompt_tokens_total = 0 + completion_tokens_total = 0 + + for idx, _ in enumerate(prompts): + item = completions[idx] + text_out = item.text + lp_payload = server_utils.to_openai_logprobs( + getattr(item, "logprobs", None), llm, want_top=(request.logprobs or 0) > 0 + ) + finish_reason = getattr(item, "finish_reason", "stop") + text_out, lp_payload, stop_reason = server_utils.apply_stops_to_text_and_logprobs( + text_out, lp_payload, request.stop + ) + if stop_reason is not None: + finish_reason = stop_reason + + prompt_tokens_total += item.prompt_token_count + completion_tokens_total += item.completion_token_count + + choices.append(CompletionChoice( + text=text_out, + index=idx, + logprobs=lp_payload, + finish_reason=finish_reason, + )) + + usage = Usage( + prompt_tokens=prompt_tokens_total, + completion_tokens=completion_tokens_total, + total_tokens=prompt_tokens_total + completion_tokens_total, + ) + return CompletionResponse( + model=request.model, + choices=choices, + usage=usage, + ) + + +def _create_response(request, completions, prompts, is_chat, llm): + """Creates either a CompletionResponse or ChatCompletionResponse.""" + if is_chat: + # Chat API only ever processes one prompt at a time from the server's perspective. + return _build_chat_completion_response(request, completions[0], llm) + else: + return _build_completion_response(request, completions, prompts, llm) + + +def _collect_batched_requests(): + """Waits for and collects a batch of requests from the queue.""" + batched_items = [] + start_time = time.time() + while len(batched_items) < llm.batch_size and (time.time() - start_time) < BATCH_TIMEOUT_S: + try: + item = request_queue.get(timeout=0.01) + batched_items.append(item) + except queue.Empty: + if batched_items: + break # Process what we have if timeout is reached + return batched_items + + +def _prepare_batch_for_broadcast(batched_items): + """Prepares the batch payload and request map for broadcasting.""" + first_request_id, first_request = batched_items[0] + + logprobs_param = None + if isinstance(first_request, ChatCompletionRequest): + if first_request.logprobs: + logprobs_param = first_request.top_logprobs if first_request.top_logprobs is not None else 1 + else: # CompletionRequest + logprobs_param = first_request.logprobs + + params = { + "max_tokens": first_request.max_tokens, "logprobs": logprobs_param, + "echo": getattr(first_request, "echo", False), "stop": first_request.stop, + "temperature": first_request.temperature, "seed": first_request.seed, + "top_k": first_request.top_k, "top_p": first_request.top_p, + } + + all_prompts = [] + request_info_map = [] + for req_id, req in batched_items: + is_chat = isinstance(req, ChatCompletionRequest) + prompts_for_req = server_utils.get_prompts_for_request(req, llm) + all_prompts.extend(prompts_for_req) + request_info_map.append((req_id, req, is_chat, len(prompts_for_req))) + + broadcast_payload = {"prompts": all_prompts, "params": params} + payload_bytes = json.dumps(broadcast_payload).encode("utf-8") + payload_len = len(payload_bytes) + + if payload_len > MAX_REQUEST_SIZE: + logger.error(f"Batched request is too large ({payload_len} bytes > {MAX_REQUEST_SIZE})") + for req_id, _, _, _ in request_info_map: + with response_lock: + response_dict[req_id] = {"error": "Batched request payload is too large."} + return 0, b'', [] # Signal other ranks to skip + + return payload_len, payload_bytes, request_info_map + + +def _process_results(completions, request_info_map, payload): + """Processes completions and sends responses back to the waiting threads.""" + logger.info(f"Batched generation finished. Processing {len(completions)} completions.") + completion_idx = 0 + for req_id, req, is_chat, num_prompts in request_info_map: + completions_for_req = completions[completion_idx : completion_idx + num_prompts] + prompts_for_req = payload["prompts"][completion_idx : completion_idx + num_prompts] + completion_idx += num_prompts + + response = _create_response(req, completions_for_req, prompts_for_req, is_chat, llm) + with response_lock: + response_dict[req_id] = response + + +def main_loop(): + """The main processing loop with dynamic batching for all JAX processes.""" + while True: + payload_len, payload_bytes, request_info_map = 0, b'', [] + if jax.process_index() == 0: + batched_items = _collect_batched_requests() + if batched_items: + payload_len, payload_bytes, request_info_map = _prepare_batch_for_broadcast(batched_items) + + # Broadcast the payload to all ranks + data_to_broadcast = ( + jnp.array([payload_len], dtype=jnp.int32), + jnp.pad(jnp.frombuffer(payload_bytes, dtype=jnp.uint8), (0, MAX_REQUEST_SIZE - payload_len)), + ) + received_len_array, received_data_array = multihost_utils.broadcast_one_to_all(data_to_broadcast) + + received_len = int(received_len_array[0]) + if received_len == 0: + time.sleep(0.01) + continue + + broadcasted_data = received_data_array[:received_len].tobytes().decode("utf-8") + payload = json.loads(broadcasted_data) + + try: + if jax.process_index() == 0: + logger.info(f"Starting batched generation for {len(payload['prompts'])} prompts with params: {payload['params']}") + + completions = llm.generate_batch(prompts=payload["prompts"], **payload["params"]) + + if jax.process_index() == 0: + _process_results(completions, request_info_map, payload) + + except Exception as e: + logger.error(f"Inference failed for batch: {e}", exc_info=True) + if jax.process_index() == 0: + for req_id, _, _, _ in request_info_map: + with response_lock: + response_dict[req_id] = {"error": f"Inference failed: {e}"} + + +if __name__ == "__main__": + server_thread = None + server = None + + # The coordinator process (rank 0) runs the FastAPI server in a separate thread. + if jax.process_index() == 0: + # Define a Uvicorn-compatible logging config. + LOGGING_CONFIG = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "()": "uvicorn.logging.DefaultFormatter", + "fmt": f"%(levelprefix)s RANK {rank}: %(message)s", + "use_colors": None, + }, + "access": { + "()": "uvicorn.logging.AccessFormatter", + "fmt": f'%(levelprefix)s RANK {rank}: %(client_addr)s - "%(request_line)s" %(status_code)s', + }, + }, + "handlers": { + "default": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stderr", + }, + "access": { + "formatter": "access", + "class": "logging.StreamHandler", + "stream": "ext://sys.stdout", + }, + }, + "loggers": { + "": {"handlers": ["default"], "level": "INFO"}, + "uvicorn.error": {"level": "INFO"}, + "uvicorn.access": {"handlers": ["access"], "level": "INFO", "propagate": False}, + }, + } + + config = uvicorn.Config(app, host="0.0.0.0", port=8000, log_config=LOGGING_CONFIG) + server = uvicorn.Server(config) + server_thread = threading.Thread(target=server.run) + + logger.info(f"Starting Uvicorn server in a background thread on coordinator process {jax.process_index()}...") + server_thread.start() + + try: + # All processes (coordinator and workers) enter the main processing loop. + logger.info(f"Process {jax.process_index()} is entering the main processing loop.") + main_loop() + except KeyboardInterrupt: + logger.info(f"Process {jax.process_index()} received KeyboardInterrupt. Shutting down.") + finally: + if jax.process_index() == 0 and server is not None and server_thread is not None: + logger.info("Stopping Uvicorn server...") + server.should_exit = True + server_thread.join() + logger.info("Uvicorn server stopped.") + + logger.info(f"Process {jax.process_index()} has exited.") diff --git a/benchmarks/api_server/port_forward_xpk.sh b/benchmarks/api_server/port_forward_xpk.sh new file mode 100644 index 0000000000..625440d5ed --- /dev/null +++ b/benchmarks/api_server/port_forward_xpk.sh @@ -0,0 +1,104 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +#!/bin/bash +# +# This script automates finding the correct pod in a MaxText server workload +# and establishes a port-forward connection to it. +# +# Usage: +# bash port_forward_xpk.sh job_name= project= zone= cluster= [namespace=] + +set -eu # Exit immediately if a command exits with a non-zero status or if an unset variable is used. + +# --- Argument Parsing --- +NAMESPACE="default" # Default namespace + +for arg in "$@" +do + case $arg in + job_name=*) + JOB_NAME="${arg#*=}" + # Shift removes the current argument from the list of positional parameters ($@). + shift + ;; + project=*) + PROJECT="${arg#*=}" + shift + ;; + zone=*) + ZONE="${arg#*=}" + shift + ;; + cluster=*) + CLUSTER="${arg#*=}" + shift + ;; + namespace=*) + NAMESPACE="${arg#*=}" + shift + ;; + esac +done + +# --- Validate Arguments --- +if [ -z "$JOB_NAME" ] || [ -z "$PROJECT" ] || [ -z "$ZONE" ] || [ -z "$CLUSTER" ]; then + echo "Usage: $0 job_name= project= zone= cluster= [namespace=]" >&2 + exit 1 +fi + +echo "--- Configuration ---" +echo "Project: $PROJECT" +echo "Zone: $ZONE" +echo "Cluster: $CLUSTER" +echo "Job Name: $JOB_NAME" +echo "Namespace: $NAMESPACE" +echo "---------------------" + +# --- Get GKE Credentials --- +echo "Fetching cluster credentials..." +gcloud container clusters get-credentials "$CLUSTER" --zone "$ZONE" --project "$PROJECT" > /dev/null + +# --- Find the Server Pod --- +echo "Searching for pods in namespace '$NAMESPACE' with label 'job-name=$JOB_NAME'..." +# Use a label selector for an efficient server-side lookup. +# Read the space-separated pod names safely into a bash array. +read -r -a PODS <<< "$(kubectl get pods -n "$NAMESPACE" -l "job-name=$JOB_NAME" -o jsonpath='{.items[*].metadata.name}')" + +if [ -z "$PODS" ]; then + echo "Error: No pods found for job name '$JOB_NAME' in namespace '$NAMESPACE'." + exit 1 +fi + +SERVER_POD="" +for pod in "${PODS[@]}"; do + echo "Checking logs for pod: $pod..." + # Use grep -q for a silent check. The command succeeds if the pattern is found. + if kubectl logs "$pod" -n "$NAMESPACE" | grep -q "Uvicorn running on http://0.0.0.0:8000"; then + echo "Found server running in pod: $pod" + SERVER_POD=$pod + break # Exit the loop once the server pod is found + fi +done + +# --- Establish Port Forwarding --- +if [ -n "$SERVER_POD" ]; then + echo "Establishing port-forward from localhost:8000 to $SERVER_POD:8000 in namespace '$NAMESPACE'..." + echo "You can now send requests to http://localhost:8000" + kubectl port-forward "pod/$SERVER_POD" -n "$NAMESPACE" 8000:8000 +else + echo "Error: Could not find a pod running the Uvicorn server for job '$JOB_NAME' in namespace '$NAMESPACE'." + exit 1 +fi diff --git a/benchmarks/api_server/requirements.txt b/benchmarks/api_server/requirements.txt new file mode 100644 index 0000000000..ac5d4046b4 --- /dev/null +++ b/benchmarks/api_server/requirements.txt @@ -0,0 +1,4 @@ +uvicorn +fastapi +openai-harmony +pyyaml \ No newline at end of file diff --git a/benchmarks/api_server/server_models.py b/benchmarks/api_server/server_models.py new file mode 100644 index 0000000000..cff25db031 --- /dev/null +++ b/benchmarks/api_server/server_models.py @@ -0,0 +1,206 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time +import uuid +from typing import List, Optional, Union, Dict, TypeVar, Generic + +from pydantic import BaseModel, Field, field_validator + + +class SamplingParams(BaseModel): + """ + Defines the common sampling parameters that are shared across different types of + generation requests, such as standard completions and chat-based completions. + + Attributes: + max_tokens: The maximum number of tokens to generate. + temperature: The sampling temperature. + top_p: The nucleus sampling probability. + top_k: The top-k sampling integer. + stream: Whether to stream the response. + stop: A string or list of strings that will stop the generation. + seed: A seed for deterministic sampling. + """ + max_tokens: Optional[int] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + stream: Optional[bool] = False + stop: Optional[Union[str, List[str]]] = None + seed: Optional[int] = None + + +class CompletionRequest(SamplingParams): + """ + Represents a request for a standard text completion, inheriting sampling + parameters from `SamplingParams`. + + Attributes: + model: The ID of the model to use for the completion. + prompt: The prompt(s) to generate completions for, which can be a string, + a list of strings, a list of token IDs, or a list of lists of token IDs. + echo: Whether to echo the prompt back in the response. + logprobs: The number of top log probabilities to return for each token. + """ + model: str + prompt: Union[str, List[str], List[int], List[List[int]]] + echo: Optional[bool] = False + logprobs: Optional[int] = None + + @field_validator("logprobs") + def validate_logprobs(cls, v): + if v is not None and v < 0: + raise ValueError("logprobs must be a non-negative integer if provided.") + return v + + +class LogProbsPayload(BaseModel): + """ + A data structure to hold the log probability information for a sequence of tokens, + formatted to be compatible with OpenAI's API. + + Attributes: + tokens: The string representation of each token. + token_logprobs: The log probability of each token. + top_logprobs: A list of dictionaries mapping other tokens to their log + probabilities at each position. + text_offset: The character offset of each token in the text. + """ + tokens: List[str] + token_logprobs: List[Optional[float]] + top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None + text_offset: List[int] + + +class CompletionChoice(BaseModel): + """ + Represents a single choice (a possible completion) in a `CompletionResponse`. + + Attributes: + text: The generated text for this choice. + index: The index of this choice in the list of choices. + logprobs: An optional payload containing log probability information. + finish_reason: The reason the model stopped generating tokens (e.g., 'stop', 'length'). + """ + text: str + index: int + logprobs: Optional[LogProbsPayload] = None + finish_reason: str = "stop" + + +class Usage(BaseModel): + """ + Provides information about the number of tokens used in a request. + + Attributes: + prompt_tokens: The number of tokens in the input prompt. + completion_tokens: The number of tokens in the generated completion. + total_tokens: The total number of tokens used. + """ + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +# Define a TypeVar for the choice models +ChoiceType = TypeVar('ChoiceType') + + +class BaseCompletionResponse(BaseModel, Generic[ChoiceType]): + """ + A generic base response model using Python's Generic type. It shares all + common fields for API responses and uses a TypeVar for the 'choices' list + to accommodate different types of choices (e.g., for standard vs. chat completions). + + Attributes: + id: A unique identifier for the response. + object: The type of the object (e.g., 'text_completion'). + created: The timestamp when the response was created. + model: The model that generated the response. + choices: A list of choices, the type of which is determined by `ChoiceType`. + usage: Token usage statistics. + """ + id: str + object: str + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChoiceType] + usage: Usage + + +class CompletionResponse(BaseCompletionResponse[CompletionChoice]): + """ + The specific response object for a standard completion request. It inherits + from the generic base and specifies `CompletionChoice` as its choice type. + It also provides default values for the 'id' and 'object' fields. + """ + id: str = Field(default_factory=lambda: f"cmpl-{uuid.uuid4().hex}") + object: str = "text_completion" + + +class ChatMessage(BaseModel): + """ + Represents a single message within a chat conversation. + + Attributes: + role: The role of the message's author (e.g., 'user', 'assistant'). + content: The text content of the message. + """ + role: str + content: str + + +class ChatCompletionRequest(SamplingParams): + """ + Represents a request for a chat-based completion, where the input is a + sequence of messages. Inherits sampling parameters from `SamplingParams`. + + Attributes: + model: The ID of the model to use. + messages: A list of `ChatMessage` objects representing the conversation history. + logprobs: Whether to return log probabilities. + top_logprobs: The number of top log probabilities to return if `logprobs` is true. + """ + model: str + messages: List[ChatMessage] + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = None + + +class ChatCompletionChoice(BaseModel): + """ + Represents a single choice in a `ChatCompletionResponse`. + + Attributes: + index: The index of this choice. + message: The `ChatMessage` generated by the model. + finish_reason: The reason the model stopped generating. + logprobs: An optional payload with log probability information. + """ + index: int + message: ChatMessage + finish_reason: str = "stop" + logprobs: Optional[LogProbsPayload] = None + + +class ChatCompletionResponse(BaseCompletionResponse[ChatCompletionChoice]): + """ + The specific response object for a chat completion request. It inherits from + the generic base, specifies `ChatCompletionChoice` as its choice type, and + provides chat-specific default values for 'id' and 'object'. + """ + id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}") + object: str = "chat.completion" diff --git a/benchmarks/api_server/server_utils.py b/benchmarks/api_server/server_utils.py new file mode 100644 index 0000000000..04849559a4 --- /dev/null +++ b/benchmarks/api_server/server_utils.py @@ -0,0 +1,300 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import bisect +import math +import yaml +import json +import logging +from datetime import datetime, timezone +from typing import List, Optional, Union, Dict, Any + +from fastapi import HTTPException + +from benchmarks.api_server.maxtext_generator import MaxTextGenerator +from benchmarks.api_server.server_models import LogProbsPayload + +# ---------------------------- +# Debugging +# ---------------------------- + +DEBUG_MODE = os.environ.get("MAXTEXT_SERVER_DEBUG", "0") == "1" +DEBUG_LOG_FILE = os.environ.get("MAXTEXT_DEBUG_LOG_FILE", "benchmarks/api_server/server_debug_log.jsonl") +logger = logging.getLogger(__name__) + + +def log_debug_event(request_id: str, event_type: str, content: dict): + """ + Helper to write a structured debug log entry if DEBUG_MODE is on. + + Args: + request_id: The unique identifier for the request. + event_type: A string describing the type of event being logged (e.g., 'request', 'response'). + content: A dictionary containing the data to be logged. + """ + if not DEBUG_MODE: + return + try: + log_entry = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "request_id": request_id, + "event": event_type, + "content": content, + } + with open(DEBUG_LOG_FILE, "a") as f: + f.write(json.dumps(log_entry) + "\n") + except Exception as e: + # Use logger for errors + logger.error(f"Error writing to debug log file '{DEBUG_LOG_FILE}': {e}") + + +# ---------------------------- +# Request/Response Helpers +# ---------------------------- + + +def decode_one_prompt(p: Union[str, List[int]], llm: MaxTextGenerator) -> str: + """ + Decodes a single prompt element, which can be a string or a list of token IDs. + + Args: + p: The prompt element to decode. + llm: The MaxTextGenerator instance, used for its tokenizer. + + Returns: + The decoded prompt as a string. + + Raises: + ValueError: If the prompt item has an unsupported type. + """ + if isinstance(p, str): + return p + if isinstance(p, list) and (len(p) == 0 or isinstance(p[0], int)): + try: + return llm.tokenizer.decode(p) + except Exception: + print("Return empty string on decoding error") + return "" + raise ValueError("Unsupported prompt item type") + + +def get_prompts_for_request(req: any, llm: MaxTextGenerator) -> List[str]: + """ + Extracts and formats a list of prompts from a request object. + + This function handles both standard `CompletionRequest` and `ChatCompletionRequest` + types, converting them into a unified list of string prompts that the model + can process. + + Args: + req: The request object. + llm: The MaxTextGenerator instance. + + Returns: + A list of string prompts. + """ + if hasattr(req, 'messages'): # ChatCompletionRequest + messages = [m.model_dump() for m in req.messages] + formatted_prompt = llm.tokenizer.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + return [formatted_prompt] + else: # CompletionRequest + return normalize_prompts(req.prompt, llm) + + +def normalize_prompts(prompt: Union[str, List[str], List[int], List[List[int]]], llm: MaxTextGenerator) -> List[str]: + """ + Normalizes the highly flexible 'prompt' field from an OpenAI-style request + into a simple list of strings. + + The 'prompt' field can be a single string, a list of strings, a list of + token IDs, or a list of lists of token IDs. This function handles all + these cases and returns a flat list of string prompts. + + Args: + prompt: The prompt data from the request. + llm: The MaxTextGenerator instance for decoding token IDs. + + Returns: + A list of normalized string prompts. + + Raises: + HTTPException: If the prompt type is not supported. + """ + if isinstance(prompt, str): + return [prompt] + if isinstance(prompt, list): + if len(prompt) == 0: + return [] + # Prompts can be a list of strings, a single list of ints, or a list of lists of ints. + first = prompt[0] + if isinstance(first, str): + return [str(x) for x in prompt] + if isinstance(first, int): + return [decode_one_prompt(prompt, llm)] + if isinstance(first, list): + return [decode_one_prompt(x, llm) for x in prompt] + raise HTTPException(status_code=400, detail="Unsupported prompt type for this API.") + + +def decode_token_id(token_id: int, llm: MaxTextGenerator) -> str: + """ + Decodes a single token ID into its string representation. + + Args: + token_id: The integer token ID to decode. + llm: The MaxTextGenerator instance. + + Returns: + The decoded string. + """ + return llm.tokenizer.decode([int(token_id)]) + + +def finite_or_none(v: Optional[float]) -> Optional[float]: + """ + Returns the float if it's finite (i.e., not NaN or infinity), otherwise None. + + Args: + v: The float value to check. + + Returns: + The original float if it is finite, otherwise None. + """ + if v is None: + return None + f = float(v) + return f if math.isfinite(f) else None + + +def to_openai_logprobs(lp_obj: Any, llm: MaxTextGenerator, want_top: bool = True) -> Optional[LogProbsPayload]: + """ + Converts the internal logprobs object to the OpenAI-compatible format. + + Args: + lp_obj: The internal logprobs object from the generation result. + llm: The MaxTextGenerator instance for decoding tokens. + want_top: Whether to populate the `top_logprobs` field. + + Returns: + A `LogProbsPayload` object compatible with the OpenAI API, or None. + """ + if lp_obj is None: + return None + + token_strings = [decode_token_id(tid, llm) for tid in lp_obj.tokens] + token_logprobs = [finite_or_none(v) for v in lp_obj.token_logprobs] + text_offset = list(lp_obj.text_offset) + + # Ensure all lists are of the same length to avoid errors. + min_len = min(len(token_strings), len(token_logprobs), len(text_offset)) + token_strings = token_strings[:min_len] + token_logprobs = token_logprobs[:min_len] + text_offset = text_offset[:min_len] + + top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None + if want_top: + # The current implementation only returns the logprob of the single sampled token. + # This structure is a placeholder for a future feature where the model might + # return the logprobs of multiple top tokens at each step. + top_logprobs = [ + ({tok: lp} if lp is not None else None) + for tok, lp in zip(token_strings, token_logprobs) + ] + + return LogProbsPayload( + tokens=token_strings, + token_logprobs=token_logprobs, + top_logprobs=top_logprobs, + text_offset=text_offset, + ) + + +def count_tokens(s: str, llm: MaxTextGenerator) -> int: + """ + Counts the number of tokens in a string. + + Args: + s: The string to tokenize. + llm: The MaxTextGenerator instance. + + Returns: + The number of tokens in the string. + """ + try: + # Use the underlying tokenizer to avoid the jetstream wrapper's + # padding issues with single-token sequences. + ids = llm.tokenizer.tokenizer.encode(s, add_special_tokens=False) + return len(ids) + except Exception as e: + logger.warning(f"Could not count tokens for string '{s[:50]}...': {e}") + return 0 + + +def apply_stops_to_text_and_logprobs( + text: str, + logprobs_payload: Optional[LogProbsPayload], + stop: Optional[Union[str, List[str]]], +) -> tuple[str, Optional[LogProbsPayload], Optional[str]]: + """ + Truncates the generated text and corresponding logprobs at the first occurrence + of any of the specified stop sequences. + + Args: + text: The generated text. + logprobs_payload: The corresponding logprobs payload. + stop: The stop sequence(s) to search for. + + Returns: + A tuple containing the truncated text, the truncated logprobs payload, + and the reason for stopping ('stop' if a sequence was found, otherwise None). + """ + if not stop: + return text, logprobs_payload, None + + stops = [stop] if isinstance(stop, str) else list(stop) + + # Find the earliest stop sequence + first_stop_index = -1 + for s in stops: + if not s: + continue + i = text.find(s) + if i != -1: + first_stop_index = i if first_stop_index == -1 else min(first_stop_index, i) + + if first_stop_index == -1: + return text, logprobs_payload, None + + # Truncate text + new_text = text[:first_stop_index] + + # Truncate logprobs payload if it exists + if logprobs_payload is not None: + truncate_at_index = bisect.bisect_left( + logprobs_payload.text_offset, first_stop_index + ) + + new_logprobs = LogProbsPayload( + tokens=logprobs_payload.tokens[:truncate_at_index], + token_logprobs=logprobs_payload.token_logprobs[:truncate_at_index], + top_logprobs=logprobs_payload.top_logprobs[:truncate_at_index] if logprobs_payload.top_logprobs is not None else None, + text_offset=logprobs_payload.text_offset[:truncate_at_index], + ) + return new_text, new_logprobs, "stop" + + return new_text, logprobs_payload, "stop" diff --git a/benchmarks/api_server/start_server.sh b/benchmarks/api_server/start_server.sh new file mode 100644 index 0000000000..7d9b9aa431 --- /dev/null +++ b/benchmarks/api_server/start_server.sh @@ -0,0 +1,60 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +#!/bin/bash +# This script starts the MaxText server from the project root, +# ensuring that Python can find the necessary modules regardless of +# where the script is invoked from. +# +# Example: +# bash benchmarks/api_server/start_server.sh \ +# MaxText/configs/base.yml \ +# model_name="qwen3-30b-a3b" \ +# tokenizer_path="Qwen/Qwen3-30B-A3B-Thinking-2507" \ +# load_parameters_path="" \ +# per_device_batch_size=4 \ +# ici_tensor_parallelism=4 \ +# max_prefill_predict_length=1024 \ +# max_target_length=2048 \ +# async_checkpointing=false \ +# scan_layers=false \ +# attention="dot_product" \ +# tokenizer_type="huggingface" \ +# return_log_prob=True +set -e + + +# Check if arguments were provided. +if [ -z "$1" ]; then + echo "Usage: $0 [arg1=value1 arg2=value2 ...]" >&2 + echo "Or: $0 " >&2 + exit 1 +fi + +# Get the absolute path of the directory where the script is located. +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +# The project root is two levels up from the script's directory. +PROJECT_ROOT=$(dirname $(dirname "$SCRIPT_DIR")) + +# Change to the project root directory. +cd "$PROJECT_ROOT" + +echo "Starting MaxText server on http://0.0.0.0:8000" +echo "Executing from project root: $(pwd)" +echo "Using arguments: $@" + +# Pass all script arguments directly to the python module. +python -u -m benchmarks.api_server.maxtext_server "$@"