Skip to content

Commit

Permalink
updated vision node
Browse files Browse the repository at this point in the history
  • Loading branch information
theshubzworld authored Dec 4, 2024
1 parent 57e825a commit 558a17a
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 97 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,13 @@ If you encounter any issues or have questions:
---
**Note**: This node requires a Together AI account and API key. You can get one at [Together AI's website](https://together.ai).
**Updated README.md to reflect automatic mode switching based on image connection**
The node now automatically switches between Vision Mode and Text-Only Mode based on the presence of an image input connection. When an image is connected, the node will generate detailed image descriptions. When no image is connected, the node will function as a text generation model.
**Flexible Processing Modes**
- **Image + Text Mode**: When an image is connected, generates descriptions and responses about the image
- **Text-Only Mode**: When no image is connected, functions as a text generation model
- Seamlessly switches between modes based on input connections
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
together
together
206 changes: 110 additions & 96 deletions together_vision_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from together import Together
import torch
import logging
import re
import time
from typing import Optional

# Set up logging
logging.basicConfig(level=logging.INFO)
Expand All @@ -22,12 +25,13 @@ class TogetherVisionNode:

def __init__(self):
self.client = None
self.last_request_time = 0
self.min_request_interval = 1.0 # Minimum seconds between requests

@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"model_name": (["Free (Llama-Vision-Free)", "Paid (Llama-3.2-11B-Vision)"],),
"api_key": ("STRING", {"default": "", "multiline": False}),
"system_prompt": ("STRING", {
Expand Down Expand Up @@ -62,130 +66,136 @@ def INPUT_TYPES(cls):
"max": 2.0,
"step": 0.1
})
},
"optional": {
"image": ("IMAGE",)
}
}

RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("description",)
FUNCTION = "process_image"
CATEGORY = "image/text"
OUTPUT_NODE = True

def encode_image(self, image_tensor):
def encode_image(self, image_tensor: torch.Tensor) -> str:
"""
Converts an image tensor to base64 string.
Converts an image tensor to base64 string with improved error handling.
"""
try:
logger.info(f"Image tensor type: {type(image_tensor)}")
logger.info(f"Image tensor shape: {image_tensor.shape if hasattr(image_tensor, 'shape') else 'No shape'}")

# Handle different input types
if isinstance(image_tensor, torch.Tensor):
logger.info("Converting torch tensor to numpy array")
image_array = image_tensor.cpu().numpy()
elif isinstance(image_tensor, np.ndarray):
logger.info("Input is already a numpy array")
image_array = image_tensor
else:
raise ValueError(f"Unsupported image type: {type(image_tensor)}")

logger.info(f"Numpy array shape: {image_array.shape}")
logger.info(f"Numpy array dtype: {image_array.dtype}")
# Validate array shape
if not (2 <= len(image_array.shape) <= 4):
raise ValueError(f"Invalid image shape: {image_array.shape}")

# Ensure image is in the right format (HWC)
# Handle batch dimension
if len(image_array.shape) == 4:
logger.info("Removing batch dimension")
image_array = image_array[0] # Take first image if batched
image_array = image_array[0]

# Ensure correct channel format
if len(image_array.shape) == 3:
if image_array.shape[0] in [3, 4]: # If channels first (CHW)
logger.info("Converting CHW to HWC format")
image_array = np.transpose(image_array, (1, 2, 0)) # Convert to HWC

logger.info(f"After format conversion - Shape: {image_array.shape}")
if image_array.shape[0] in [3, 4]: # CHW to HWC
image_array = np.transpose(image_array, (1, 2, 0))

# Handle different channel numbers
if image_array.shape[-1] == 4: # RGBA
logger.info("Converting RGBA to RGB")
image_array = image_array[..., :3] # Convert to RGB
# Convert RGBA to RGB if needed
if image_array.shape[-1] == 4:
image_array = image_array[..., :3]

# Convert to uint8 if needed
if image_array.dtype == np.float32 or image_array.dtype == np.float64:
logger.info("Converting to uint8")
# Normalize and convert to uint8
if image_array.dtype in [np.float32, np.float64]:
image_array = (image_array * 255).clip(0, 255).astype(np.uint8)

logger.info(f"Final array shape: {image_array.shape}, dtype: {image_array.dtype}")
# Create PIL Image
elif image_array.dtype != np.uint8:
raise ValueError(f"Unsupported image dtype: {image_array.dtype}")

# Create PIL Image and validate
pil_image = Image.fromarray(image_array)
logger.info(f"PIL Image size: {pil_image.size}, mode: {pil_image.mode}")

if pil_image.size[0] * pil_image.size[1] == 0:
raise ValueError("Invalid image dimensions")

# Convert to base64
buffered = io.BytesIO()
pil_image.save(buffered, format="PNG")
base64_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
logger.info(f"Base64 string length: {len(base64_str)}")

return base64_str

return base64.b64encode(buffered.getvalue()).decode("utf-8")

except Exception as e:
logger.error(f"Error in encode_image: {str(e)}", exc_info=True)
raise
logger.error(f"Image encoding error: {str(e)}")
raise ValueError(f"Failed to encode image: {str(e)}")

