6
6
import os
7
7
import random
8
8
import sys
9
+ import re
9
10
from typing import Dict , List , Any , Callable , Tuple
10
11
11
12
import black
@@ -217,7 +218,7 @@ def generate_workflow(self, load_order: List, filename: str = 'generated_code_wo
217
218
continue
218
219
219
220
class_type , import_statement , class_code = self .get_class_info (class_type )
220
- initialized_objects [class_type ] = class_type . lower (). strip ( )
221
+ initialized_objects [class_type ] = self . clean_variable_name ( class_type )
221
222
if class_type in self .base_node_class_mappings .keys ():
222
223
import_statements .add (import_statement )
223
224
if class_type not in self .base_node_class_mappings .keys ():
@@ -234,9 +235,9 @@ def generate_workflow(self, load_order: List, filename: str = 'generated_code_wo
234
235
inputs ['unique_id' ] = random .randint (1 , 2 ** 64 )
235
236
236
237
# Create executed variable and generate code
237
- executed_variables [idx ] = f'{ class_type . lower (). strip ( )} _{ idx } '
238
+ executed_variables [idx ] = f'{ self . clean_variable_name ( class_type )} _{ idx } '
238
239
inputs = self .update_inputs (inputs , executed_variables )
239
-
240
+
240
241
if is_special_function :
241
242
special_functions_code .append (self .create_function_call_code (initialized_objects [class_type ], class_def .FUNCTION , executed_variables [idx ], is_special_function , ** inputs ))
242
243
else :
@@ -329,6 +330,21 @@ def assemble_python_code(self, import_statements: set, speical_functions_code: L
329
330
330
331
return final_code
331
332
333
+ def clean_variable_name (self , class_type : str ) -> str :
334
+ clean_name = class_type .lower ().strip ()
335
+
336
+ # Convert to lowercase and replace spaces with underscores
337
+ clean_name = clean_name .lower ().replace ("-" , "_" ).replace (" " , "_" )
338
+
339
+ # Remove characters that are not letters, numbers, or underscores
340
+ clean_name = re .sub (r'[^a-z0-9_]' , '' , clean_name )
341
+
342
+ # Ensure that it doesn't start with a number
343
+ if clean_name [0 ].isdigit ():
344
+ clean_name = "_" + clean_name
345
+
346
+ return clean_name
347
+
332
348
def get_class_info (self , class_type : str ) -> Tuple [str , str , str ]:
333
349
"""Generates and returns necessary information about class type.
334
350
@@ -339,10 +355,11 @@ def get_class_info(self, class_type: str) -> Tuple[str, str, str]:
339
355
Tuple[str, str, str]: Updated class type, import statement string, class initialization code.
340
356
"""
341
357
import_statement = class_type
358
+ variable_name = self .clean_variable_name (class_type )
342
359
if class_type in self .base_node_class_mappings .keys ():
343
- class_code = f'{ class_type . lower (). strip () } = { class_type .strip ()} ()'
360
+ class_code = f'{ variable_name } = { class_type .strip ()} ()'
344
361
else :
345
- class_code = f'{ class_type . lower (). strip () } = NODE_CLASS_MAPPINGS["{ class_type } "]()'
362
+ class_code = f'{ variable_name } = NODE_CLASS_MAPPINGS["{ class_type } "]()'
346
363
347
364
return class_type , import_statement , class_code
348
365
0 commit comments