diff --git a/pyproject.toml b/pyproject.toml index 960d2c1c..239cfe0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,8 @@ dependencies = [ "tiktoken>=0.9.0", "openai>=1.65.4", "opencv-python>=4.11.0.86", + "tritonclient[all]>=2.51.0", + "ultralytics>=8.3.120", ] [tool.hatch.metadata] diff --git a/src/r1_vlm/datasets/aok_vqa/aok_vqa_mc_tool_use_r1.py b/src/r1_vlm/datasets/aok_vqa/aok_vqa_mc_tool_use_r1.py index 12923ce9..56e8c599 100644 --- a/src/r1_vlm/datasets/aok_vqa/aok_vqa_mc_tool_use_r1.py +++ b/src/r1_vlm/datasets/aok_vqa/aok_vqa_mc_tool_use_r1.py @@ -28,7 +28,7 @@ def generate_r1_messages(example): system_prompt = "REPLACED WITH TOOLS SYSTEM PROMPT" - choices_str = "These are the possible answers, you must choose one: " + choices_str = "Possible answers: " for i, choice in enumerate(choices): if i == len(choices) - 1: choices_str += f"or {choice}." @@ -36,11 +36,11 @@ def generate_r1_messages(example): choices_str += f"{choice}, " instruction = f""" - {question} + Question: {question} - {choices_str} + {choices_str} You must choose one to answer the question and place in tags. - You must inspect the input image and gather visual evidence. The image size is {image_size}. + You must inspect the input image to gather visual evidence. After you've collected evidence, combine that with your knowledge of the world to answer the question. You must consider all 4 possible answers when thinking through your reasoning. The image size is {image_size}. """ r1_messages = [ @@ -66,7 +66,7 @@ def generate_r1_messages(example): "content": [ { "type": "text", - "text": "\n I'll collect as much visual evidence as possible from the image. First, I'll consider what region of the image to zoom in on to get the most information. Then, I'll review and consider the four possible answers. Then, I'll select the most likely answer based on the evidence and my knowledge of the world.", + "text": "\n I'll collect as much visual evidence as possible from the image. Then, I'll consider the four possible answers. Finally, I'll select the most likely answer based on the evidence and my knowledge of the world.", } ], }, diff --git a/src/r1_vlm/environments/multistep_vision_env.py b/src/r1_vlm/environments/multistep_vision_env.py index 24dc3ec5..cb93a967 100644 --- a/src/r1_vlm/environments/multistep_vision_env.py +++ b/src/r1_vlm/environments/multistep_vision_env.py @@ -57,7 +57,7 @@ def is_completed(self, messages: list[dict[str, str]], **kwargs: Any) -> bool: @abstractmethod def env_response( self, messages: list[dict[str, str]], **kwargs: Any - ) -> list[dict[str, Any]]: + ) -> list[dict[str, Any]]: pass def prepare_data(self, *, inputs, processing_class): @@ -96,12 +96,12 @@ def update_state(j, vlm_response): state["prompt_ids"] = vlm_response.prompt_token_ids # update the conversation with the model's response - state["messages"].append({ - "role": "assistant", - "content": [ - {"text": vlm_response.outputs[0].text, "type": "text"} - ] - }) + state["messages"].append( + { + "role": "assistant", + "content": [{"text": vlm_response.outputs[0].text, "type": "text"}], + } + ) # get token lengths of env response and new completion total_prev_len = len(state["prompt_ids"]) + len(state["completion_ids"]) @@ -120,7 +120,7 @@ def update_state(j, vlm_response): ] # if we are done, we mark the state as completed - # we do not want to truncate the completion ids here, + # we do not want to truncate the completion ids here, # because the number of image tokens returned from the tools is variable if ( self.is_completed(state["messages"]) @@ -153,7 +153,7 @@ def update_state(j, vlm_response): for j, state in results: states[j] = state - + return states def generate( @@ -176,7 +176,7 @@ def generate( } for conversation in conversations ] - + # main loop while not all_completed: states = self.step(states, vlm, custom_sp) @@ -190,8 +190,7 @@ def generate( "messages": completion_messages, "mask": completion_mask, } - - + def clean_messages_for_logging(messages): cleaned = [] images = [] @@ -201,7 +200,10 @@ def clean_messages_for_logging(messages): cleaned_content = [] for item in cleaned_message["content"]: cleaned_item = item.copy() - if "image" in cleaned_item and cleaned_item["image"] is not None: + if ( + "image" in cleaned_item + and cleaned_item["image"] is not None + ): images.append(cleaned_item["image"]) cleaned_item["image"] = "" cleaned_content.append(cleaned_item) @@ -212,28 +214,30 @@ def clean_messages_for_logging(messages): cleaned_messages, images = clean_messages_for_logging(states[0]["messages"]) self.logger.info( - "Full conversation 0:\n" - + json.dumps(cleaned_messages, indent=4) + "Full conversation 0:\n" + json.dumps(cleaned_messages, indent=4) ) for image in images: imgcat.imgcat(image) - + return output - + @staticmethod - def preprocess_messages(prompts_messages: list[list[dict[str, Any]]], completions_messages: list[list[dict[str, Any]]]) -> list[list[dict[str, Any]]]: - ''' + def preprocess_messages( + prompts_messages: list[list[dict[str, Any]]], + completions_messages: list[list[dict[str, Any]]], + ) -> list[list[dict[str, Any]]]: + """ 1. Combines prompts and completion messages into full conversations 2. Removes all messages before the first assistant message, leaving only the completion 3. Merges elements of the completion that come from the same source and are text only - + Args: prompts: list of prompt conversations completions_messages: list of completion conversations - + Returns: list of preprocessed completion conversations - ''' + """ # Combine prompts and completions into full conversations combined_messages = [] for prompt_msgs, completion_msgs in zip(prompts_messages, completions_messages): @@ -241,22 +245,29 @@ def preprocess_messages(prompts_messages: list[list[dict[str, Any]]], completion conversation.extend(prompt_msgs) conversation.extend(completion_msgs) combined_messages.append(conversation) - + filtered_messages = [] for completion in combined_messages: # find the index of the first assistant message - assistant_message_index = next((i for i, message in enumerate(completion) if message["role"] == "assistant"), None) - + assistant_message_index = next( + ( + i + for i, message in enumerate(completion) + if message["role"] == "assistant" + ), + None, + ) + if assistant_message_index is not None: # keep only messages from the first assistant message onwards filtered_messages.append(completion[assistant_message_index:]) - + merged_completions = [] - + for completion in filtered_messages: merged_completion = [] current_message = None - + for message in completion: # If message has non-text content, add it as is if any(item["type"] != "text" for item in message["content"]): @@ -265,7 +276,7 @@ def preprocess_messages(prompts_messages: list[list[dict[str, Any]]], completion current_message = None merged_completion.append(message) continue - + # For text messages if current_message and current_message["role"] == message["role"]: # Merge text content @@ -277,11 +288,21 @@ def preprocess_messages(prompts_messages: list[list[dict[str, Any]]], completion merged_completion.append(current_message) current_message = { "role": message["role"], - "content": [{"type": "text", "text": message["content"][0]["text"]}] + "content": [ + {"type": "text", "text": message["content"][0]["text"]} + ], } - + if current_message: merged_completion.append(current_message) merged_completions.append(merged_completion) - - return merged_completions \ No newline at end of file + + return merged_completions + + def log_metrics(self, data): + """ + Callback for logging metrics. Can be implemented by subclasses. + + Should return a dictionary of metrics (key = metric name, value = metric value) + """ + return {} diff --git a/src/r1_vlm/environments/simple_vision_env.py b/src/r1_vlm/environments/simple_vision_env.py index 7e269bd6..5f1465a9 100644 --- a/src/r1_vlm/environments/simple_vision_env.py +++ b/src/r1_vlm/environments/simple_vision_env.py @@ -4,12 +4,12 @@ import imgcat from qwen_vl_utils import process_vision_info +from verifiers import SimpleEnv from vllm import LLM, SamplingParams # type: ignore from r1_vlm.budget_forcing.budget_forcing import ( generate_completions_with_budget_forcing, ) -from verifiers import SimpleEnv class SimpleVisionEnv(SimpleEnv): @@ -93,7 +93,7 @@ def generate( completions = vlm.generate( vlm_inputs, sampling_params=custom_sp, use_tqdm=False ) # type: ignore - + stop_reasons = [c.outputs[0].stop_reason for c in completions] print(f"Stop reasons: {stop_reasons}") @@ -166,6 +166,14 @@ def prepare_data(self, *, inputs, processing_class): return conversations, texts, batch, vllm_inputs + def log_metrics(self, data): + """ + Callback for logging metrics. Can be implemented by subclasses. + + Should return a dictionary of metrics (key = metric name, value = metric value) + """ + return {} + def prepare_inputs_for_env(*, inputs, processing_class): """ diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/eval.py b/src/r1_vlm/environments/tool_use_aokvqa_env/eval.py index 635930e1..76ca7032 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/eval.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/eval.py @@ -1,7 +1,9 @@ import json import os import re +from copy import deepcopy +from datasets import Dataset from imgcat import imgcat from tqdm import tqdm from transformers import AutoProcessor @@ -19,78 +21,78 @@ def extract_answer(generation: str): return None -def main(): - checkpoint = "/millcreek/home/sunil/r1_vlm/vlm-r1-zoom-only-reward-refactor-oversampling/checkpoint-700" - processor = AutoProcessor.from_pretrained(checkpoint, padding_side="left") - vf_env = AOKVQAToolEnv(processing_class=processor) - train_dataset, val_dataset, test_dataset = vf_env.get_dataset() - - if not os.path.exists("generations.json"): - vlm = LLM( - model=checkpoint, - gpu_memory_utilization=1.0, - dtype="bfloat16", - tensor_parallel_size=2, - enable_prefix_caching=True, - limit_mm_per_prompt={"image": 2, "video": 0}, - ) +def generate_completions( + checkpoint_path: str, file_path: str, dataset: Dataset, env, processor +): + """ + Generate completions given a checkpoint and a file path to save the generations + """ + if os.path.exists(file_path): + raise ValueError(f"File {file_path} already exists") + + vlm = LLM( + model=checkpoint_path, + gpu_memory_utilization=1.0, + dtype="bfloat16", + tensor_parallel_size=4, + enable_prefix_caching=True, + limit_mm_per_prompt={"image": 2, "video": 0}, + ) + + sampling_params = SamplingParams( + temperature=0.1, + max_tokens=2048, + ) + + batch_size = 24 + batches = [] + + for example in dataset: + if len(batches) == 0: + batches.append([example]) + elif len(batches[-1]) < batch_size: + batches[-1].append(example) + else: + batches.append([example]) - sampling_params = SamplingParams( - temperature=0.1, - max_tokens=2048, + generations = [] + for batch in tqdm(batches, desc="Generating completions"): + conversations, texts, processed_batch, vllm_inputs = env.prepare_data( + inputs=batch, processing_class=processor ) - batch_size = 6 - batches = [] - - for example in val_dataset: - if len(batches) == 0: - batches.append([example]) - elif len(batches[-1]) < batch_size: - batches[-1].append(example) - else: - batches.append([example]) - - generations = [] - for batch in tqdm(batches, desc="Generating completions"): - conversations, texts, processed_batch, vllm_inputs = vf_env.prepare_data( - inputs=batch, processing_class=processor - ) - - completion_ids = vf_env.generate( - conversations=conversations, - vlm_inputs=vllm_inputs, - vlm=vlm, - sampling_params=sampling_params, - ) + completion_ids = env.generate( + conversations=conversations, + vlm_inputs=vllm_inputs, + vlm=vlm, + sampling_params=sampling_params, + ) - generated_texts = processor.batch_decode( - completion_ids["ids"], - skip_special_tokens=False, - clean_up_tokenization_spaces=False, - ) + generated_texts = processor.batch_decode( + completion_ids["ids"], + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) - print(generated_texts) + for example, generation in zip(batch, generated_texts): + data = { + "question_id": example["question_id"], + "question": example["question"], + "options": example["choices"], + "rationales": example["rationales"], + "gt_answer": example["multiple_choice_answer"], + "generation": generation, + "model_answer": extract_answer(generation), + } + generations.append(data) - for example, generation in zip(batch, generated_texts): - data = { - "question_id": example["question_id"], - "question": example["question"], - "options": example["choices"], - "rationales": example["rationales"], - "gt_answer": example["multiple_choice_answer"], - "generation": generation, - "model_answer": extract_answer(generation), - } - generations.append(data) + with open(file_path, "w") as f: + json.dump(generations, f, indent=2) - # Save the generations list as a JSON array to a file - with open("generations.json", "w") as f: - json.dump(generations, f, indent=2) # Use indent for readability (optional) - else: - with open("generations.json", "r") as f: - generations = json.load(f) +def evaluate(generations_dict: dict, dataset: Dataset): + with open(generations_dict, "r") as f: + generations = json.load(f) generations_dict = {} for generation in generations: @@ -101,7 +103,8 @@ def main(): total = 0 correct = 0 in_option_set = 0 - for example in val_dataset: + + for example in dataset: question_id = example["question_id"] if question_id not in generations_dict: @@ -131,9 +134,53 @@ def main(): imgcat(example["image"]) print("--------------------------------") - print(f"Accuracy: {correct / total}") - print(f"In option set: {in_option_set / total}") + results = { + "accuracy": correct / total, + "in_option_set": in_option_set / total, + } + + print(f"Accuracy: {results['accuracy']}") + print(f"In option set: {results['in_option_set']}") + + return results if __name__ == "__main__": - main() + checkpoints_folder = "/millcreek/home/sunil/r1_vlm/vlm-r1-od-tool-fixed-reward-schedule-for-tools-apr-30" + + checkpoint_paths = [ + os.path.join(checkpoints_folder, f) + for f in os.listdir(checkpoints_folder) + if os.path.isdir(os.path.join(checkpoints_folder, f)) + ] + + checkpoints_to_eval = ["150", "350", "600"] + + checkpoint_paths = [ + path + for path in checkpoint_paths + if any(num in path for num in checkpoints_to_eval) + ] + + processor = AutoProcessor.from_pretrained(checkpoint_paths[0], padding_side="left") + env = AOKVQAToolEnv(processing_class=processor) + train_dataset, val_dataset, test_dataset = env.get_dataset() + + results_dict = {} + + # we'll save evaluations to the same folder as the checkpoints + for checkpoint_path in checkpoint_paths: + file_path = os.path.join( + checkpoints_folder, f"{checkpoint_path}_generations.json" + ) + if not os.path.exists(file_path): + generate_completions( + checkpoint_path, file_path, deepcopy(val_dataset), env, processor + ) + else: + print(f"Skipping {checkpoint_path} because it already exists") + + results = evaluate(file_path, deepcopy(val_dataset)) + results_dict[checkpoint_path] = results + + print(results_dict) diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py index dee5a949..4cafc653 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py @@ -4,11 +4,11 @@ from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl from peft import LoraConfig, TaskType from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration - -from r1_vlm.environments.tool_use_aokvqa_env.tool_use_aokvqa_env import AOKVQAToolEnv from trl import GRPOConfig, ModelConfig from trl.trainer.qwen_grpo_trainer import QwenGRPOTrainer +from r1_vlm.environments.tool_use_aokvqa_env.tool_use_aokvqa_env import AOKVQAToolEnv + os.environ["WANDB_ENTITY"] = "groundlightai" os.environ["WANDB_PROJECT"] = "tool-use-aokvqa-env" @@ -92,8 +92,14 @@ def train(): load_model_and_processor(gradient_checkpointing=True, use_peft=False) ) print("loaded model") + num_generations = 6 - vf_env = AOKVQAToolEnv(processing_class=processor, max_steps=3) + vf_env = AOKVQAToolEnv( + processing_class=processor, + max_steps=3, + num_generations=num_generations, + use_combined_tool_correctness_reward=True, + ) print("loaded env") @@ -106,7 +112,7 @@ def train(): training_args = GRPOConfig( model_init_kwargs=model_config, # save path on the runpod instance - output_dir="vlm-r1-new-zoom-tool-reward-independent-oversampling", + output_dir="vlm-r1-new-fancy-tool-aligned-reward-may1", # increase learning rate for PEFT - 1e-4 learning_rate=1e-4 if peft_config is not None else 1e-6, max_grad_norm=1.0, @@ -115,8 +121,8 @@ def train(): warmup_steps=10, logging_steps=1, save_steps=50, - save_total_limit=5, - num_train_epochs=10, + save_total_limit=10, + num_train_epochs=1, per_device_train_batch_size=2, num_generations=6, gradient_accumulation_steps=4, diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py index 67f46c95..44085222 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py @@ -1,18 +1,30 @@ +import re from typing import Any, Callable from datasets import Dataset from transformers import AutoProcessor +from trl.trainer.grpo_trainer import RewardFunc +from verifiers.parsers import XMLParser from r1_vlm.datasets.aok_vqa.aok_vqa_mc_tool_use_r1 import ( create_r1_aok_vqa_tool_use_dataset, ) from r1_vlm.datasets.utils import preprocess_r1_dataset from r1_vlm.environments.multistep_vision_env import MultistepVisionEnv +from r1_vlm.environments.reward_schedules import create_linear_decay_schedule from r1_vlm.environments.tool_vision_env import ToolArgParser, ToolVisionEnv +from r1_vlm.tools.object_detection import ( + ObjectDetectionTool, + detect_objects, + parse_detect_objects_args, + set_object_detection_tool, +) from r1_vlm.tools.tool_prompts import SINGLE_TOOL_PROMPT_TEMPLATE from r1_vlm.tools.zoom import parse_zoom_args, zoom -from trl.trainer.grpo_trainer import RewardFunc -from verifiers.parsers import XMLParser + +# This is a global variable that is used to store the object detection tool. It is accessed by the detect_objects function. +od_tool = ObjectDetectionTool() +set_object_detection_tool(od_tool) class AOKVQAToolEnv(ToolVisionEnv): @@ -21,11 +33,13 @@ def __init__( processing_class: AutoProcessor, dataset_name: str = "Groundlight/real-iad-toy-brick-tool-use-r1", tools_with_parsers: list[tuple[Callable, ToolArgParser]] = [ - # (detect_objects, parse_detect_objects_args), + (detect_objects, parse_detect_objects_args), (zoom, parse_zoom_args), ], max_steps: int = 3, tool_prompt_template: str = SINGLE_TOOL_PROMPT_TEMPLATE, + num_generations: int = 6, + use_combined_tool_correctness_reward: bool = False, ): super().__init__( processing_class=processing_class, @@ -41,6 +55,8 @@ def __init__( ("answer", ["answer"]), ("tool", ["tool"]), ] + self.num_generations = num_generations + self.use_combined_tool_correctness_reward = use_combined_tool_correctness_reward def parse(self, text: str, strip: bool = True): return self.parser.parse(text, strip=strip) @@ -103,12 +119,22 @@ def get_reward_weights(self) -> list[float]: schedule = 0.1 reward_weights.append(schedule) elif reward_function.__name__ == "tool_execution_reward_func": - schedule = 0.1 + # linearly decay from 1.0 to 0.0 over 200 global steps (200 gradient updates) + schedule = ( + create_linear_decay_schedule(1.0, 0.0, 200) + if not self.use_combined_tool_correctness_reward + # quick burst of reward for tool use at the beginning to teach the model to use tools + else create_linear_decay_schedule(1.0, 0.0, 50) + ) reward_weights.append(schedule) elif reward_function.__name__ == "correct_answer_reward_func": # consistent high reward for getting the answer right - schedule = 1.0 + schedule = 1.0 if not self.use_combined_tool_correctness_reward else 0.0 + reward_weights.append(schedule) + elif reward_function.__name__ == "combined_tool_correctness_reward_func": + # consistent high reward for getting the answer right + schedule = 1.0 if self.use_combined_tool_correctness_reward else 0.0 reward_weights.append(schedule) else: raise ValueError( @@ -220,6 +246,17 @@ def check_format(trajectory): return [check_format(m) for m in merged_completion_conversations] + def check_tool_use_attempt(conversation) -> bool: + """ + Returns True if the model attempts to use any tool. + """ + for i, message in enumerate(conversation): + if message["role"] == "assistant": + parsed = self.parser.parse(message["content"][0]["text"]) + if hasattr(parsed, "tool") and parsed.tool is not None: + return True + return False + def check_execution(conversation): """ Returns the ratio of successful tool executions to total attempts. @@ -291,13 +328,120 @@ def correct_answer_reward_func( return [1.0 if result else 0.0 for result in correctness_results] + def combined_tool_correctness_reward_func( + prompts, completions, completions_messages, **kwargs + ) -> list[float]: + """ + Reward function that checks if tools were executed successfully only if tool use is necessary to answer the question. + """ + if self.num_generations != len(prompts) or self.num_generations != len( + completions_messages + ): + raise ValueError( + f"Expected num_generations to be equal to the number of prompts and completions, but got num_generations={self.num_generations}, len(prompts)={len(prompts)}, len(completions_messages)={len(completions_messages)}" + ) + merged_completion_conversations = MultistepVisionEnv.preprocess_messages( + prompts_messages=prompts, completions_messages=completions_messages + ) + + # For each response sampled, check if the completion has the correct answer + correct_answers = kwargs["multiple_choice_answer"] + correctness_results: list[bool] = [ + check_correctness(conv, correct_answer) + for conv, correct_answer in zip( + merged_completion_conversations, correct_answers + ) + ] + # For each response sampled, check if the any tool use is correct + tool_use_correctness: list[bool] = [ + check_execution(conv) > 0.0 for conv in merged_completion_conversations + ] + # For each response sampled, check if the model attempts to use any tool + tool_use_attempts: list[bool] = [ + check_tool_use_attempt(conv) for conv in merged_completion_conversations + ] + + # For all responses sampled, check if there is a completion that has the correct answer successfully using a tool + correct_with_tool = any( + correctness_results[i] and tool_use_correctness[i] + for i in range(len(correctness_results)) + ) + # For all responses sampled, check if there is a completion that has the correct answer without successfully using a tool + correct_without_tool = any( + correctness_results[i] and not tool_use_correctness[i] + for i in range(len(correctness_results)) + ) + + if correct_without_tool: + # If the question is answerable without using a tool, the model will be penalized for using a tool + rewards = [ + 0.0 + if not correctness_results[i] + else (0.5 if tool_use_attempts[i] else 1.0) + for i in range(len(correctness_results)) + ] + elif correct_with_tool: + # The model is only rewarded if the tool use used correctly, AND the answer is correct + rewards = [ + 1.0 if correctness_results[i] and tool_use_correctness[i] else 0.0 + for i in range(len(correctness_results)) + ] + else: + # The model is not rewarded for any incorrect responses + rewards = [0.0 for _ in range(len(correctness_results))] + return rewards + return [ format_reward_func, tool_execution_reward_func, correct_answer_reward_func, + combined_tool_correctness_reward_func, ] + def log_metrics(self, conversations, completions_text, completion_messages): + # 1. compute how many completions attempt to use any tool + # 2. for each tool, compute how many completions attempt to use it + + completions_with_tool_use = 0 + completions_with_zoom_use = 0 + completions_with_detect_objects_use = 0 + + for completion in completions_text: + tool_use_regex = r"(.*?)" + zoom_use_string = "name: zoom" + detect_objects_use_string = "name: detect_objects" + + tool_matches = re.findall(tool_use_regex, completion, re.DOTALL) + if tool_matches: + completions_with_tool_use += 1 + for tool_content in tool_matches: + if zoom_use_string in tool_content: + completions_with_zoom_use += 1 + if detect_objects_use_string in tool_content: + completions_with_detect_objects_use += 1 + + print( + f"There are {len(completions_text)} completions, {completions_with_tool_use} of which attempt to use a tool, {completions_with_zoom_use} of which attempt to use zoom, and {completions_with_detect_objects_use} of which attempt to use detect_objects" + ) + + num_completions = len(completions_text) + tool_use_proportion = completions_with_tool_use / num_completions + zoom_use_proportion = completions_with_zoom_use / num_completions + detect_objects_use_proportion = ( + completions_with_detect_objects_use / num_completions + ) + + return { + "tool_use_proportion": tool_use_proportion, + "zoom_use_proportion": zoom_use_proportion, + "detect_objects_use_proportion": detect_objects_use_proportion, + } + if __name__ == "__main__": env = AOKVQAToolEnv(processing_class=None) train_dataset, val_dataset, test_dataset = env.get_dataset() + import ipdb + + ipdb.set_trace() + print("hi") diff --git a/src/r1_vlm/environments/tool_vision_env.py b/src/r1_vlm/environments/tool_vision_env.py index f309a2e4..34d3472f 100644 --- a/src/r1_vlm/environments/tool_vision_env.py +++ b/src/r1_vlm/environments/tool_vision_env.py @@ -1,4 +1,5 @@ import inspect +import random import traceback from typing import Any, Callable, Dict, List @@ -147,12 +148,8 @@ def __init__( # Schema inference still uses the tool function's signature/docstring self.tool_schemas.append(infer_schema_from_function(tool_func)) - # Format the system prompt with tool descriptions - tool_descriptions = format_tool_descriptions(self.tool_schemas) - formatted_prompt = tool_prompt_template.format( - tool_descriptions=tool_descriptions - ) - self.formatted_prompt = formatted_prompt + # Store the template for dynamic formatting later + self.tool_prompt_template = tool_prompt_template # Set the general parser (use internal default if none provided) self.general_parser = general_parser or self._general_parse_key_value @@ -188,11 +185,23 @@ def _inject_prompt(examples): if not messages or messages[0]["role"] != "system": raise ValueError("Expected first message to be a system message") + # Create a shuffled copy of tool schemas for this sample + shuffled_schemas = self.tool_schemas[:] # Create a copy + random.shuffle(shuffled_schemas) + + # Format tool descriptions with the shuffled order + tool_descriptions = format_tool_descriptions(shuffled_schemas) + + # Format the prompt template with the randomized descriptions + formatted_prompt = self.tool_prompt_template.format( + tool_descriptions=tool_descriptions + ) + # Replace the content of the system message with the formatted prompt messages[0]["content"] = [ { "type": "text", - "text": self.formatted_prompt, + "text": formatted_prompt, } ] diff --git a/src/r1_vlm/tools/object_detection.py b/src/r1_vlm/tools/object_detection.py index af521e8e..28c173fc 100644 --- a/src/r1_vlm/tools/object_detection.py +++ b/src/r1_vlm/tools/object_detection.py @@ -1,35 +1,142 @@ -import base64 -import io +import base64 # For encoding/decoding images +import io # For handling image bytes import json import os -import time # Import the time module +import time -# Add imports for numpy and cv2 -import cv2 -import numpy as np -import pytest -import requests +import requests # To make HTTP requests to the API server + +# Remove multiprocessing imports +# from multiprocessing import Pipe, Process +# Remove YOLO import from here +# from ultralytics import YOLO +# Add imports for numpy and cv2 (if still needed for other parts, unlikely now) from dotenv import load_dotenv -from imgcat import imgcat from PIL import Image from r1_vlm.environments.tool_vision_env import RawToolArgs, TypedToolArgs load_dotenv() -API_IP = str(os.getenv("API_IP")) -API_PORT = int(os.getenv("API_PORT")) +# --- Configuration for the Detection API Server --- +# Get the API server's URL from environment variables, default to localhost:8001 +DETECTION_API_HOST = os.getenv("DETECTION_API_HOST", "localhost") +DETECTION_API_PORT = int(os.getenv("DETECTION_API_PORT", 8001)) +DETECTION_API_URL = f"http://{DETECTION_API_HOST}:{DETECTION_API_PORT}/detect" +# --- End Configuration --- + +_object_detection_tool = None + + +class ObjectDetectionTool: + def __init__(self): + # Store the URL for the detection API server + self.api_url = DETECTION_API_URL + + def detect_objects(self, image: Image.Image) -> dict: + """Sends image to detection API server and returns results.""" + t_client_start = time.time() + annotated_image = None # Default + dets_string = "Error: Detection failed." # Default error message + + try: + # 1. Prepare Image for Sending + buffer = io.BytesIO() + # Save image to buffer in a common format like PNG + image.save(buffer, format="PNG") + img_bytes = buffer.getvalue() + img_base64 = base64.b64encode(img_bytes).decode("utf-8") + t_encoded = time.time() + + # 2. Prepare Request Payload + payload = {"image_base64": img_base64} + + # 3. Call the API Server + response = requests.post( + self.api_url, + json=payload, + headers={"Content-Type": "application/json"}, + timeout=60.0, # Set a reasonable timeout (e.g., 60 seconds) + ) + t_responded = time.time() + + # 4. Process Response + if response.status_code == 200: + try: + response_data = response.json() + dets_string = response_data.get( + "text_data", "Error: Missing text_data in response." + ) + image_data_base64 = response_data.get("image_data_base64") + + if image_data_base64: + try: + annotated_bytes = base64.b64decode(image_data_base64) + annotated_image = Image.open(io.BytesIO(annotated_bytes)) + except Exception as img_err: + raise ValueError( + f"Failed to decode/load annotated image from response: {img_err}" + ) + # Keep annotated_image as None + + except json.JSONDecodeError as json_err: + raise ValueError( + f"Failed to decode JSON response from API: {json_err}" + ) + + except Exception as proc_err: # Catch other errors processing response + raise ValueError( + f"Error processing successful API response: {proc_err}" + ) + + else: + # Handle HTTP errors + error_msg = f"Error from detection API: {response.status_code}" + try: + error_detail = response.json().get("detail", response.text) + error_msg += f" - {error_detail}" + except json.JSONDecodeError: + error_msg += f" - {response.text}" + raise ValueError(error_msg) + + except requests.exceptions.Timeout: + raise ValueError("Request to detection API timed out after 60s.") + except requests.exceptions.RequestException as req_err: + raise ValueError(f"Request to detection API failed: {req_err}") + + except Exception as e: + # Catch-all for other unexpected errors in the client logic + raise ValueError( + f"Unexpected error in detect_objects client: {e}", exc_info=True + ) + + t_client_end = time.time() + print( + f"detect_objects client timings (s): " + f"Encode: {t_encoded - t_client_start:.3f}, " + f"API Call: {t_responded - t_encoded:.3f}, " + f"Decode/Process: {t_client_end - t_responded:.3f}, " + f"Total: {t_client_end - t_client_start:.3f}" + ) + + return {"text_data": dets_string, "image_data": annotated_image} + + def __del__(self): + """Cleanup method - nothing persistent to clean up""" + pass + + +def set_object_detection_tool(tool: ObjectDetectionTool): + global _object_detection_tool + _object_detection_tool = tool -def detect_objects( - image_name: str, classes: list[str], **kwargs -) -> tuple[list[dict], Image.Image]: +def detect_objects(image_name: str, **kwargs) -> tuple[list[dict], Image.Image]: """ - Calls an open vocabulary object detection model on the image. Useful for localizing objects in an image or determining if an object is present. + Calls an object detection model on the image. Useful for localizing objects in an image or determining if an object is present. Args: image_name: str, the name of the image to detect objects in. Can only be called on the "input_image" image. - classes: list[str], the classes to detect. As the model is open vocabulary, your classes can be any object you want to detect in the image. Each class should contain an noun for best results. Returns: 1. A list of dictionaries, each containing the following keys: @@ -41,12 +148,6 @@ def detect_objects( name: detect_objects image_name: input_image - classes: ["car", "person", "train", "bus"] - - - name: detect_objects - image_name: input_image - classes: ["elephant", "white jeep", "tree", "water"] """ @@ -64,108 +165,33 @@ def detect_objects( f"Error: Image {image_name} is not the input_image. This tool can only be called on the input_image." ) - # construct the API request - # I decided to fix the confidence threshold at 0.10, as the model tends to set this value very high, which leads to a lot of false negatives - url = f"http://{API_IP}:{API_PORT}/detect?confidence={0.10}" - - # Convert PIL Image to bytes - img_byte_arr = io.BytesIO() - image.save(img_byte_arr, format="JPEG") - img_byte_arr = img_byte_arr.getvalue() - - files = {"image": img_byte_arr} - data = {} - for c in classes: - data.setdefault("classes", []).append(c) - - # send the request - start_time = time.time() # Record start time - response = requests.post(url, files=files, data=data) - end_time = time.time() # Record end time - print(f"API call took {end_time - start_time:.2f} seconds") # Print duration - - if response.status_code == 200: - result = response.json() - else: - raise Exception( - f"Error: API request failed with status code {response.status_code}" + if _object_detection_tool is None: + raise RuntimeError( + "ObjectDetectionTool not initialized. Call set_object_detection_tool first." ) - detections = result["results"]["detections"] - - dets = [] - for detection in detections: - dets.append( - { - "bbox_2d": detection["bbox_2d"], - "label": detection["label"], - } - ) - - if len(dets) == 0: - dets_string = "No objects detected." - annotated_image = None - else: - dets_string = "" - for index, det in enumerate(dets): - dets_string += f"{index + 1}. {det}" - - if index < len(dets) - 1: - dets_string += "\n" - - # convert the annotated image(base64 encoded) to a PIL Image only if detections exist - annotated_image_data = base64.b64decode(result["annotated_image"]) - annotated_image_pil = Image.open(io.BytesIO(annotated_image_data)) - - # Convert PIL Image to NumPy array (OpenCV format) - # PIL images with mode 'RGB' are loaded as NumPy arrays with shape (H, W, 3) in RGB order. - # PIL images with mode 'RGBA' are loaded as NumPy arrays with shape (H, W, 4) in RGBA order. - annotated_image_np = np.array(annotated_image_pil) - - # Convert BGR(A) to RGB(A) using OpenCV if it's a color image - # Assuming the source API sent BGR/BGRA data, which np.array converted retaining channel order relative to PIL's interpretation. - # If PIL interpreted as RGB, the np array is RGB. If RGBA, the np array is RGBA. - # Since the *source* was BGR/BGRA, we convert the numpy array from BGR/BGRA to RGB/RGBA. - if annotated_image_np.ndim == 3 and annotated_image_np.shape[2] == 3: # RGB/BGR - annotated_image_np_rgb = cv2.cvtColor(annotated_image_np, cv2.COLOR_BGR2RGB) - elif ( - annotated_image_np.ndim == 3 and annotated_image_np.shape[2] == 4 - ): # RGBA/BGRA - annotated_image_np_rgb = cv2.cvtColor( - annotated_image_np, cv2.COLOR_BGRA2RGBA - ) - else: - # Grayscale or other formats, no conversion needed - annotated_image_np_rgb = annotated_image_np - - # Convert NumPy array back to PIL Image - annotated_image = Image.fromarray(annotated_image_np_rgb) - - # Return None for image_data if no detections were found - return {"text_data": dets_string, "image_data": annotated_image} + # Call the method which now calls the API + return _object_detection_tool.detect_objects(image) # Return type is now dict def parse_detect_objects_args(raw_args: RawToolArgs) -> TypedToolArgs: """ - Parses raw string arguments for the detect_objects tool, focusing on type conversion. + Parses raw string arguments for the detect_objects tool. - Expects keys: 'name', 'image_name', 'classes'. - Converts 'classes' from a JSON string representing a list of strings. - Detailed validation of values (e.g., 'image_name' validity, 'classes' content) + Expects keys: 'name', 'image_name' + Detailed validation of values (e.g., 'image_name' validity) is deferred to the detect_objects function itself. Args: raw_args: Dictionary with string keys and string values from the general parser. Returns: - A dictionary containing the arguments with basic type conversions applied, - ready for the detect_objects function. Keys: 'image_name', 'classes'. + A dictionary containing the arguments. Keys: 'image_name'. Raises: - ValueError: If required keys are missing or basic type conversion fails - (e.g., 'classes' is not valid JSON). + ValueError: If required keys are missing or extra keys are present. """ - required_keys = {"name", "image_name", "classes"} + required_keys = {"name", "image_name"} actual_keys = set(raw_args.keys()) # 1. Check for Missing Keys @@ -182,27 +208,12 @@ def parse_detect_objects_args(raw_args: RawToolArgs) -> TypedToolArgs: f"Error: Unexpected arguments for detect_objects tool: {', '.join(sorted(extra_keys))}" ) - # 3. Perform Basic Type Conversions + # 3. Prepare typed args (only image_name needed) typed_args: TypedToolArgs = {} try: # Keep image_name as string typed_args["image_name"] = raw_args["image_name"] - # Convert classes string using json.loads - classes_list = json.loads(raw_args["classes"]) - - # Basic type check - ensure it's a list, defer content check (list of strings) to tool - if not isinstance(classes_list, list): - raise ValueError( - f"Error: Invalid format for 'classes': Expected a JSON list, got type {type(classes_list).__name__}" - ) - - typed_args["classes"] = classes_list - - except json.JSONDecodeError: - raise ValueError( - f"Error: Invalid JSON format for 'classes': '{raw_args['classes']}'" - ) except ValueError as e: # Catch the list type error from above raise ValueError(f"Error: processing 'classes': {e}") @@ -211,46 +222,3 @@ def parse_detect_objects_args(raw_args: RawToolArgs) -> TypedToolArgs: raise ValueError(f"Error: Missing key '{e}' during conversion phase.") return typed_args - - -@pytest.fixture -def sample_image_fixture(): - """Provides a simple dummy image for testing.""" - current_dir = os.path.dirname(os.path.abspath(__file__)) - img = Image.open(os.path.join(current_dir, "cars.jpeg")) - return {"input_image": img} - - -def test_basic_detection_integration(sample_image_fixture): - """Tests basic object detection call against the running API.""" - # Call the function under test - this will make a real HTTP request - # Using classes unlikely to be in a plain red image might be safer - # depending on the actual model behavior. Let's use "object". - try: - result = detect_objects( - image_name="input_image", - # there should be cars, but no dogs - classes=["car", "dog"], - images=sample_image_fixture, - ) - - assert isinstance(result, dict) - assert "text_data" in result - assert "image_data" in result - assert isinstance(result["text_data"], str) - assert isinstance(result["image_data"], Image.Image) - - # visualize the annotated image - annotated_image = result["image_data"] - imgcat(annotated_image) - - # visualize the text data - print(result["text_data"]) - - except requests.exceptions.ConnectionError as e: - pytest.fail( - f"API connection failed. Is the server running at http://{API_IP}:{API_PORT}? Error: {e}" - ) - except Exception as e: - # Catch other potential errors during the API call or processing - pytest.fail(f"An unexpected error occurred: {e}") diff --git a/src/r1_vlm/tools/tool_prompts.py b/src/r1_vlm/tools/tool_prompts.py index ea300bd1..c8178258 100644 --- a/src/r1_vlm/tools/tool_prompts.py +++ b/src/r1_vlm/tools/tool_prompts.py @@ -37,7 +37,7 @@ {tool_descriptions} For each step: -1. Start by thinking through your reasoning inside tags. Then either return your answer inside tags, or use a tool inside tags. You are not required to use a tool if you can answer the question without one. +1. Start by thinking through your reasoning inside tags. Then either return your answer inside tags, or use a tool inside tags. 2. If needed, use a tool by writing its arguments inside tags. Use one line for each argument in the format 'key: value'. The first line must be 'name: '. 3. You will see the tool's output inside tags. 4. Continue until you can give the final answer inside tags. @@ -45,4 +45,6 @@ Tools expect specific arguments. Follow the examples carefully for the required keys and expected value formats. Do not make up tools or arguments that aren't listed. If the tool includes the argument "image_name", you must provide it the name of an image from this conversation. + +As a reminder, you are not required to use a tool if you can answer the user's question without one. """ diff --git a/tool_server/infer.py b/tool_server/infer.py index 45aff7a2..aa3c7264 100644 --- a/tool_server/infer.py +++ b/tool_server/infer.py @@ -1,74 +1,47 @@ import contextlib +import os import time import numpy as np +from dotenv import load_dotenv from imgcat import imgcat from PIL import Image from tritonclient.http import InferenceServerClient from ultralytics import YOLO +load_dotenv() + +API_IP = str(os.getenv("API_IP")) +API_PORT = int(os.getenv("API_PORT")) +url = f"{API_IP}:{API_PORT}/yolo" +print(url) + # Wait for the Triton server to start -triton_client = InferenceServerClient(url="localhost:8000", verbose=False, ssl=False) +triton_client = InferenceServerClient(url=url, verbose=False, ssl=False) # Wait until model is ready for _ in range(10): with contextlib.suppress(Exception): + print("checking if model is ready") assert triton_client.is_model_ready("yolo") + print("model is ready") break time.sleep(1) - +print("loading model") # Load the Triton Server model -model = YOLO("http://localhost:8000/yolo", task="detect") +model = YOLO(f"http://{url}", task="detect") # load the image via PIL img = Image.open( "/millcreek/home/sunil/r1_vlm_bumbershoot0/r1_vlm/tool_server/cars.jpeg" ) -# create 10 noisy copies and their crops -test_images = [] -crop_ratios = [(2, 1), (1, 1), (1, 2)] - -for i in range(10): - # Create noisy image - arr = np.array(img) - noise = np.random.normal(0, 5, arr.shape) - noisy_arr = np.clip(arr + noise, 0, 255).astype(np.uint8) - noisy_img = Image.fromarray(noisy_arr) - - # Create crops for this noisy image - img_w, img_h = noisy_img.size - for w_ratio, h_ratio in crop_ratios: - if img_w / img_h > w_ratio / h_ratio: - crop_h = img_h - crop_w = int(crop_h * w_ratio / h_ratio) - else: - crop_w = img_w - crop_h = int(crop_w * h_ratio / w_ratio) - - x0 = np.random.randint(0, img_w - crop_w + 1) - y0 = np.random.randint(0, img_h - crop_h + 1) - cropped = noisy_img.crop((x0, y0, x0 + crop_w, y0 + crop_h)) - test_images.append( - {"image": cropped, "ratio": f"{w_ratio}:{h_ratio}", "noise_id": i} - ) -speeds = [] - -# run inference on each variant -for test_case in test_images: - start = time.time() - results = model(test_case["image"]) # Pass the cropped image - end = time.time() - speeds.append(end - start) - print( - f"Noise #{test_case['noise_id']}, Aspect ratio {test_case['ratio']} – time taken: {end - start} seconds" - ) - # Convert the cropped image to numpy for visualization - vis_img = np.array(test_case["image"]) - # Plot directly on the cropped image - plotted = results[0].plot(img=vis_img) - imgcat(Image.fromarray(plotted)) +results = model(img) # Pass the cropped image -print(speeds) +# Convert the cropped image to numpy for visualization +vis_img = np.array(img) +# Plot directly on the cropped image +plotted = results[0].plot(img=vis_img) +imgcat(Image.fromarray(plotted)) diff --git a/tool_server/pyproject.toml b/tool_server/pyproject.toml index b4241c49..a2bb6f18 100644 --- a/tool_server/pyproject.toml +++ b/tool_server/pyproject.toml @@ -18,6 +18,9 @@ dependencies = [ "ipdb>=0.13.13", "numpy>=1.26.4", "pillow>=11.2.1", + "python-dotenv>=1.1.0", + "uvicorn>=0.34.2", + "fastapi>=0.115.12", ] [tool.uv.sources] diff --git a/tool_server/start_server.py b/tool_server/start_server.py index 1a7ebc04..6134e0de 100644 --- a/tool_server/start_server.py +++ b/tool_server/start_server.py @@ -20,7 +20,7 @@ container_id = ( subprocess.check_output( # Use the absolute path here - f"docker run -d --gpus 0 -v {absolute_triton_repo_path}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models", + f"docker run -d --gpus 0 -v {absolute_triton_repo_path}:/models -p 0.0.0.0:8000:8000 {tag} tritonserver --model-repository=/models", shell=True, ) .decode("utf-8") diff --git a/tool_server/training_server.py b/tool_server/training_server.py new file mode 100644 index 00000000..d935d5b3 --- /dev/null +++ b/tool_server/training_server.py @@ -0,0 +1,211 @@ +import base64 +import io +import logging +import os +import time +from typing import Optional + +# --- Force CPU Usage for this Server --- +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +logger = logging.getLogger(__name__) # Get logger early for info message +logger.info("CUDA_VISIBLE_DEVICES set to -1. Server will attempt to use CPU for YOLO.") +# --- End Force CPU --- + +import uvicorn +from dotenv import load_dotenv +from fastapi import FastAPI, HTTPException +from fastapi.responses import JSONResponse +from PIL import Image +from pydantic import BaseModel +from ultralytics import YOLO + +# --- Configuration & Initialization --- + +# Load environment variables (pointing to the *actual* YOLO/Triton backend) +load_dotenv() +API_IP = str(os.getenv("API_IP")) +API_PORT = int(os.getenv("API_PORT")) +BACKEND_URL = f"http://{API_IP}:{API_PORT}/yolo" + +# Set up logging +logging.basicConfig(level=logging.INFO) + +# Global variable to hold the loaded YOLO model +yolo_model: Optional[YOLO] = None + +# --- Pydantic Models --- + + +class DetectionRequest(BaseModel): + """Request body for the /detect endpoint.""" + + image_base64: str + + +class DetectionResponse(BaseModel): + """Successful response body for the /detect endpoint.""" + + text_data: str + image_data_base64: Optional[str] = None + + +class ErrorResponse(BaseModel): + """Error response body.""" + + error: str + + +# --- FastAPI App --- + +app = FastAPI(title="YOLO Detection API Server") + + +@app.on_event("startup") +async def startup_event(): + """Load the YOLO model on server startup (will use CPU).""" + global yolo_model + logger.info( + f"Attempting to load YOLO model targeting backend: {BACKEND_URL} (Forcing CPU)" + ) # Added CPU note + start_time = time.time() + try: + # Model will initialize on CPU due to CUDA_VISIBLE_DEVICES=-1 + yolo_model = YOLO(BACKEND_URL, task="detect") + # Perform a dummy inference to ensure connection/initialization on CPU + dummy_img = Image.new("RGB", (64, 64), color="red") + _ = yolo_model(dummy_img, verbose=False) + end_time = time.time() + logger.info( + f"YOLO model loaded successfully on CPU in {end_time - start_time:.2f} seconds." # Added CPU note + ) + except Exception as e: + logger.error(f"Failed to load YOLO model on startup: {e}", exc_info=True) + yolo_model = None + + +@app.post( + "/detect", + response_model=DetectionResponse, + responses={500: {"model": ErrorResponse}}, + summary="Perform object detection on an image", +) +async def detect_objects_api(request: DetectionRequest): + """ + Accepts a base64 encoded image, performs YOLO detection using the + pre-loaded model targeting the backend service, and returns results. + """ + global yolo_model + if yolo_model is None: + logger.error("YOLO model is not loaded. Cannot process request.") + raise HTTPException( + status_code=503, detail="Model service unavailable" + ) # 503 Service Unavailable + + logger.info("Received detection request.") + t_start = time.time() + + try: + # 1. Decode Base64 Image + try: + image_bytes = base64.b64decode(request.image_base64) + image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Ensure RGB + except Exception as e: + logger.error(f"Failed to decode/load image from base64: {e}") + raise HTTPException(status_code=400, detail=f"Invalid image data: {e}") + t_decoded = time.time() + + # 2. Run YOLO Inference + try: + # Use the globally loaded model + result = yolo_model(image, conf=0.3, verbose=False)[ + 0 + ] # verbose=False is quieter + except Exception as e: + logger.error(f"YOLO inference failed: {e}", exc_info=True) + # Consider more specific error codes if YOLO/Triton provide them + raise HTTPException(status_code=500, detail=f"Inference failed: {e}") + t_inferred = time.time() + + # 3. Process Results + boxes = [[int(round(x)) for x in box] for box in result.boxes.xyxy.tolist()] + labels = [result.names[int(cls)] for cls in result.boxes.cls.tolist()] + detections = [ + {"bbox_2d": box, "label": label} for box, label in zip(boxes, labels) + ] + + # 4. Format Output String + if not detections: + dets_string = "No objects detected." + annotated_image = None + plot_img_array = None + else: + dets_string = "" + for index, det in enumerate(detections): + dets_string += f"{index + 1}. {det}" + if index < len(detections) - 1: + dets_string += "\n" + # Generate annotated image array + plot_img_array = result.plot(conf=False, labels=True) + annotated_image = Image.fromarray(plot_img_array[..., ::-1]) # BGR->RGB + + t_processed = time.time() + + # 5. Encode Annotated Image (if any) + image_data_base64: Optional[str] = None + if annotated_image: + try: + with io.BytesIO() as buffer: + # Save as PNG (generally lossless) + annotated_image.save(buffer, format="PNG") + img_bytes = buffer.getvalue() + image_data_base64 = base64.b64encode(img_bytes).decode("utf-8") + except Exception as e: + logger.error(f"Failed to encode annotated image: {e}") + # Proceed without annotated image if encoding fails + t_encoded = time.time() + + logger.info( + f"Detection successful. Timings (s): " + f"Decode: {t_decoded - t_start:.3f}, " + f"Inference: {t_inferred - t_decoded:.3f}, " + f"Process: {t_processed - t_inferred:.3f}, " + f"Encode: {t_encoded - t_processed:.3f}, " + f"Total: {t_encoded - t_start:.3f}" + ) + + return DetectionResponse( + text_data=dets_string, + image_data_base64=image_data_base64, + ) + + except HTTPException as http_exc: + # Re-raise HTTPExceptions (like 400 Bad Request) + raise http_exc + except Exception as e: + # Catch-all for unexpected server errors during processing + logger.error(f"Unexpected error during detection request: {e}", exc_info=True) + # Return a generic 500 error response + return JSONResponse( + status_code=500, content={"error": f"Internal server error: {e}"} + ) + + +# --- Run Server --- + +if __name__ == "__main__": + # Set default port if not specified in environment + server_port = int(os.getenv("DETECTION_API_PORT", 8001)) + num_workers = int(os.getenv("DETECTION_API_WORKERS", 6)) # Default to 6 workers + logger.info( + f"Starting YOLO detection server on port {server_port} with {num_workers} workers (CPU forced)..." # Added CPU note + ) + # Note: Using uvicorn.run() with workers > 1 might have limitations + # compared to running via the command line with a process manager like gunicorn. + # See Uvicorn documentation for details on multi-process modes. + uvicorn.run( + "training_server:app", # Need to specify app string for reload/workers + host="0.0.0.0", + port=server_port, + workers=num_workers, + # reload=False # Ensure reload is False when using workers programmatically + ) diff --git a/tool_server/uv.lock b/tool_server/uv.lock index a6125dcf..f440049b 100644 --- a/tool_server/uv.lock +++ b/tool_server/uv.lock @@ -62,6 +62,29 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/6a/bc7e17a3e87a2985d3e8f4da4cd0f481060eb78fb08596c42be62c90a4d9/aiosignal-1.3.2-py2.py3-none-any.whl", hash = "sha256:45cde58e409a301715980c2b01d0c28bdde3770d8290b5eb2173759d9acb31a5", size = 7597 }, ] +[[package]] +name = "annotated-types" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 }, +] + +[[package]] +name = "anyio" +version = "4.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna" }, + { name = "sniffio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/95/7d/4c1bd541d4dffa1b52bd83fb8527089e097a106fc90b467a7313b105f840/anyio-4.9.0.tar.gz", hash = "sha256:673c0c244e15788651a4ff38710fea9675823028a6f08a5eda409e0c9840a028", size = 190949 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a1/ee/48ca1a7c89ffec8b6a0c5d02b89c305671d5ffd8d3c94acf8b8c408575bb/anyio-4.9.0-py3-none-any.whl", hash = "sha256:9f76d541cad6e36af7beb62e978876f3b41e3e04f2c1fbf0884604c0a9c4d93c", size = 100916 }, +] + [[package]] name = "asttokens" version = "3.0.0" @@ -150,6 +173,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767 }, ] +[[package]] +name = "click" +version = "8.1.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188 }, +] + [[package]] name = "clip" version = "1.0" @@ -255,6 +290,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702 }, ] +[[package]] +name = "fastapi" +version = "0.115.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f4/55/ae499352d82338331ca1e28c7f4a63bfd09479b16395dce38cf50a39e2c2/fastapi-0.115.12.tar.gz", hash = "sha256:1e2c2a2646905f9e83d32f04a3f86aff4a286669c6c950ca95b5fd68c2602681", size = 295236 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/50/b3/b51f09c2ba432a576fe63758bddc81f78f0c6309d9e5c10d194313bf021e/fastapi-0.115.12-py3-none-any.whl", hash = "sha256:e94613d6c05e27be7ffebdd6ea5f388112e5e430c8f7d6494a9d1d88d43e814d", size = 95164 }, +] + [[package]] name = "filelock" version = "3.18.0" @@ -420,6 +469,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7d/36/0c03e2d80db69e2472cf81c6123aa7d14741de7cf790117291a703ae6ae1/grpcio-1.67.1-cp312-cp312-win_amd64.whl", hash = "sha256:5721e66a594a6c4204458004852719b38f3d5522082be9061d6510b455c90afc", size = 4346574 }, ] +[[package]] +name = "h11" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515 }, +] + [[package]] name = "huggingface-hub" version = "0.30.2" @@ -1079,6 +1137,46 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/13/a3/a812df4e2dd5696d1f351d58b8fe16a405b234ad2886a0dab9183fb78109/pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc", size = 117552 }, ] +[[package]] +name = "pydantic" +version = "2.11.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-types" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/77/ab/5250d56ad03884ab5efd07f734203943c8a8ab40d551e208af81d0257bf2/pydantic-2.11.4.tar.gz", hash = "sha256:32738d19d63a226a52eed76645a98ee07c1f410ee41d93b4afbfa85ed8111c2d", size = 786540 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/12/46b65f3534d099349e38ef6ec98b1a5a81f42536d17e0ba382c28c67ba67/pydantic-2.11.4-py3-none-any.whl", hash = "sha256:d9615eaa9ac5a063471da949c8fc16376a84afb5024688b3ff885693506764eb", size = 443900 }, +] + +[[package]] +name = "pydantic-core" +version = "2.33.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ad/88/5f2260bdfae97aabf98f1778d43f69574390ad787afb646292a638c923d4/pydantic_core-2.33.2.tar.gz", hash = "sha256:7cb8bc3605c29176e1b105350d2e6474142d7c1bd1d9327c4a9bdb46bf827acc", size = 435195 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/8a/2b41c97f554ec8c71f2a8a5f85cb56a8b0956addfe8b0efb5b3d77e8bdc3/pydantic_core-2.33.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a7ec89dc587667f22b6a0b6579c249fca9026ce7c333fc142ba42411fa243cdc", size = 2009000 }, + { url = "https://files.pythonhosted.org/packages/a1/02/6224312aacb3c8ecbaa959897af57181fb6cf3a3d7917fd44d0f2917e6f2/pydantic_core-2.33.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3c6db6e52c6d70aa0d00d45cdb9b40f0433b96380071ea80b09277dba021ddf7", size = 1847996 }, + { url = "https://files.pythonhosted.org/packages/d6/46/6dcdf084a523dbe0a0be59d054734b86a981726f221f4562aed313dbcb49/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e61206137cbc65e6d5256e1166f88331d3b6238e082d9f74613b9b765fb9025", size = 1880957 }, + { url = "https://files.pythonhosted.org/packages/ec/6b/1ec2c03837ac00886ba8160ce041ce4e325b41d06a034adbef11339ae422/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb8c529b2819c37140eb51b914153063d27ed88e3bdc31b71198a198e921e011", size = 1964199 }, + { url = "https://files.pythonhosted.org/packages/2d/1d/6bf34d6adb9debd9136bd197ca72642203ce9aaaa85cfcbfcf20f9696e83/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c52b02ad8b4e2cf14ca7b3d918f3eb0ee91e63b3167c32591e57c4317e134f8f", size = 2120296 }, + { url = "https://files.pythonhosted.org/packages/e0/94/2bd0aaf5a591e974b32a9f7123f16637776c304471a0ab33cf263cf5591a/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:96081f1605125ba0855dfda83f6f3df5ec90c61195421ba72223de35ccfb2f88", size = 2676109 }, + { url = "https://files.pythonhosted.org/packages/f9/41/4b043778cf9c4285d59742281a769eac371b9e47e35f98ad321349cc5d61/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f57a69461af2a5fa6e6bbd7a5f60d3b7e6cebb687f55106933188e79ad155c1", size = 2002028 }, + { url = "https://files.pythonhosted.org/packages/cb/d5/7bb781bf2748ce3d03af04d5c969fa1308880e1dca35a9bd94e1a96a922e/pydantic_core-2.33.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:572c7e6c8bb4774d2ac88929e3d1f12bc45714ae5ee6d9a788a9fb35e60bb04b", size = 2100044 }, + { url = "https://files.pythonhosted.org/packages/fe/36/def5e53e1eb0ad896785702a5bbfd25eed546cdcf4087ad285021a90ed53/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:db4b41f9bd95fbe5acd76d89920336ba96f03e149097365afe1cb092fceb89a1", size = 2058881 }, + { url = "https://files.pythonhosted.org/packages/01/6c/57f8d70b2ee57fc3dc8b9610315949837fa8c11d86927b9bb044f8705419/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:fa854f5cf7e33842a892e5c73f45327760bc7bc516339fda888c75ae60edaeb6", size = 2227034 }, + { url = "https://files.pythonhosted.org/packages/27/b9/9c17f0396a82b3d5cbea4c24d742083422639e7bb1d5bf600e12cb176a13/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5f483cfb75ff703095c59e365360cb73e00185e01aaea067cd19acffd2ab20ea", size = 2234187 }, + { url = "https://files.pythonhosted.org/packages/b0/6a/adf5734ffd52bf86d865093ad70b2ce543415e0e356f6cacabbc0d9ad910/pydantic_core-2.33.2-cp312-cp312-win32.whl", hash = "sha256:9cb1da0f5a471435a7bc7e439b8a728e8b61e59784b2af70d7c169f8dd8ae290", size = 1892628 }, + { url = "https://files.pythonhosted.org/packages/43/e4/5479fecb3606c1368d496a825d8411e126133c41224c1e7238be58b87d7e/pydantic_core-2.33.2-cp312-cp312-win_amd64.whl", hash = "sha256:f941635f2a3d96b2973e867144fde513665c87f13fe0e193c158ac51bfaaa7b2", size = 1955866 }, + { url = "https://files.pythonhosted.org/packages/0d/24/8b11e8b3e2be9dd82df4b11408a67c61bb4dc4f8e11b5b0fc888b38118b5/pydantic_core-2.33.2-cp312-cp312-win_arm64.whl", hash = "sha256:cca3868ddfaccfbc4bfb1d608e2ccaaebe0ae628e1416aeb9c4d88c001bb45ab", size = 1888894 }, +] + [[package]] name = "pygments" version = "2.19.1" @@ -1118,6 +1216,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892 }, ] +[[package]] +name = "python-dotenv" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/88/2c/7bb1416c5620485aa793f2de31d3df393d3686aa8a8506d11e10e13c5baf/python_dotenv-1.1.0.tar.gz", hash = "sha256:41f90bc6f5f177fb41f53e87666db362025010eb28f60a01c9143bfa33a2b2d5", size = 39920 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/18/98a99ad95133c6a6e2005fe89faedf294a748bd5dc803008059409ac9b1e/python_dotenv-1.1.0-py3-none-any.whl", hash = "sha256:d7c01d9e2293916c18baf562d95698754b0dbbb5e74d457c45d4f6561fb9d55d", size = 20256 }, +] + [[package]] name = "python-rapidjson" version = "1.20" @@ -1286,6 +1393,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050 }, ] +[[package]] +name = "sniffio" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 }, +] + [[package]] name = "stack-data" version = "0.6.3" @@ -1300,6 +1416,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521 }, ] +[[package]] +name = "starlette" +version = "0.46.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ce/20/08dfcd9c983f6a6f4a1000d934b9e6d626cff8d2eeb77a89a68eef20a2b7/starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5", size = 2580846 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8b/0c/9d30a4ebeb6db2b25a841afbb80f6ef9a854fc3b41be131d249a977b4959/starlette-0.46.2-py3-none-any.whl", hash = "sha256:595633ce89f8ffa71a015caed34a5b2dc1c0cdb3f0f1fbd1e69339cf2abeec35", size = 72037 }, +] + [[package]] name = "sympy" version = "1.13.1" @@ -1372,6 +1500,7 @@ version = "0.1.0" source = { virtual = "." } dependencies = [ { name = "clip" }, + { name = "fastapi" }, { name = "imgcat" }, { name = "ipdb" }, { name = "mobileclip" }, @@ -1381,14 +1510,17 @@ dependencies = [ { name = "onnxruntime-gpu" }, { name = "onnxslim" }, { name = "pillow" }, + { name = "python-dotenv" }, { name = "tensorrt" }, { name = "tritonclient", extra = ["all"] }, { name = "ultralytics" }, + { name = "uvicorn" }, ] [package.metadata] requires-dist = [ { name = "clip", git = "https://github.com/ultralytics/CLIP.git" }, + { name = "fastapi", specifier = ">=0.115.12" }, { name = "imgcat", specifier = ">=0.6.0" }, { name = "ipdb", specifier = ">=0.13.13" }, { name = "mobileclip", git = "https://github.com/THU-MIG/yoloe.git?subdirectory=third_party%2Fml-mobileclip" }, @@ -1398,9 +1530,11 @@ requires-dist = [ { name = "onnxruntime-gpu", specifier = ">=1.21.1" }, { name = "onnxslim", specifier = ">=0.1.50" }, { name = "pillow", specifier = ">=11.2.1" }, + { name = "python-dotenv", specifier = ">=1.1.0" }, { name = "tensorrt", specifier = ">=10.9.0.34" }, { name = "tritonclient", extras = ["all"], specifier = ">=2.56.0" }, { name = "ultralytics", specifier = ">=8.3.112" }, + { name = "uvicorn", specifier = ">=0.34.2" }, ] [[package]] @@ -1519,6 +1653,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8b/54/b1ae86c0973cc6f0210b53d508ca3641fb6d0c56823f288d108bc7ab3cc8/typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c", size = 45806 }, ] +[[package]] +name = "typing-inspection" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/82/5c/e6082df02e215b846b4b8c0b887a64d7d08ffaba30605502639d44c06b82/typing_inspection-0.4.0.tar.gz", hash = "sha256:9765c87de36671694a67904bf2c96e395be9c6439bb6c87b5142569dcdd65122", size = 76222 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/31/08/aa4fdfb71f7de5176385bd9e90852eaf6b5d622735020ad600f2bab54385/typing_inspection-0.4.0-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f", size = 14125 }, +] + [[package]] name = "tzdata" version = "2025.2" @@ -1576,6 +1722,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6b/11/cc635220681e93a0183390e26485430ca2c7b5f9d33b15c74c2861cb8091/urllib3-2.4.0-py3-none-any.whl", hash = "sha256:4e16665048960a0900c702d4a66415956a584919c03361cac9f1df5c5dd7e813", size = 128680 }, ] +[[package]] +name = "uvicorn" +version = "0.34.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a6/ae/9bbb19b9e1c450cf9ecaef06463e40234d98d95bf572fab11b4f19ae5ded/uvicorn-0.34.2.tar.gz", hash = "sha256:0e929828f6186353a80b58ea719861d2629d766293b6d19baf086ba31d4f3328", size = 76815 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/4b/4cef6ce21a2aaca9d852a6e84ef4f135d99fcd74fa75105e2fc0c8308acd/uvicorn-0.34.2-py3-none-any.whl", hash = "sha256:deb49af569084536d269fe0a6d67e3754f104cf03aba7c11c40f01aadf33c403", size = 62483 }, +] + [[package]] name = "wcwidth" version = "0.2.13" diff --git a/uv.lock b/uv.lock index e925b579..8adcfdb0 100644 --- a/uv.lock +++ b/uv.lock @@ -790,6 +790,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/cf/1f7649b8b9a3543e042d3f348e398a061923ac05b507f3f4d95f11938aa9/cryptography-44.0.2-cp39-abi3-win_amd64.whl", hash = "sha256:5f6f90b72d8ccadb9c6e311c775c8305381db88374c65fa1a68250aa8a9cb3a6", size = 3210957 }, ] +[[package]] +name = "cuda-bindings" +version = "12.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pywin32", marker = "sys_platform == 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/16/621f2ff6e4c6a0c1d57f5a0a373d1fb9d10eb9a7f05052cc64eba2e7dab2/cuda_bindings-12.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0865c9b75ee8f0535044c3f0f06ca34a37131192b573ab59e20a9e058da1ead4", size = 10904424 }, + { url = "https://files.pythonhosted.org/packages/59/11/aee1afd60a5d6af67994dd88697912be22366a6e548e52e6cd2defdbe678/cuda_bindings-12.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0e6a889c87238e6cd55e9b25ce4fd1d90fe2d4169982860fed5f0bc3230795e", size = 11235285 }, + { url = "https://files.pythonhosted.org/packages/c1/c7/eedad18aeb461e9a3c1f8e2ea856caa50202a572b024912cb561f847a054/cuda_bindings-12.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:d0123d841cb3053d227e18b08ea7680d0b5ca64fab4664a2b80b7c83c8edf1ee", size = 11224401 }, + { url = "https://files.pythonhosted.org/packages/4e/82/dc34a092d9111524eea70671d41d72dd3a5452ef70c424680bee1daf9c45/cuda_bindings-12.8.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e264ea93027c7448b9efa134729c12217ca9096051114ee7a9425d49b5a14222", size = 10722116 }, + { url = "https://files.pythonhosted.org/packages/78/f2/b5c3f07f743e74c1f5c42bb2fc6e735f3adac8b526f60ef731d861663dd9/cuda_bindings-12.8.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:099f27e79e754346fa51517168787cda395fb437b31fbf20771c002f30adc0c9", size = 11039795 }, + { url = "https://files.pythonhosted.org/packages/d5/89/d1f3c70651cdeb7c276c0503aea34c1d0c22f8bc66de73887f5ce40c600a/cuda_bindings-12.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:630290148879b47f5e34629ee15061414caaf2f73ea284175a73b30427ad94fd", size = 11190771 }, +] + +[[package]] +name = "cuda-python" +version = "12.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cuda-bindings" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/2c/02bb311b996ffb91d05f8c1fb79131bf50855f7410dd33d09f800fe78c58/cuda_python-12.8.0-py3-none-any.whl", hash = "sha256:3fca3a03c247d6aa1c414989dfe0dd21e9500307b8573f72216ed57d99344c5a", size = 11930 }, +] + [[package]] name = "cupy-cuda12x" version = "13.4.0" @@ -1232,6 +1259,49 @@ http = [ { name = "aiohttp" }, ] +[[package]] +name = "gevent" +version = "25.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "platform_python_implementation == 'CPython' and sys_platform == 'win32'" }, + { name = "greenlet", marker = "platform_python_implementation == 'CPython'" }, + { name = "zope-event" }, + { name = "zope-interface" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/00/e5/a2d9c2d5bfb575973bca7733b23e7f8649f1079c18140a8680a551f3963e/gevent-25.4.2.tar.gz", hash = "sha256:7ffba461458ed28a85a01285ea0e0dc14f883204d17ce5ed82fa839a9d620028", size = 6342241 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/67/3c9a560d3b64510dc053714375b3d9f2c3d98192dc85b78a6e6f8b9a284b/gevent-25.4.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:5940174c7d1ffc7bb4b0ea9f2908f4f361eb03ada9e145d3590b8df1e61c379b", size = 2969979 }, + { url = "https://files.pythonhosted.org/packages/39/ee/594a40e09d9d56b76a04265ea37b825ec8e7b98cd41e8012eda413f233e6/gevent-25.4.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7ae7ad4ff9c4492d4b633702e35153509b07dc6ffd20f1577076d7647c9caba", size = 1805780 }, + { url = "https://files.pythonhosted.org/packages/d6/87/0707bfae4cc3728eb8d5fc29018b5ac3e0e1f8efca237d267d1d3abc7153/gevent-25.4.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d68fdf9bff0068367126983d7d85765124c292b4bc3d4d19ed8138335d8426a7", size = 1885718 }, + { url = "https://files.pythonhosted.org/packages/09/c6/4f35473d46ca8cfbffeee5e6f89ac29370280b3f34682ed8f0fea907f987/gevent-25.4.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ff92408011d78e4ffe297331ff30cded39a3e22845ba237516c646f6a485a241", size = 1845102 }, + { url = "https://files.pythonhosted.org/packages/7a/9b/d2269957be2867802d10bcb28e17eba64783067057d55e91e57207294c05/gevent-25.4.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a7c70ab6d33dfeb43bfe982c636609d8f90506dacaaa1f409a3c43c66d578fb1", size = 2084973 }, + { url = "https://files.pythonhosted.org/packages/6b/59/9a069d16d8b6b7ef82b0d241de9041b1341c9f132fbd096b80d6d1bc2345/gevent-25.4.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:8e740bc08ba4c34951f4bb6351dbe04209416e12d620691fb57e115b218a7818", size = 1822891 }, + { url = "https://files.pythonhosted.org/packages/96/0d/815808f04cef2410a93521814e51de7554874012fc49c5ca7197f86ac340/gevent-25.4.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c535d96ded6e26b37fadda9242a49fea6308754da5945173940614b7520c07b4", size = 2115665 }, + { url = "https://files.pythonhosted.org/packages/42/b4/15e5f9c06d50843c0e7c87d580acc2ac4e47fef0195c2d3f73c3bd54e3f0/gevent-25.4.2-cp312-cp312-win_amd64.whl", hash = "sha256:c62bf14557d2cb54f5e3c1ba0a3b3f4b69bf0441081c32d63b205763b495b251", size = 1679652 }, + { url = "https://files.pythonhosted.org/packages/7d/1d/195936c1e0c5b1dc89a8b534c05d080d24d760f6913632cbb13d9430c907/gevent-25.4.2-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:f735f57bc19d0f8bbc784093cfb7953a9ad66612b05c3ff876ec7951a96d7edd", size = 2996686 }, + { url = "https://files.pythonhosted.org/packages/52/2a/a82de55db10ca17e210a61548a421d65d144045a62958d172537d4ea6f26/gevent-25.4.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63aecf1e43b8d01086ea574ed05f7272ed40c48dd41fa3d061e3c5ca900abcdd", size = 1809379 }, + { url = "https://files.pythonhosted.org/packages/77/73/3508d539c96e435d883aa07c67ad5859505af33346795c8c575501d3ebda/gevent-25.4.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f12e570777027f807dc7dc3ea1945ea040befaf1c9485deb6f24d7110009fc12", size = 1887353 }, + { url = "https://files.pythonhosted.org/packages/4d/40/911e4eca7958bea73d3889433e780b59413f3d7bbd4d24cadc0a2f276528/gevent-25.4.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:44acca4196d4a174c2b4817642564526898f42f72992dc1818b834b2bbf17582", size = 1848809 }, + { url = "https://files.pythonhosted.org/packages/59/eb/ccf5a2d7cb8ed2814b69fbe9cf46a8875f275fa0e5984889b1cbb0a67492/gevent-25.4.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75d2fdd24f3948c085d341281648014760f5cb23de9b29f710083e6911b2e605", size = 2084966 }, + { url = "https://files.pythonhosted.org/packages/7d/19/a1aadd6f3da55f18bb10877ccda7245be0c3b5e6acdc3c882fe54f412e01/gevent-25.4.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:0cc1d6093f482547ac522ab1a985429d8c12494518eeca354c956f0ff6de7a94", size = 1824458 }, + { url = "https://files.pythonhosted.org/packages/0f/70/ee8b5a4df0a6f587c44a102ad46356d626d652e35f46eeec05c5ba1575de/gevent-25.4.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:fe4a3e3fa3a16ed9b12b6ff0922208ef83287e066e696b82b96d33723d8207f2", size = 2116628 }, + { url = "https://files.pythonhosted.org/packages/13/c6/50ee863dd09dd31f61892b847b684fde730473487bcae3240acd9e3e412c/gevent-25.4.2-cp313-cp313-win_amd64.whl", hash = "sha256:8b90913360b1af058b279160679d804d4917a8661f128b2f7625f8665c39450f", size = 1678856 }, + { url = "https://files.pythonhosted.org/packages/54/d8/e29cc7f90ae7aa9e8f5298ca5a157bab34bfbc65d070385b28f4d72af1ac/gevent-25.4.2-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:b0a656eccd9cb115d01c9bbe55bfe84cf20c8422c495503f41aef747b193c33d", size = 3007128 }, +] + +[[package]] +name = "geventhttpclient" +version = "2.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "brotli" }, + { name = "certifi" }, + { name = "gevent" }, + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/68/98/1ee9fbab4ae97d5f0f05035059a56a61a9c966331e6c837f974b402fdf63/geventhttpclient-2.0.2.tar.gz", hash = "sha256:8135a85200b170def7293d01dd1557931fcd1bec1ac78c52ad7cedd22368b9ba", size = 73821 } + [[package]] name = "gguf" version = "0.10.0" @@ -1326,6 +1396,40 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/95/cb/002424d4f5af1425f9cfe7dcee3ed795ed1367bf0f185a6c4bf81385e1d6/gradio_client-1.7.2-py3-none-any.whl", hash = "sha256:50d61b4db3e87639430a121a7cde4303055486ed72a5035edae94b4fbe6a0e6b", size = 322052 }, ] +[[package]] +name = "greenlet" +version = "3.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3f/74/907bb43af91782e0366b0960af62a8ce1f9398e4291cac7beaeffbee0c04/greenlet-3.2.1.tar.gz", hash = "sha256:9f4dd4b4946b14bb3bf038f81e1d2e535b7d94f1b2a59fdba1293cd9c1a0a4d7", size = 184475 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f0/d1/e4777b188a04726f6cf69047830d37365b9191017f54caf2f7af336a6f18/greenlet-3.2.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:0ba2811509a30e5f943be048895a983a8daf0b9aa0ac0ead526dfb5d987d80ea", size = 270381 }, + { url = "https://files.pythonhosted.org/packages/59/e7/b5b738f5679247ddfcf2179c38945519668dced60c3164c20d55c1a7bb4a/greenlet-3.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4245246e72352b150a1588d43ddc8ab5e306bef924c26571aafafa5d1aaae4e8", size = 637195 }, + { url = "https://files.pythonhosted.org/packages/6c/9f/57968c88a5f6bc371364baf983a2e5549cca8f503bfef591b6dd81332cbc/greenlet-3.2.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7abc0545d8e880779f0c7ce665a1afc3f72f0ca0d5815e2b006cafc4c1cc5840", size = 651381 }, + { url = "https://files.pythonhosted.org/packages/40/81/1533c9a458e9f2ebccb3ae22f1463b2093b0eb448a88aac36182f1c2cd3d/greenlet-3.2.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6dcc6d604a6575c6225ac0da39df9335cc0c6ac50725063fa90f104f3dbdb2c9", size = 646110 }, + { url = "https://files.pythonhosted.org/packages/06/66/25f7e4b1468ebe4a520757f2e41c2a36a2f49a12e963431b82e9f98df2a0/greenlet-3.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2273586879affca2d1f414709bb1f61f0770adcabf9eda8ef48fd90b36f15d12", size = 648070 }, + { url = "https://files.pythonhosted.org/packages/d7/4c/49d366565c4c4d29e6f666287b9e2f471a66c3a3d8d5066692e347f09e27/greenlet-3.2.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ff38c869ed30fff07f1452d9a204ece1ec6d3c0870e0ba6e478ce7c1515acf22", size = 603816 }, + { url = "https://files.pythonhosted.org/packages/04/15/1612bb61506f44b6b8b6bebb6488702b1fe1432547e95dda57874303a1f5/greenlet-3.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e934591a7a4084fa10ee5ef50eb9d2ac8c4075d5c9cf91128116b5dca49d43b1", size = 1119572 }, + { url = "https://files.pythonhosted.org/packages/cc/2f/002b99dacd1610e825876f5cbbe7f86740aa2a6b76816e5eca41c8457e85/greenlet-3.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:063bcf7f8ee28eb91e7f7a8148c65a43b73fbdc0064ab693e024b5a940070145", size = 1147442 }, + { url = "https://files.pythonhosted.org/packages/c0/ba/82a2c3b9868644ee6011da742156247070f30e952f4d33f33857458450f2/greenlet-3.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7132e024ebeeeabbe661cf8878aac5d2e643975c4feae833142592ec2f03263d", size = 296207 }, + { url = "https://files.pythonhosted.org/packages/77/2a/581b3808afec55b2db838742527c40b4ce68b9b64feedff0fd0123f4b19a/greenlet-3.2.1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:e1967882f0c42eaf42282a87579685c8673c51153b845fde1ee81be720ae27ac", size = 269119 }, + { url = "https://files.pythonhosted.org/packages/b0/f3/1c4e27fbdc84e13f05afc2baf605e704668ffa26e73a43eca93e1120813e/greenlet-3.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e77ae69032a95640a5fe8c857ec7bee569a0997e809570f4c92048691ce4b437", size = 637314 }, + { url = "https://files.pythonhosted.org/packages/fc/1a/9fc43cb0044f425f7252da9847893b6de4e3b20c0a748bce7ab3f063d5bc/greenlet-3.2.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3227c6ec1149d4520bc99edac3b9bc8358d0034825f3ca7572165cb502d8f29a", size = 651421 }, + { url = "https://files.pythonhosted.org/packages/8a/65/d47c03cdc62c6680206b7420c4a98363ee997e87a5e9da1e83bd7eeb57a8/greenlet-3.2.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0ddda0197c5b46eedb5628d33dad034c455ae77708c7bf192686e760e26d6a0c", size = 645789 }, + { url = "https://files.pythonhosted.org/packages/2f/40/0faf8bee1b106c241780f377b9951dd4564ef0972de1942ef74687aa6bba/greenlet-3.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de62b542e5dcf0b6116c310dec17b82bb06ef2ceb696156ff7bf74a7a498d982", size = 648262 }, + { url = "https://files.pythonhosted.org/packages/e0/a8/73305f713183c2cb08f3ddd32eaa20a6854ba9c37061d682192db9b021c3/greenlet-3.2.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c07a0c01010df42f1f058b3973decc69c4d82e036a951c3deaf89ab114054c07", size = 606770 }, + { url = "https://files.pythonhosted.org/packages/c3/05/7d726e1fb7f8a6ac55ff212a54238a36c57db83446523c763e20cd30b837/greenlet-3.2.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:2530bfb0abcd451ea81068e6d0a1aac6dabf3f4c23c8bd8e2a8f579c2dd60d95", size = 1117960 }, + { url = "https://files.pythonhosted.org/packages/bf/9f/2b6cb1bd9f1537e7b08c08705c4a1d7bd4f64489c67d102225c4fd262bda/greenlet-3.2.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:1c472adfca310f849903295c351d297559462067f618944ce2650a1878b84123", size = 1145500 }, + { url = "https://files.pythonhosted.org/packages/e4/f6/339c6e707062319546598eb9827d3ca8942a3eccc610d4a54c1da7b62527/greenlet-3.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:24a496479bc8bd01c39aa6516a43c717b4cee7196573c47b1f8e1011f7c12495", size = 295994 }, + { url = "https://files.pythonhosted.org/packages/f1/72/2a251d74a596af7bb1717e891ad4275a3fd5ac06152319d7ad8c77f876af/greenlet-3.2.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:175d583f7d5ee57845591fc30d852b75b144eb44b05f38b67966ed6df05c8526", size = 629889 }, + { url = "https://files.pythonhosted.org/packages/29/2e/d7ed8bf97641bf704b6a43907c0e082cdf44d5bc026eb8e1b79283e7a719/greenlet-3.2.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3ecc9d33ca9428e4536ea53e79d781792cee114d2fa2695b173092bdbd8cd6d5", size = 635261 }, + { url = "https://files.pythonhosted.org/packages/1e/75/802aa27848a6fcb5e566f69c64534f572e310f0f12d41e9201a81e741551/greenlet-3.2.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3f56382ac4df3860ebed8ed838f268f03ddf4e459b954415534130062b16bc32", size = 632523 }, + { url = "https://files.pythonhosted.org/packages/56/09/f7c1c3bab9b4c589ad356503dd71be00935e9c4db4db516ed88fc80f1187/greenlet-3.2.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc45a7189c91c0f89aaf9d69da428ce8301b0fd66c914a499199cfb0c28420fc", size = 628816 }, + { url = "https://files.pythonhosted.org/packages/79/e0/1bb90d30b5450eac2dffeaac6b692857c4bd642c21883b79faa8fa056cf2/greenlet-3.2.1-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:51a2f49da08cff79ee42eb22f1658a2aed60c72792f0a0a95f5f0ca6d101b1fb", size = 593687 }, + { url = "https://files.pythonhosted.org/packages/c5/b5/adbe03c8b4c178add20cc716021183ae6b0326d56ba8793d7828c94286f6/greenlet-3.2.1-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:0c68bbc639359493420282d2f34fa114e992a8724481d700da0b10d10a7611b8", size = 1105754 }, + { url = "https://files.pythonhosted.org/packages/39/93/84582d7ef38dec009543ccadec6ab41079a6cbc2b8c0566bcd07bf1aaf6c/greenlet-3.2.1-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:e775176b5c203a1fa4be19f91da00fd3bff536868b77b237da3f4daa5971ae5d", size = 1125160 }, + { url = "https://files.pythonhosted.org/packages/01/e6/f9d759788518a6248684e3afeb3691f3ab0276d769b6217a1533362298c8/greenlet-3.2.1-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:d6668caf15f181c1b82fb6406f3911696975cc4c37d782e19cb7ba499e556189", size = 269897 }, +] + [[package]] name = "groovy" version = "0.1.2" @@ -3426,6 +3530,38 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/45/58/38b5afbc1a800eeea951b9285d3912613f2603bdf897a4ab0f4bd7f405fc/python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104", size = 24546 }, ] +[[package]] +name = "python-rapidjson" +version = "1.20" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/47/2a/2510836a65a1fc40c923393611896c3c8ad1e2f583ed0c32cf0bb48cc378/python_rapidjson-1.20.tar.gz", hash = "sha256:115f08c86d2df7543c02605e77c84727cdabc4b08310d2f097e953efeaaa73eb", size = 238158 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/d1/40616f40499f8f61e83135aa078a0ba7d392e7ea63c016c7cc544ecb7344/python_rapidjson-1.20-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6056fcc8caeb9b04775bf655568bba362c7670ab792c1b438671bb056db954cd", size = 230104 }, + { url = "https://files.pythonhosted.org/packages/ea/2f/d28f4da4df83cfeb60fb7b84396a9c3678a0ac615012dc234d5b962fbaaf/python_rapidjson-1.20-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:225bd4cbabfe7910261cbcebb8b811d4ff98e90cdd17c233b916c6aa71a9553f", size = 211105 }, + { url = "https://files.pythonhosted.org/packages/b3/60/ebc521afbdb626bb571a815378831f685213cb6b98ffe08176fe3191c5a3/python_rapidjson-1.20-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:026077b663acf93a3f2b1adb87282e611a30214b8ae8001b7e4863a3b978e646", size = 1650309 }, + { url = "https://files.pythonhosted.org/packages/19/da/4c375b90c54091e93a600fca06a9f3b8456b0e09050e862e998fc22b6385/python_rapidjson-1.20-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:884e1dd4c0770ed424737941af4d5dc9014995f9c33595f151af13f83ce282c3", size = 1700043 }, + { url = "https://files.pythonhosted.org/packages/bc/6e/2718413e7bc300523c5d4eaa25418059d8b17effa9aef2f2ae370493b861/python_rapidjson-1.20-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f55531c8197cb7a21a5ef0ffa46f2b8fc8c5fe7c6fd08bdbd2063ae65d2ff65", size = 1700523 }, + { url = "https://files.pythonhosted.org/packages/32/fe/d96e996f9c5140d3ce93d440f871a1b336f1c14fae27b64d4872fc58d45d/python_rapidjson-1.20-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c60121d155562dc694c05ed7df4e39e42ee1d3adff2a060c64a004498e6451f7", size = 1598383 }, + { url = "https://files.pythonhosted.org/packages/46/32/ef3a381641b803e1b67c9b9c360d161b650620605768652e704fb35ad2b9/python_rapidjson-1.20-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3a6620eed0b04196f37fab7048c1d672d03391bb29d7f09ee8fee8dea33f11f4", size = 2454134 }, + { url = "https://files.pythonhosted.org/packages/2f/50/771826d3f217b7c597f14df0dfa943d9e6f2f14749d974de4402f56ce39a/python_rapidjson-1.20-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:ddb63eff401ce7cf20cdd5e21942fc23fbe0e1dc1d96d7ae838645fb1f74fb47", size = 2585576 }, + { url = "https://files.pythonhosted.org/packages/64/95/f3e7ed53c9ab27a99c876c42b7d1994312e6fd2c2d8131ce849bd4275be8/python_rapidjson-1.20-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:05e28c3dbb4a0d74ec13af9668ef2b9f302edf83cf7ce1d8316a95364720eec0", size = 2599382 }, + { url = "https://files.pythonhosted.org/packages/bc/4c/34778932d0145fdc7087274cd4c0fa421a96acbc96bf9860cbdf3e389dcd/python_rapidjson-1.20-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b733978ecd84fc5df9a778ce821dc1f3113f7bfc2493cac0bb17efb4ae0bb8fa", size = 2537066 }, + { url = "https://files.pythonhosted.org/packages/50/16/dfef47ec507d5a5d00281b8db8526d5c36b715afeeae0ceeef4030f1640f/python_rapidjson-1.20-cp312-cp312-win32.whl", hash = "sha256:d87041448cec00e2db5d858625a76dc1b59eef6691a039acff6d92ad8581cfc1", size = 128358 }, + { url = "https://files.pythonhosted.org/packages/bc/97/42a550a79ab90ab37fcd8b519cd71bba4b96b85679218100d63b437770c0/python_rapidjson-1.20-cp312-cp312-win_amd64.whl", hash = "sha256:5d3be149ce5475f9605f01240487541057792abad94d3fd0cd56af363cf5a4dc", size = 149067 }, + { url = "https://files.pythonhosted.org/packages/18/04/47d9d10c3fa6e57af9462792088187605a07d88ad6f6f2e193fb01eff0fc/python_rapidjson-1.20-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:daee815b4c20ca6e4dbc6bde373dd3f65b53813d775f1c94b765b33b402513a7", size = 229315 }, + { url = "https://files.pythonhosted.org/packages/9a/3a/0c4e0af51d7356d9efdef1bf1785d9d9f9e0789a7d2844cc3e9b35ef383f/python_rapidjson-1.20-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:083df379c769b30f9bc40041c91fd9d8f7bb8ca2b3c7170258842aced2098e05", size = 211111 }, + { url = "https://files.pythonhosted.org/packages/83/e1/e253de9a774d021f9a6947f845628fae8237f441c63198e8a72e5906d31f/python_rapidjson-1.20-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9399ad75a2e3377f9e6208caabe73eb9354cd01b732407475ccadcd42c577df", size = 1650131 }, + { url = "https://files.pythonhosted.org/packages/3e/93/8f723c7f7be055086d6bec2ba9e5ef13e749c3fb3ad5a3dc1d740acee889/python_rapidjson-1.20-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:599ab208ccf6172d6cfac1abe048c837e62612f91f97d198e32773c45346a0b4", size = 1699873 }, + { url = "https://files.pythonhosted.org/packages/7d/2e/eb7255601b81a5b70f2bff05caab136e191b66825c16db3e7db1bdaa8314/python_rapidjson-1.20-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf3c0e2a5b97b0d07311f15f0dce4434e43dec865c3794ad1b10d968460fd665", size = 1700484 }, + { url = "https://files.pythonhosted.org/packages/90/54/23d8b595dd4fdbdaa6c5f723a4df7a7be78aa702aa0b6dac6c964e6e6d30/python_rapidjson-1.20-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e8064b8edb57ddd9e3ffa539cf2ec2f03515751fb0698b40ba5cb66a2123af19", size = 1598344 }, + { url = "https://files.pythonhosted.org/packages/3d/3a/3628e199a826e7bc598633ce895516981602ab1d8fce76359005f90ca488/python_rapidjson-1.20-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bc79d7f00f7538e027960ca6bcd1e03ed99fcf660d4d882d1c22f641155d0db0", size = 2454206 }, + { url = "https://files.pythonhosted.org/packages/ed/19/eef8629f73b1af21fa778d140e68e72076fe5746357426d6716a0c411dd2/python_rapidjson-1.20-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:87aa0b01b8c20984844f1440b8ff6bdb32de911a1750fed344b9daed33b4b52b", size = 2585553 }, + { url = "https://files.pythonhosted.org/packages/d8/9d/217e56c74a65cfaf4441b26b6206b924b41fb339f98776a74e60dd287b46/python_rapidjson-1.20-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:4099cb9eae8a0ce19c09e02729eb6d69d5180424f13a2641a6c407d053e47a82", size = 2599513 }, + { url = "https://files.pythonhosted.org/packages/54/f6/4d40189f14e4fa5526a91aad9944864c8a4eebc0257e0314a331f3c64170/python_rapidjson-1.20-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4c680cd2b4de760ff6875de71fe6a87bd610aa116593d62e4f81a563be86ae18", size = 2537192 }, + { url = "https://files.pythonhosted.org/packages/ee/30/f3f40abfd8d7f0586b88ccfcd747f2e227fe589c16fbb485b1e238d8e641/python_rapidjson-1.20-cp313-cp313-win32.whl", hash = "sha256:9e431a7afc77aa874fed537c9f6bf5fcecaef124ebeae2a2379d3b9e9adce74b", size = 128362 }, + { url = "https://files.pythonhosted.org/packages/94/df/7126352e55cb72a5ca99630bd44ffb11bbf61ee35f4e1f34d203a77597c5/python_rapidjson-1.20-cp313-cp313-win_amd64.whl", hash = "sha256:7444bc7e6a04c03d6ed748b5dab0798fa2b3f2b303be8c38d3af405b2cac6d63", size = 149072 }, +] + [[package]] name = "pytz" version = "2025.1" @@ -3580,6 +3716,8 @@ dependencies = [ { name = "torch" }, { name = "torchvision" }, { name = "transformers" }, + { name = "tritonclient", extra = ["all"] }, + { name = "ultralytics" }, { name = "unsloth" }, { name = "verifiers" }, { name = "vllm" }, @@ -3619,6 +3757,8 @@ requires-dist = [ { name = "torch", specifier = "==2.5.1" }, { name = "torchvision", specifier = "==0.20.1" }, { name = "transformers", specifier = "==4.49.0" }, + { name = "tritonclient", extras = ["all"], specifier = ">=2.51.0" }, + { name = "ultralytics", specifier = ">=8.3.120" }, { name = "unsloth", specifier = ">=2025.3.19" }, { name = "verifiers", editable = "../verifiers" }, { name = "vllm", specifier = "==0.7.3" }, @@ -4032,6 +4172,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0a/c8/b3f566db71461cabd4b2d5b39bcc24a7e1c119535c8361f81426be39bb47/scipy-1.15.2-cp313-cp313t-win_amd64.whl", hash = "sha256:fe8a9eb875d430d81755472c5ba75e84acc980e4a8f6204d402849234d3017db", size = 40477705 }, ] +[[package]] +name = "seaborn" +version = "0.13.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "matplotlib" }, + { name = "numpy" }, + { name = "pandas" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/86/59/a451d7420a77ab0b98f7affa3a1d78a313d2f7281a57afb1a34bae8ab412/seaborn-0.13.2.tar.gz", hash = "sha256:93e60a40988f4d65e9f4885df477e2fdaff6b73a9ded434c1ab356dd57eefff7", size = 1457696 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/11/00d3c3dfc25ad54e731d91449895a79e4bf2384dc3ac01809010ba88f6d5/seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987", size = 294914 }, +] + [[package]] name = "semantic-version" version = "2.10.0" @@ -4589,6 +4743,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f5/75/aac76f24dd17eb2245973ec1dd995759ce85ed91e5bb045fabb3c83ab1d6/triton_windows-3.2.0.post17-cp313-cp313-win_amd64.whl", hash = "sha256:539dd7ba8b7cc238930c1f4cb6e7819c22d1b8798fde361b78115b0fdb98a147", size = 40039344 }, ] +[[package]] +name = "tritonclient" +version = "2.51.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "python-rapidjson" }, + { name = "urllib3" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/23/a6/301bd2f431346adac05ff3c062bbcec0a93b567f1d3ef0d3ccf353a5bcd6/tritonclient-2.51.0-py3-none-any.whl", hash = "sha256:eef99681b0a18ee72808d887d2324a38a81fa1250924e595db46256b83f13668", size = 98012 }, + { url = "https://files.pythonhosted.org/packages/87/0b/57eae443655212c73ae3586b280e1b1c81ba1668afc94109b1efac8c23c4/tritonclient-2.51.0-py3-none-manylinux1_x86_64.whl", hash = "sha256:c485bb0123bdf310f90bc8b03d3489b28df2ffed55b30c7eee0b795b48113d52", size = 13956700 }, + { url = "https://files.pythonhosted.org/packages/70/bd/eb64fe810b8728f5f7936fe4d156062847d850c55923289dad8e281ee3d6/tritonclient-2.51.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:ee6f5a508409f6c95069f4d77d34e97bef84fb4a1aedb5d82ad0ad311ad128d5", size = 13325829 }, +] + +[package.optional-dependencies] +all = [ + { name = "aiohttp" }, + { name = "cuda-python" }, + { name = "geventhttpclient" }, + { name = "grpcio" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "protobuf" }, + { name = "python-rapidjson" }, +] + [[package]] name = "trl" version = "0.15.0.dev0" @@ -4740,6 +4921,45 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d7/72/6cb6728e2738c05bbe9bd522d6fc79f86b9a28402f38663e85a28fddd4a0/ujson-5.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:4573fd1695932d4f619928fd09d5d03d917274381649ade4328091ceca175539", size = 42212 }, ] +[[package]] +name = "ultralytics" +version = "8.3.120" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "matplotlib" }, + { name = "numpy" }, + { name = "opencv-python" }, + { name = "pandas" }, + { name = "pillow" }, + { name = "psutil" }, + { name = "py-cpuinfo" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "scipy" }, + { name = "seaborn" }, + { name = "torch" }, + { name = "torchvision" }, + { name = "tqdm" }, + { name = "ultralytics-thop" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1d/c8/921621be09aed3c498d0db807261a9737d04efe84f8cb729de3874dfe2d8/ultralytics-8.3.120.tar.gz", hash = "sha256:5b709c2a66fc1580dfbf8d6be56727b941d0d3d5906582f9613e72b90b486e53", size = 863199 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f5/bc/3f390c44ef15deb1af6235b349b6953c6409f45f02d49b1e22b6f940871c/ultralytics-8.3.120-py3-none-any.whl", hash = "sha256:7ac3bf90850eb7b943c3f1c8451eca271f8277c51d9af9cb34933c7a23cab9ad", size = 1004601 }, +] + +[[package]] +name = "ultralytics-thop" +version = "2.0.14" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "torch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c2/d8/e43a8bfcb03ff036119d098a7ea27be9f0adb715543ed6bd83b16cda83dc/ultralytics_thop-2.0.14.tar.gz", hash = "sha256:38ebfdbd3cd8dafdc3d26ec3a7d4f604fbeed5e69a74e61a48060b39736c945c", size = 28793 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/10/251f036b4c5d77249f9a119cc89dafe8745dc1ad1f1a5f06b6a3988ca454/ultralytics_thop-2.0.14-py3-none-any.whl", hash = "sha256:720b421e2459179fee21ec8f730d242a20774cd4b0a00a58d02351a39ec3881c", size = 26517 }, +] + [[package]] name = "unsloth" version = "2025.3.19" @@ -5283,3 +5503,38 @@ sdist = { url = "https://files.pythonhosted.org/packages/3f/50/bad581df71744867e wheels = [ { url = "https://files.pythonhosted.org/packages/b7/1a/7e4798e9339adc931158c9d69ecc34f5e6791489d469f5e50ec15e35f458/zipp-3.21.0-py3-none-any.whl", hash = "sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931", size = 9630 }, ] + +[[package]] +name = "zope-event" +version = "5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "setuptools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/46/c2/427f1867bb96555d1d34342f1dd97f8c420966ab564d58d18469a1db8736/zope.event-5.0.tar.gz", hash = "sha256:bac440d8d9891b4068e2b5a2c5e2c9765a9df762944bda6955f96bb9b91e67cd", size = 17350 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/42/f8dbc2b9ad59e927940325a22d6d3931d630c3644dae7e2369ef5d9ba230/zope.event-5.0-py3-none-any.whl", hash = "sha256:2832e95014f4db26c47a13fdaef84cef2f4df37e66b59d8f1f4a8f319a632c26", size = 6824 }, +] + +[[package]] +name = "zope-interface" +version = "7.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "setuptools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/30/93/9210e7606be57a2dfc6277ac97dcc864fd8d39f142ca194fdc186d596fda/zope.interface-7.2.tar.gz", hash = "sha256:8b49f1a3d1ee4cdaf5b32d2e738362c7f5e40ac8b46dd7d1a65e82a4872728fe", size = 252960 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/0b/c7516bc3bad144c2496f355e35bd699443b82e9437aa02d9867653203b4a/zope.interface-7.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:086ee2f51eaef1e4a52bd7d3111a0404081dadae87f84c0ad4ce2649d4f708b7", size = 208959 }, + { url = "https://files.pythonhosted.org/packages/a2/e9/1463036df1f78ff8c45a02642a7bf6931ae4a38a4acd6a8e07c128e387a7/zope.interface-7.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:21328fcc9d5b80768bf051faa35ab98fb979080c18e6f84ab3f27ce703bce465", size = 209357 }, + { url = "https://files.pythonhosted.org/packages/07/a8/106ca4c2add440728e382f1b16c7d886563602487bdd90004788d45eb310/zope.interface-7.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6dd02ec01f4468da0f234da9d9c8545c5412fef80bc590cc51d8dd084138a89", size = 264235 }, + { url = "https://files.pythonhosted.org/packages/fc/ca/57286866285f4b8a4634c12ca1957c24bdac06eae28fd4a3a578e30cf906/zope.interface-7.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8e7da17f53e25d1a3bde5da4601e026adc9e8071f9f6f936d0fe3fe84ace6d54", size = 259253 }, + { url = "https://files.pythonhosted.org/packages/96/08/2103587ebc989b455cf05e858e7fbdfeedfc3373358320e9c513428290b1/zope.interface-7.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cab15ff4832580aa440dc9790b8a6128abd0b88b7ee4dd56abacbc52f212209d", size = 264702 }, + { url = "https://files.pythonhosted.org/packages/5f/c7/3c67562e03b3752ba4ab6b23355f15a58ac2d023a6ef763caaca430f91f2/zope.interface-7.2-cp312-cp312-win_amd64.whl", hash = "sha256:29caad142a2355ce7cfea48725aa8bcf0067e2b5cc63fcf5cd9f97ad12d6afb5", size = 212466 }, + { url = "https://files.pythonhosted.org/packages/c6/3b/e309d731712c1a1866d61b5356a069dd44e5b01e394b6cb49848fa2efbff/zope.interface-7.2-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:3e0350b51e88658d5ad126c6a57502b19d5f559f6cb0a628e3dc90442b53dd98", size = 208961 }, + { url = "https://files.pythonhosted.org/packages/49/65/78e7cebca6be07c8fc4032bfbb123e500d60efdf7b86727bb8a071992108/zope.interface-7.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:15398c000c094b8855d7d74f4fdc9e73aa02d4d0d5c775acdef98cdb1119768d", size = 209356 }, + { url = "https://files.pythonhosted.org/packages/11/b1/627384b745310d082d29e3695db5f5a9188186676912c14b61a78bbc6afe/zope.interface-7.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:802176a9f99bd8cc276dcd3b8512808716492f6f557c11196d42e26c01a69a4c", size = 264196 }, + { url = "https://files.pythonhosted.org/packages/b8/f6/54548df6dc73e30ac6c8a7ff1da73ac9007ba38f866397091d5a82237bd3/zope.interface-7.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eb23f58a446a7f09db85eda09521a498e109f137b85fb278edb2e34841055398", size = 259237 }, + { url = "https://files.pythonhosted.org/packages/b6/66/ac05b741c2129fdf668b85631d2268421c5cd1a9ff99be1674371139d665/zope.interface-7.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a71a5b541078d0ebe373a81a3b7e71432c61d12e660f1d67896ca62d9628045b", size = 264696 }, + { url = "https://files.pythonhosted.org/packages/0a/2f/1bccc6f4cc882662162a1158cda1a7f616add2ffe322b28c99cb031b4ffc/zope.interface-7.2-cp313-cp313-win_amd64.whl", hash = "sha256:4893395d5dd2ba655c38ceb13014fd65667740f09fa5bb01caa1e6284e48c0cd", size = 212472 }, +]