def get_api_key(self, provided_key):
"""
Get API key from input or environment variable.
"""
if provided_key:
return provided_key
return os.getenv('TOGETHER_API_KEY')
def get_api_key(self, provided_key: str) -> str:
"""Get API key with validation."""
api_key = provided_key or os.getenv('TOGETHER_API_KEY')
if not api_key:
raise ValueError("API key not provided. Please provide an API key or set TOGETHER_API_KEY environment variable.")
return api_key

def rate_limit_check(self):
"""Implement rate limiting."""
current_time = time.time()
time_since_last = current_time - self.last_request_time
if time_since_last < self.min_request_interval:
time.sleep(self.min_request_interval - time_since_last)
self.last_request_time = time.time()

def process_image(self, image, model_name, api_key, system_prompt, user_prompt,
temperature, top_p, top_k, repetition_penalty):
def process_image(self, model_name: str, api_key: str, system_prompt: str, user_prompt: str,
temperature: float, top_p: float, top_k: int, repetition_penalty: float,
image: Optional[torch.Tensor] = None) -> tuple:
"""
Process the image and generate description using Together API.
Process the image and generate description using Together API with improved stability.
"""
try:
logger.info("Starting image processing")

# Map friendly model names to actual model IDs
# Validate required inputs
if not user_prompt:
raise ValueError("User prompt cannot be empty")
if not system_prompt:
system_prompt = "You are an AI that describes images accurately and concisely."

# Rate limit check
self.rate_limit_check()

# Map model names
model_mapping = {
"Paid (Llama-3.2-11B-Vision)": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
"Free (Llama-Vision-Free)": "meta-llama/Llama-Vision-Free"
}
actual_model = model_mapping[model_name]
# Get API key and initialize client if needed

# Initialize API client
api_key = self.get_api_key(api_key)
if self.client is None:
logger.info("Initializing Together client")
self.client = Together(api_key=api_key)

# Convert image to base64
logger.info("Converting image to base64")
base64_image = self.encode_image(image)

# Create the messages array with system and user prompts
messages = [
{
"role": "system",
"content": system_prompt
},
{
"role": "user",
"content": [
{"type": "text", "text": user_prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}"
# Prepare messages
messages = [{"role": "system", "content": system_prompt}]

# Handle image if provided
if image is not None:
try:
base64_image = self.encode_image(image)
messages.append({
"role": "user",
"content": [
{"type": "text", "text": user_prompt},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{base64_image}"}
}
}
]
}
]
]
})
except Exception as img_error:
logger.error(f"Image processing failed: {str(img_error)}")
return (f"Error processing image: {str(img_error)}",)
else:
messages.append({"role": "user", "content": user_prompt})

try:
# Call the Together API with all parameters
logger.info(f"Calling Together API with model: {actual_model}")
# API call with timeout handling
response = self.client.chat.completions.create(
model=actual_model,
messages=messages,
Expand All @@ -197,18 +207,25 @@ def process_image(self, image, model_name, api_key, system_prompt, user_prompt,
stream=True
)

# Process the streamed response
# Process streamed response with timeout
description = ""
logger.info("Processing streamed response")
start_time = time.time()
timeout = 30 # 30 seconds timeout

for chunk in response:
logger.debug(f"Received chunk: {chunk}")
if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'):
content = chunk.choices[0].delta.content
if content is not None:
description += content

logger.info("Finished processing response")
if time.time() - start_time > timeout:
raise TimeoutError("Response generation timed out")

if not hasattr(chunk, 'choices') or not chunk.choices:
continue

delta = chunk.choices[0].delta
if hasattr(delta, 'content') and delta.content:
description += delta.content

if not description:
raise ValueError("No response generated")

return (description,)

except Exception as api_error:
Expand All @@ -217,12 +234,11 @@ def process_image(self, image, model_name, api_key, system_prompt, user_prompt,
wait_time = "1 hour"
model_type = "free" if "free" in model_name.lower() else "paid"

# Try to extract the wait time from the error message
time_match = re.search(r"try again in (?:about )?([^:]+)", error_msg)
if time_match:
wait_time = time_match.group(1)

error_message = f"""⚠️ Rate Limit Exceeded
return (f"""⚠️ Rate Limit Exceeded
The {model_name} has reached its rate limit.
Please try again in {wait_time}.
Expand All @@ -235,16 +251,14 @@ def process_image(self, image, model_name, api_key, system_prompt, user_prompt,
Rate Limits:
• Free Model: ~100 requests/day, 20-30 requests/hour
• Paid Model: Based on subscription tier"""

logger.warning(f"Rate limit exceeded for {model_type} model: {error_msg}")
return (error_message,)
• Paid Model: Based on subscription tier""",)
else:
raise api_error

except Exception as e:
logger.error(f"Error in process_image: {str(e)}", exc_info=True)
return (f"Error: {str(e)}",)
error_msg = str(e)
logger.error(f"Process error: {error_msg}")
return (f"Error: {error_msg}",)

# Node registration
NODE_CLASS_MAPPINGS = {
Expand Down

0 comments on commit 558a17a

Please sign in to comment.