Skip to content

Commit baf5e91

Browse files
authored
Merge pull request #11 from rossaai/main
[Fix] Clean the variable names to avoid conflict
2 parents 48248e0 + c0830e5 commit baf5e91

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

comfyui_to_python.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import random
88
import sys
9+
import re
910
from typing import Dict, List, Any, Callable, Tuple
1011

1112
import black
@@ -217,7 +218,7 @@ def generate_workflow(self, load_order: List, filename: str = 'generated_code_wo
217218
continue
218219

219220
class_type, import_statement, class_code = self.get_class_info(class_type)
220-
initialized_objects[class_type] = class_type.lower().strip()
221+
initialized_objects[class_type] = self.clean_variable_name(class_type)
221222
if class_type in self.base_node_class_mappings.keys():
222223
import_statements.add(import_statement)
223224
if class_type not in self.base_node_class_mappings.keys():
@@ -234,9 +235,9 @@ def generate_workflow(self, load_order: List, filename: str = 'generated_code_wo
234235
inputs['unique_id'] = random.randint(1, 2**64)
235236

236237
# Create executed variable and generate code
237-
executed_variables[idx] = f'{class_type.lower().strip()}_{idx}'
238+
executed_variables[idx] = f'{self.clean_variable_name(class_type)}_{idx}'
238239
inputs = self.update_inputs(inputs, executed_variables)
239-
240+
240241
if is_special_function:
241242
special_functions_code.append(self.create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, **inputs))
242243
else:
@@ -329,6 +330,21 @@ def assemble_python_code(self, import_statements: set, speical_functions_code: L
329330

330331
return final_code
331332

333+
def clean_variable_name(self, class_type: str) -> str:
334+
clean_name = class_type.lower().strip()
335+
336+
# Convert to lowercase and replace spaces with underscores
337+
clean_name = clean_name.lower().replace("-", "_").replace(" ", "_")
338+
339+
# Remove characters that are not letters, numbers, or underscores
340+
clean_name = re.sub(r'[^a-z0-9_]', '', clean_name)
341+
342+
# Ensure that it doesn't start with a number
343+
if clean_name[0].isdigit():
344+
clean_name = "_" + clean_name
345+
346+
return clean_name
347+
332348
def get_class_info(self, class_type: str) -> Tuple[str, str, str]:
333349
"""Generates and returns necessary information about class type.
334350
@@ -339,10 +355,11 @@ def get_class_info(self, class_type: str) -> Tuple[str, str, str]:
339355
Tuple[str, str, str]: Updated class type, import statement string, class initialization code.
340356
"""
341357
import_statement = class_type
358+
variable_name = self.clean_variable_name(class_type)
342359
if class_type in self.base_node_class_mappings.keys():
343-
class_code = f'{class_type.lower().strip()} = {class_type.strip()}()'
360+
class_code = f'{variable_name} = {class_type.strip()}()'
344361
else:
345-
class_code = f'{class_type.lower().strip()} = NODE_CLASS_MAPPINGS["{class_type}"]()'
362+
class_code = f'{variable_name} = NODE_CLASS_MAPPINGS["{class_type}"]()'
346363

347364
return class_type, import_statement, class_code
348365

0 commit comments

Comments
 (0)