diff --git a/README.md b/README.md index 8ca5bda..f5353e0 100644 --- a/README.md +++ b/README.md @@ -181,3 +181,13 @@ If you encounter any issues or have questions: --- **Note**: This node requires a Together AI account and API key. You can get one at [Together AI's website](https://together.ai). + +**Updated README.md to reflect automatic mode switching based on image connection** + +The node now automatically switches between Vision Mode and Text-Only Mode based on the presence of an image input connection. When an image is connected, the node will generate detailed image descriptions. When no image is connected, the node will function as a text generation model. + +**Flexible Processing Modes** + +- **Image + Text Mode**: When an image is connected, generates descriptions and responses about the image +- **Text-Only Mode**: When no image is connected, functions as a text generation model +- Seamlessly switches between modes based on input connections diff --git a/requirements.txt b/requirements.txt index a85b2ea..0b91e06 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -together +together \ No newline at end of file diff --git a/together_vision_node.py b/together_vision_node.py index 65c5c2a..be4c0ee 100644 --- a/together_vision_node.py +++ b/together_vision_node.py @@ -7,6 +7,9 @@ from together import Together import torch import logging +import re +import time +from typing import Optional # Set up logging logging.basicConfig(level=logging.INFO) @@ -22,12 +25,13 @@ class TogetherVisionNode: def __init__(self): self.client = None + self.last_request_time = 0 + self.min_request_interval = 1.0 # Minimum seconds between requests @classmethod def INPUT_TYPES(cls): return { "required": { - "image": ("IMAGE",), "model_name": (["Free (Llama-Vision-Free)", "Paid (Llama-3.2-11B-Vision)"],), "api_key": ("STRING", {"default": "", "multiline": False}), "system_prompt": ("STRING", { @@ -62,130 +66,136 @@ def INPUT_TYPES(cls): "max": 2.0, "step": 0.1 }) + }, + "optional": { + "image": ("IMAGE",) } } RETURN_TYPES = ("STRING",) + RETURN_NAMES = ("description",) FUNCTION = "process_image" CATEGORY = "image/text" + OUTPUT_NODE = True - def encode_image(self, image_tensor): + def encode_image(self, image_tensor: torch.Tensor) -> str: """ - Converts an image tensor to base64 string. + Converts an image tensor to base64 string with improved error handling. """ try: - logger.info(f"Image tensor type: {type(image_tensor)}") - logger.info(f"Image tensor shape: {image_tensor.shape if hasattr(image_tensor, 'shape') else 'No shape'}") - # Handle different input types if isinstance(image_tensor, torch.Tensor): - logger.info("Converting torch tensor to numpy array") image_array = image_tensor.cpu().numpy() elif isinstance(image_tensor, np.ndarray): - logger.info("Input is already a numpy array") image_array = image_tensor else: raise ValueError(f"Unsupported image type: {type(image_tensor)}") - logger.info(f"Numpy array shape: {image_array.shape}") - logger.info(f"Numpy array dtype: {image_array.dtype}") + # Validate array shape + if not (2 <= len(image_array.shape) <= 4): + raise ValueError(f"Invalid image shape: {image_array.shape}") - # Ensure image is in the right format (HWC) + # Handle batch dimension if len(image_array.shape) == 4: - logger.info("Removing batch dimension") - image_array = image_array[0] # Take first image if batched + image_array = image_array[0] + # Ensure correct channel format if len(image_array.shape) == 3: - if image_array.shape[0] in [3, 4]: # If channels first (CHW) - logger.info("Converting CHW to HWC format") - image_array = np.transpose(image_array, (1, 2, 0)) # Convert to HWC - - logger.info(f"After format conversion - Shape: {image_array.shape}") + if image_array.shape[0] in [3, 4]: # CHW to HWC + image_array = np.transpose(image_array, (1, 2, 0)) - # Handle different channel numbers - if image_array.shape[-1] == 4: # RGBA - logger.info("Converting RGBA to RGB") - image_array = image_array[..., :3] # Convert to RGB + # Convert RGBA to RGB if needed + if image_array.shape[-1] == 4: + image_array = image_array[..., :3] - # Convert to uint8 if needed - if image_array.dtype == np.float32 or image_array.dtype == np.float64: - logger.info("Converting to uint8") + # Normalize and convert to uint8 + if image_array.dtype in [np.float32, np.float64]: image_array = (image_array * 255).clip(0, 255).astype(np.uint8) - - logger.info(f"Final array shape: {image_array.shape}, dtype: {image_array.dtype}") - - # Create PIL Image + elif image_array.dtype != np.uint8: + raise ValueError(f"Unsupported image dtype: {image_array.dtype}") + + # Create PIL Image and validate pil_image = Image.fromarray(image_array) - logger.info(f"PIL Image size: {pil_image.size}, mode: {pil_image.mode}") - + if pil_image.size[0] * pil_image.size[1] == 0: + raise ValueError("Invalid image dimensions") + # Convert to base64 buffered = io.BytesIO() pil_image.save(buffered, format="PNG") - base64_str = base64.b64encode(buffered.getvalue()).decode("utf-8") - logger.info(f"Base64 string length: {len(base64_str)}") - - return base64_str - + return base64.b64encode(buffered.getvalue()).decode("utf-8") + except Exception as e: - logger.error(f"Error in encode_image: {str(e)}", exc_info=True) - raise + logger.error(f"Image encoding error: {str(e)}") + raise ValueError(f"Failed to encode image: {str(e)}") - def get_api_key(self, provided_key): - """ - Get API key from input or environment variable. - """ - if provided_key: - return provided_key - return os.getenv('TOGETHER_API_KEY') + def get_api_key(self, provided_key: str) -> str: + """Get API key with validation.""" + api_key = provided_key or os.getenv('TOGETHER_API_KEY') + if not api_key: + raise ValueError("API key not provided. Please provide an API key or set TOGETHER_API_KEY environment variable.") + return api_key + + def rate_limit_check(self): + """Implement rate limiting.""" + current_time = time.time() + time_since_last = current_time - self.last_request_time + if time_since_last < self.min_request_interval: + time.sleep(self.min_request_interval - time_since_last) + self.last_request_time = time.time() - def process_image(self, image, model_name, api_key, system_prompt, user_prompt, - temperature, top_p, top_k, repetition_penalty): + def process_image(self, model_name: str, api_key: str, system_prompt: str, user_prompt: str, + temperature: float, top_p: float, top_k: int, repetition_penalty: float, + image: Optional[torch.Tensor] = None) -> tuple: """ - Process the image and generate description using Together API. + Process the image and generate description using Together API with improved stability. """ try: - logger.info("Starting image processing") - - # Map friendly model names to actual model IDs + # Validate required inputs + if not user_prompt: + raise ValueError("User prompt cannot be empty") + if not system_prompt: + system_prompt = "You are an AI that describes images accurately and concisely." + + # Rate limit check + self.rate_limit_check() + + # Map model names model_mapping = { "Paid (Llama-3.2-11B-Vision)": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", "Free (Llama-Vision-Free)": "meta-llama/Llama-Vision-Free" } actual_model = model_mapping[model_name] - - # Get API key and initialize client if needed + + # Initialize API client api_key = self.get_api_key(api_key) if self.client is None: - logger.info("Initializing Together client") self.client = Together(api_key=api_key) - # Convert image to base64 - logger.info("Converting image to base64") - base64_image = self.encode_image(image) - - # Create the messages array with system and user prompts - messages = [ - { - "role": "system", - "content": system_prompt - }, - { - "role": "user", - "content": [ - {"type": "text", "text": user_prompt}, - { - "type": "image_url", - "image_url": { - "url": f"data:image/png;base64,{base64_image}" + # Prepare messages + messages = [{"role": "system", "content": system_prompt}] + + # Handle image if provided + if image is not None: + try: + base64_image = self.encode_image(image) + messages.append({ + "role": "user", + "content": [ + {"type": "text", "text": user_prompt}, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{base64_image}"} } - } - ] - } - ] + ] + }) + except Exception as img_error: + logger.error(f"Image processing failed: {str(img_error)}") + return (f"Error processing image: {str(img_error)}",) + else: + messages.append({"role": "user", "content": user_prompt}) try: - # Call the Together API with all parameters - logger.info(f"Calling Together API with model: {actual_model}") + # API call with timeout handling response = self.client.chat.completions.create( model=actual_model, messages=messages, @@ -197,18 +207,25 @@ def process_image(self, image, model_name, api_key, system_prompt, user_prompt, stream=True ) - # Process the streamed response + # Process streamed response with timeout description = "" - logger.info("Processing streamed response") + start_time = time.time() + timeout = 30 # 30 seconds timeout + for chunk in response: - logger.debug(f"Received chunk: {chunk}") - if hasattr(chunk, 'choices') and len(chunk.choices) > 0: - if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'): - content = chunk.choices[0].delta.content - if content is not None: - description += content - - logger.info("Finished processing response") + if time.time() - start_time > timeout: + raise TimeoutError("Response generation timed out") + + if not hasattr(chunk, 'choices') or not chunk.choices: + continue + + delta = chunk.choices[0].delta + if hasattr(delta, 'content') and delta.content: + description += delta.content + + if not description: + raise ValueError("No response generated") + return (description,) except Exception as api_error: @@ -217,12 +234,11 @@ def process_image(self, image, model_name, api_key, system_prompt, user_prompt, wait_time = "1 hour" model_type = "free" if "free" in model_name.lower() else "paid" - # Try to extract the wait time from the error message time_match = re.search(r"try again in (?:about )?([^:]+)", error_msg) if time_match: wait_time = time_match.group(1) - error_message = f"""⚠️ Rate Limit Exceeded + return (f"""⚠️ Rate Limit Exceeded The {model_name} has reached its rate limit. Please try again in {wait_time}. @@ -235,16 +251,14 @@ def process_image(self, image, model_name, api_key, system_prompt, user_prompt, Rate Limits: • Free Model: ~100 requests/day, 20-30 requests/hour -• Paid Model: Based on subscription tier""" - - logger.warning(f"Rate limit exceeded for {model_type} model: {error_msg}") - return (error_message,) +• Paid Model: Based on subscription tier""",) else: raise api_error except Exception as e: - logger.error(f"Error in process_image: {str(e)}", exc_info=True) - return (f"Error: {str(e)}",) + error_msg = str(e) + logger.error(f"Process error: {error_msg}") + return (f"Error: {error_msg}",) # Node registration NODE_CLASS_MAPPINGS = {