diff --git a/WAS_Node_Suite.py b/WAS_Node_Suite.py index e422df1..24f9804 100644 --- a/WAS_Node_Suite.py +++ b/WAS_Node_Suite.py @@ -3885,7 +3885,52 @@ def INPUT_TYPES(cls): @classmethod def IS_CHANGED(cls, **kwargs): - return float("NaN") + # IS_CHANGED logic should use BatchImageLoader to get the filename based on the stored index/label/path/pattern + # It needs to instantiate BatchImageLoader to load paths (from cache if possible) and get the current index. + # This means IS_CHANGED also benefits from the caching mechanism we added. + # However, IS_CHANGED *should not* modify the state (like index in DB). + # The BatchImageLoader __init__ reads state but doesn't increment for single_image mode, so it's safe. + + if kwargs['mode'] != 'single_image': + # For incremental and random modes, IS_CHANGED should always return NaN as state changes externally or randomly. + return float("NaN") + else: + # For single_image, check if the specific file at the index has changed. + # Need to get the correct image path for the given label, path, pattern, and index. + # Instantiating BatchImageLoader is still the way to get the file paths based on inputs and cache. + + try: + # Create a temporary loader instance just to get the file list based on inputs and cache. + loader = WAS_Load_Image_Batch.BatchImageLoader(kwargs['path'], kwargs['label'], kwargs['pattern']) + + # Ensure paths are loaded - __init__ guarantees this + if not loader.image_paths: + cstr(f"IS_CHANGED: No images found for path '{kwargs['path']}' pattern '{kwargs['pattern']}'.").warning.print() + return float("NaN") # Cannot determine change if no images + + # Get the filename corresponding to the *input* index for single_image mode + target_index = kwargs['index'] + if target_index < 0 or target_index >= len(loader.image_paths): + cstr(f"IS_CHANGED: Invalid index {target_index} for single_image mode. List length: {len(loader.image_paths)}.").warning.print() + return float("NaN") # Index out of bounds, cannot determine change + + image_path = loader.image_paths[target_index] + # Ensure image_path is absolute for get_sha256 + abs_image_path = os.path.abspath(image_path) + + + # Check if file exists before getting SHA + if not os.path.exists(abs_image_path): + cstr(f"IS_CHANGED: File not found at expected path '{abs_image_path}'. Assuming changed.").warning.print() + return float("NaN") # File disappeared, indicates change + + # Calculate SHA of the file at the determined path + sha = get_sha256(abs_image_path) + return sha + + except Exception as e: + cstr(f"Error in IS_CHANGED for single_image mode: {e}").error.print() + return float("NaN") # Any error during the process indicates unknown state/change RETURN_TYPES = ("IMAGE","IMAGE",TEXT_TYPE,TEXT_TYPE) RETURN_NAMES = ("image_a_pass","image_b_pass","filepath_text","filename_text") @@ -5265,6 +5310,8 @@ def apply_resize_image(self, image: Image.Image, mode='scale', supersample='true # LOAD IMAGE BATCH +# Add this at the top of the file or near other module-level variables +_batch_image_paths_memory_cache = {} class WAS_Load_Image_Batch: def __init__(self): @@ -5299,8 +5346,16 @@ def load_batch_images(self, path, pattern='*', index=0, mode="single_image", see if not os.path.exists(path): return (None, ) + # Instantiate BatchImageLoader, this includes cache loading or file scanning logic fl = self.BatchImageLoader(path, label, pattern) - new_paths = fl.image_paths + + # Ensure fl.image_paths has been populated (from cache or scan) + if not fl.image_paths: + cstr(f"No images found in '{path}' with pattern '{pattern}'.").error.print() + return (None, None) + + new_paths = fl.image_paths # Still need the full path list to support all modes + if mode == 'single_image': image, filename = fl.get_image_by_id(index) if image == None: @@ -5311,8 +5366,9 @@ def load_batch_images(self, path, pattern='*', index=0, mode="single_image", see if image == None: cstr(f"No valid image was found for the next ID. Did you remove images from the source directory?").error.print() return (None, None) - else: + else: # random mode random.seed(seed) + # Ensure index is within the bounds of the new list newindex = int(random.random() * len(fl.image_paths)) image, filename = fl.get_image_by_id(newindex) if image == None: @@ -5320,8 +5376,13 @@ def load_batch_images(self, path, pattern='*', index=0, mode="single_image", see return (None, None) - # Update history - update_history_images(new_paths) + # Update history - Check if update_history_images processes all new_paths, if so it might be a bottleneck + # If update_history_images processes all paths, consider only passing the filename actually loaded to it + # I've commented this out in the previous attempt as a potential bottleneck. + # Re-evaluate if this line is needed and if it's the cause of slowness. + # If you uncomment this, ensure update_history_images is efficient with large lists. + # update_history_images(new_paths) + if not allow_RGBA_output: image = image.convert("RGB") @@ -5330,57 +5391,149 @@ def load_batch_images(self, path, pattern='*', index=0, mode="single_image", see filename = os.path.splitext(filename)[0] return (pil2tensor(image), filename) - class BatchImageLoader: def __init__(self, directory_path, label, pattern): + global _batch_image_paths_memory_cache # Declare intent to use global variable + self.WDB = WDB self.image_paths = [] - self.load_images(directory_path, pattern) - self.image_paths.sort() - stored_directory_path = self.WDB.get('Batch Paths', label) - stored_pattern = self.WDB.get('Batch Patterns', label) - if stored_directory_path != directory_path or stored_pattern != pattern: - self.index = 0 - self.WDB.insert('Batch Counters', label, 0) - self.WDB.insert('Batch Paths', label, directory_path) - self.WDB.insert('Batch Patterns', label, pattern) - else: - self.index = self.WDB.get('Batch Counters', label) self.label = label + self.directory_path = directory_path # Store path and pattern + self.pattern = pattern + + # Create a string key for caching file list + # Combine path and pattern with a separator (e.g., ":::") + cache_key = f"{directory_path}:::{pattern}" + + # --- Check Memory Cache First --- + if cache_key in _batch_image_paths_memory_cache: + self.image_paths = _batch_image_paths_memory_cache[cache_key] + cached_label_info = self.WDB.get('Batch Label Info', label) + if cached_label_info and cached_label_info.get('path') == directory_path and cached_label_info.get('pattern') == pattern: + self.index = self.WDB.get('Batch Counters', label) + cstr(f"Loaded image paths for '{label}' from memory cache.").msg.print() + return # Found in memory and parameters match, exit init + else: + cstr(f"Memory cache hit for '{label}', but label parameters mismatch or new. Proceeding to DB/rescan.").msg.print() + + + # --- If not in Memory Cache, check Database Cache --- + # Attempt to load cached path list as a text block from the database + cached_paths_text = self.WDB.get('Batch Images Paths Cache Text', cache_key) # Note the new category name + cached_label_info = self.WDB.get('Batch Label Info', label) + + if cached_paths_text and cached_label_info and cached_label_info.get('path') == directory_path and cached_label_info.get('pattern') == pattern: + # Cache exists in DB as text, load from DB and split into list + try: + self.image_paths = cached_paths_text.strip().split('\n') + # Filter out empty lines, although load_images should not produce them, reading might + self.image_paths = [path for path in self.image_paths if path] + self.index = self.WDB.get('Batch Counters', label) + _batch_image_paths_memory_cache[cache_key] = self.image_paths # Store in memory cache + cstr(f"Loaded image paths for '{label}' from database text cache.").msg.print() + return # Found in DB text cache and parameters match, exit init + except Exception as e: + # If loading or parsing text cache fails, record error and continue with rescan + cstr(f"Error loading or parsing DB text cache for '{label}': {e}. Rescanning.").error.print() + self.image_paths = [] # Reset image_paths + + + # --- Cache miss (memory and DB) or parameters changed, rescan and update caches --- + cstr(f"Cache miss or parameters changed for '{label}', rescanning directory.").msg.print() + self.load_images(directory_path, pattern) # This will populate self.image_paths + self.image_paths.sort() + self.index = 0 # Reset index on rescan + + # Update both caches and database info + _batch_image_paths_memory_cache[cache_key] = self.image_paths # Store in memory cache + + # Convert file path list to a text block and store in the database + paths_text_to_save = "\n".join(self.image_paths) + self.WDB.insert('Batch Images Paths Cache Text', cache_key, paths_text_to_save) # Use the new category name and text block + self.WDB.insert('Batch Counters', label, 0) + self.WDB.insert('Batch Label Info', label, {'path': directory_path, 'pattern': pattern}) + def load_images(self, directory_path, pattern): + # Clear the existing path list to ensure only results from this scan are added + self.image_paths = [] + # glob.glob finds all matching file paths, recursive=True supports subdirectories for file_name in glob.glob(os.path.join(glob.escape(directory_path), pattern), recursive=True): + # Check if file extension is in the allowed list if file_name.lower().endswith(ALLOWED_EXT): + # Get the absolute path of the file to ensure standardization abs_file_path = os.path.abspath(file_name) self.image_paths.append(abs_file_path) + # Note: Sorting is done after returning from this method in __init__ + def get_image_by_id(self, image_id): - if image_id < 0 or image_id >= len(self.image_paths): - cstr(f"Invalid image index `{image_id}`").error.print() - return - i = Image.open(self.image_paths[image_id]) - i = ImageOps.exif_transpose(i) - return (i, os.path.basename(self.image_paths[image_id])) + # Ensure self.image_paths is populated and not empty + if not self.image_paths or image_id < 0 or image_id >= len(self.image_paths): + cstr(f"Invalid image index `{image_id}` or image path list is empty").error.print() + return (None, None) # Return None, None to indicate load failure + image_path = self.image_paths[image_id] + try: + i = Image.open(image_path) + i = ImageOps.exif_transpose(i) + return (i, os.path.basename(image_path)) + except Exception as e: + cstr(f"Error opening image '{image_path}': {e}").error.print() + return (None, None) def get_next_image(self): - if self.index >= len(self.image_paths): - self.index = 0 + # Ensure self.image_paths is populated and not empty + if not self.image_paths: + cstr(f"Image path list is empty. Cannot get next image.").error.print() + return (None, None) + + # incremental mode needs to update index and store in database + # Check if index is out of bounds before accessing list + if self.index >= len(self.image_paths) or self.index < 0: + self.index = 0 # Reset index if out of bounds + cstr(f'{cstr.color.YELLOW}{self.label}{cstr.color.END} Index reset to 0').msg.print() + + image_path = self.image_paths[self.index] - self.index += 1 - if self.index == len(self.image_paths): - self.index = 0 - cstr(f'{cstr.color.YELLOW}{self.label}{cstr.color.END} Index: {self.index}').msg.print() - self.WDB.insert('Batch Counters', self.label, self.index) - i = Image.open(image_path) - i = ImageOps.exif_transpose(i) - return (i, os.path.basename(image_path)) + + # Get current index *before* attempting to open image + current_index_attempt = self.index + + try: + i = Image.open(image_path) + i = ImageOps.exif_transpose(i) + + # Update index and database only on successful image open + self.index = current_index_attempt + 1 + if self.index >= len(self.image_paths): # Wrap around + self.index = 0 + cstr(f'{cstr.color.YELLOW}{self.label}{cstr.color.END} Index: {self.index}').msg.print() + # Only update database on successful load to persist position + self.WDB.insert('Batch Counters', self.label, self.index) + + return (i, os.path.basename(image_path)) + except Exception as e: + cstr(f"Error opening image '{image_path}' for index {current_index_attempt}: {e}. Skipping this image.").error.print() + # On error, increment index to skip the problematic image in the next attempt + self.index = current_index_attempt + 1 + if self.index >= len(self.image_paths): # Wrap around + self.index = 0 + # Do NOT update database here, wait for a successful load. + cstr(f'{cstr.color.YELLOW}{self.label}{cstr.color.END} Index incremented to {self.index} after error.').msg.print() + return (None, None) # Return None, None on error def get_current_image(self): - if self.index >= len(self.image_paths): - self.index = 0 + # This method is used by IS_CHANGED. It should try to determine the *current* image based on the stored index. + # Ensure self.image_paths is populated first (it will be by __init__) + if not self.image_paths or self.index < 0 or self.index >= len(self.image_paths): + # If list is empty or index is invalid, return a placeholder that will trigger rescan or indicate no image. + cstr(f"Warning: Image path list is empty or index invalid ({self.index}) in get_current_image.").warning.print() + # Return something that IS_CHANGED can handle as "changed" or "unknown" + return f"NO_IMAGE_OR_INVALID_INDEX_{self.index}_LIST_LEN_{len(self.image_paths) if self.image_paths else 0}" + image_path = self.image_paths[self.index] return os.path.basename(image_path) - + @classmethod def IS_CHANGED(cls, **kwargs): if kwargs['mode'] != 'single_image': @@ -11156,7 +11309,7 @@ def INPUT_TYPES(cls): CATEGORY = "WAS Suite/IO" - def load_file(self, file_path='', dictionary_name='[filename]]'): + def load_file(self, file_path='', dictionary_name='[filename]'): filename = ( os.path.basename(file_path).split('.', 1)[0] if '.' in os.path.basename(file_path) else os.path.basename(file_path) )