-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtogether_image_node.py
221 lines (182 loc) · 8.13 KB
/
together_image_node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import os
import base64
import io
from PIL import Image
import numpy as np
import torch
from together import Together
import logging
from typing import Optional, Tuple
from dotenv import load_dotenv
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
class TogetherImageNode:
"""
A custom node for ComfyUI that uses Together AI's FLUX model for image generation.
"""
def __init__(self):
self.client = None
self.last_request_time = 0
self.min_request_interval = 1.0 # Minimum seconds between requests
def get_client(self, api_key: str) -> Optional[Together]:
"""Get or create Together client with validation."""
try:
# If api_key from node is empty, try to get from .env
final_api_key = api_key.strip() if api_key and api_key.strip() else os.getenv('TOGETHER_API_KEY')
if not final_api_key:
logger.error("No API key provided in node or .env file")
raise ValueError("No API key found. Please provide an API key in the node or set TOGETHER_API_KEY in .env file")
if self.client is None:
# Set the API key in environment
os.environ["TOGETHER_API_KEY"] = final_api_key
self.client = Together()
return self.client
except Exception as e:
logger.error(f"Failed to initialize Together client: {str(e)}")
return None
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"prompt": ("STRING", {
"default": "",
"multiline": True
}),
"api_key": ("STRING", {
"default": "",
"multiline": False
}),
"width": ("INT", {
"default": 1024,
"min": 512,
"max": 2048,
"step": 256
}),
"height": ("INT", {
"default": 1024,
"min": 512,
"max": 2048,
"step": 256
}),
"seed": ("INT", {
"default": 0,
"min": 0,
"max": 0xffffffffffffffff
}),
"num_images": ("INT", {
"default": 1,
"min": 1,
"max": 4,
"step": 1
})
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "generate_image"
CATEGORY = "image"
OUTPUT_NODE = True
def b64_to_image(self, b64_string: str) -> torch.Tensor:
"""Convert base64 string to torch tensor image."""
try:
# Decode base64 to image
image_data = base64.b64decode(b64_string)
# Log base64 string length for debugging
logger.info(f"Decoded base64 string length: {len(image_data)} bytes")
image = Image.open(io.BytesIO(image_data))
# Log image details
logger.info(f"Image mode: {image.mode}, size: {image.size}")
# Convert to RGB if needed
if image.mode != 'RGB':
logger.info(f"Converting image from {image.mode} to RGB")
image = image.convert('RGB')
# Convert to numpy array
image_np = np.array(image)
# Log numpy array details
logger.info(f"Numpy array shape: {image_np.shape}, dtype: {image_np.dtype}")
# Ensure the image is in the correct format for ComfyUI
# ComfyUI expects [height, width, channels] format with values 0-255
if image_np.ndim != 3 or image_np.shape[2] != 3:
raise ValueError(f"Unexpected image shape: {image_np.shape}")
# Ensure uint8 type and correct range
image_np = np.clip(image_np, 0, 255).astype(np.uint8)
# Convert to torch tensor in ComfyUI's expected format
# [batch, channels, height, width]
image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).float() / 255.0
image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
# Log tensor details
logger.info(f"Final tensor shape: {image_tensor.shape}, dtype: {image_tensor.dtype}")
return image_tensor
except Exception as e:
logger.error(f"Error converting base64 to image: {str(e)}")
raise ValueError(f"Failed to convert base64 to image: {str(e)}")
def generate_image(self, prompt: str, api_key: str, width: int, height: int,
seed: int, num_images: int) -> Tuple[torch.Tensor]:
"""
Generate images using Together AI's FLUX model.
"""
try:
# Validate inputs
if not prompt or prompt.strip() == "":
raise ValueError("Prompt cannot be empty")
# Validate width and height
if width < 512 or width > 2048 or height < 512 or height > 2048:
raise ValueError(f"Invalid image dimensions. Width and height must be between 512 and 2048. Got {width}x{height}")
# Validate number of images
if num_images < 1 or num_images > 4:
raise ValueError(f"Number of images must be between 1 and 4. Got {num_images}")
# Get client with validation and .env fallback
client = self.get_client(api_key)
if client is None:
return None, "Error: Failed to initialize API client. Please check your API key."
# Make API call
response = client.images.generate(
prompt=prompt,
model="black-forest-labs/FLUX.1-schnell-Free",
width=width,
height=height,
steps=4, # Fixed to 4 steps as per model requirements
n=num_images,
seed=seed,
response_format="b64_json"
)
# Process response
if not response.data:
raise ValueError("No images generated in response")
# Convert all images to tensors
image_tensors = []
for img_data in response.data:
if not hasattr(img_data, 'b64_json'):
logger.warning("Skipping image without base64 data")
continue
try:
# Decode base64
image_data = base64.b64decode(img_data.b64_json)
image = Image.open(io.BytesIO(image_data))
# Convert to numpy array
image_np = np.array(image)
# Ensure uint8 and correct shape
image_np = np.clip(image_np, 0, 255).astype(np.uint8)
# Convert to torch tensor in ComfyUI format [batch, height, width, channels]
image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).float() / 255.0
image_tensor = image_tensor.permute(1, 2, 0) # Change to [height, width, channels]
image_tensors.append(image_tensor)
except Exception as img_error:
logger.error(f"Failed to process an image: {str(img_error)}")
if not image_tensors:
raise ValueError("Failed to process any images from response")
# Stack images
final_tensor = torch.stack(image_tensors)
return (final_tensor,)
except Exception as e:
logger.error(f"Image generation failed: {str(e)}")
raise ValueError(f"Failed to generate image: {str(e)}")
# Node registration
NODE_CLASS_MAPPINGS = {
"Together Image 🎨": TogetherImageNode
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Together Image 🎨": "Together Image Generator"
}