Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 190 additions & 37 deletions WAS_Node_Suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3885,7 +3885,52 @@ def INPUT_TYPES(cls):

@classmethod
def IS_CHANGED(cls, **kwargs):
return float("NaN")
# IS_CHANGED logic should use BatchImageLoader to get the filename based on the stored index/label/path/pattern
# It needs to instantiate BatchImageLoader to load paths (from cache if possible) and get the current index.
# This means IS_CHANGED also benefits from the caching mechanism we added.
# However, IS_CHANGED *should not* modify the state (like index in DB).
# The BatchImageLoader __init__ reads state but doesn't increment for single_image mode, so it's safe.

if kwargs['mode'] != 'single_image':
# For incremental and random modes, IS_CHANGED should always return NaN as state changes externally or randomly.
return float("NaN")
else:
# For single_image, check if the specific file at the index has changed.
# Need to get the correct image path for the given label, path, pattern, and index.
# Instantiating BatchImageLoader is still the way to get the file paths based on inputs and cache.

try:
# Create a temporary loader instance just to get the file list based on inputs and cache.
loader = WAS_Load_Image_Batch.BatchImageLoader(kwargs['path'], kwargs['label'], kwargs['pattern'])

# Ensure paths are loaded - __init__ guarantees this
if not loader.image_paths:
cstr(f"IS_CHANGED: No images found for path '{kwargs['path']}' pattern '{kwargs['pattern']}'.").warning.print()
return float("NaN") # Cannot determine change if no images

# Get the filename corresponding to the *input* index for single_image mode
target_index = kwargs['index']
if target_index < 0 or target_index >= len(loader.image_paths):
cstr(f"IS_CHANGED: Invalid index {target_index} for single_image mode. List length: {len(loader.image_paths)}.").warning.print()
return float("NaN") # Index out of bounds, cannot determine change

image_path = loader.image_paths[target_index]
# Ensure image_path is absolute for get_sha256
abs_image_path = os.path.abspath(image_path)


# Check if file exists before getting SHA
if not os.path.exists(abs_image_path):
cstr(f"IS_CHANGED: File not found at expected path '{abs_image_path}'. Assuming changed.").warning.print()
return float("NaN") # File disappeared, indicates change

# Calculate SHA of the file at the determined path
sha = get_sha256(abs_image_path)
return sha

except Exception as e:
cstr(f"Error in IS_CHANGED for single_image mode: {e}").error.print()
return float("NaN") # Any error during the process indicates unknown state/change

RETURN_TYPES = ("IMAGE","IMAGE",TEXT_TYPE,TEXT_TYPE)
RETURN_NAMES = ("image_a_pass","image_b_pass","filepath_text","filename_text")
Expand Down Expand Up @@ -5265,6 +5310,8 @@ def apply_resize_image(self, image: Image.Image, mode='scale', supersample='true


# LOAD IMAGE BATCH
# Add this at the top of the file or near other module-level variables
_batch_image_paths_memory_cache = {}

class WAS_Load_Image_Batch:
def __init__(self):
Expand Down Expand Up @@ -5299,8 +5346,16 @@ def load_batch_images(self, path, pattern='*', index=0, mode="single_image", see

if not os.path.exists(path):
return (None, )
# Instantiate BatchImageLoader, this includes cache loading or file scanning logic
fl = self.BatchImageLoader(path, label, pattern)
new_paths = fl.image_paths

# Ensure fl.image_paths has been populated (from cache or scan)
if not fl.image_paths:
cstr(f"No images found in '{path}' with pattern '{pattern}'.").error.print()
return (None, None)

new_paths = fl.image_paths # Still need the full path list to support all modes

if mode == 'single_image':
image, filename = fl.get_image_by_id(index)
if image == None:
Expand All @@ -5311,17 +5366,23 @@ def load_batch_images(self, path, pattern='*', index=0, mode="single_image", see
if image == None:
cstr(f"No valid image was found for the next ID. Did you remove images from the source directory?").error.print()
return (None, None)
else:
else: # random mode
random.seed(seed)
# Ensure index is within the bounds of the new list
newindex = int(random.random() * len(fl.image_paths))
image, filename = fl.get_image_by_id(newindex)
if image == None:
cstr(f"No valid image was found for the next ID. Did you remove images from the source directory?").error.print()
return (None, None)


# Update history
update_history_images(new_paths)
# Update history - Check if update_history_images processes all new_paths, if so it might be a bottleneck
# If update_history_images processes all paths, consider only passing the filename actually loaded to it
# I've commented this out in the previous attempt as a potential bottleneck.
# Re-evaluate if this line is needed and if it's the cause of slowness.
# If you uncomment this, ensure update_history_images is efficient with large lists.
# update_history_images(new_paths)


if not allow_RGBA_output:
image = image.convert("RGB")
Expand All @@ -5330,57 +5391,149 @@ def load_batch_images(self, path, pattern='*', index=0, mode="single_image", see
filename = os.path.splitext(filename)[0]

return (pil2tensor(image), filename)

class BatchImageLoader:
def __init__(self, directory_path, label, pattern):
global _batch_image_paths_memory_cache # Declare intent to use global variable

self.WDB = WDB
self.image_paths = []
self.load_images(directory_path, pattern)
self.image_paths.sort()
stored_directory_path = self.WDB.get('Batch Paths', label)
stored_pattern = self.WDB.get('Batch Patterns', label)
if stored_directory_path != directory_path or stored_pattern != pattern:
self.index = 0
self.WDB.insert('Batch Counters', label, 0)
self.WDB.insert('Batch Paths', label, directory_path)
self.WDB.insert('Batch Patterns', label, pattern)
else:
self.index = self.WDB.get('Batch Counters', label)
self.label = label
self.directory_path = directory_path # Store path and pattern
self.pattern = pattern

# Create a string key for caching file list
# Combine path and pattern with a separator (e.g., ":::")
cache_key = f"{directory_path}:::{pattern}"

# --- Check Memory Cache First ---
if cache_key in _batch_image_paths_memory_cache:
self.image_paths = _batch_image_paths_memory_cache[cache_key]
cached_label_info = self.WDB.get('Batch Label Info', label)
if cached_label_info and cached_label_info.get('path') == directory_path and cached_label_info.get('pattern') == pattern:
self.index = self.WDB.get('Batch Counters', label)
cstr(f"Loaded image paths for '{label}' from memory cache.").msg.print()
return # Found in memory and parameters match, exit init
else:
cstr(f"Memory cache hit for '{label}', but label parameters mismatch or new. Proceeding to DB/rescan.").msg.print()


# --- If not in Memory Cache, check Database Cache ---
# Attempt to load cached path list as a text block from the database
cached_paths_text = self.WDB.get('Batch Images Paths Cache Text', cache_key) # Note the new category name
cached_label_info = self.WDB.get('Batch Label Info', label)

if cached_paths_text and cached_label_info and cached_label_info.get('path') == directory_path and cached_label_info.get('pattern') == pattern:
# Cache exists in DB as text, load from DB and split into list
try:
self.image_paths = cached_paths_text.strip().split('\n')
# Filter out empty lines, although load_images should not produce them, reading might
self.image_paths = [path for path in self.image_paths if path]
self.index = self.WDB.get('Batch Counters', label)
_batch_image_paths_memory_cache[cache_key] = self.image_paths # Store in memory cache
cstr(f"Loaded image paths for '{label}' from database text cache.").msg.print()
return # Found in DB text cache and parameters match, exit init
except Exception as e:
# If loading or parsing text cache fails, record error and continue with rescan
cstr(f"Error loading or parsing DB text cache for '{label}': {e}. Rescanning.").error.print()
self.image_paths = [] # Reset image_paths


# --- Cache miss (memory and DB) or parameters changed, rescan and update caches ---
cstr(f"Cache miss or parameters changed for '{label}', rescanning directory.").msg.print()
self.load_images(directory_path, pattern) # This will populate self.image_paths
self.image_paths.sort()
self.index = 0 # Reset index on rescan

# Update both caches and database info
_batch_image_paths_memory_cache[cache_key] = self.image_paths # Store in memory cache

# Convert file path list to a text block and store in the database
paths_text_to_save = "\n".join(self.image_paths)
self.WDB.insert('Batch Images Paths Cache Text', cache_key, paths_text_to_save) # Use the new category name and text block
self.WDB.insert('Batch Counters', label, 0)
self.WDB.insert('Batch Label Info', label, {'path': directory_path, 'pattern': pattern})


def load_images(self, directory_path, pattern):
# Clear the existing path list to ensure only results from this scan are added
self.image_paths = []
# glob.glob finds all matching file paths, recursive=True supports subdirectories
for file_name in glob.glob(os.path.join(glob.escape(directory_path), pattern), recursive=True):
# Check if file extension is in the allowed list
if file_name.lower().endswith(ALLOWED_EXT):
# Get the absolute path of the file to ensure standardization
abs_file_path = os.path.abspath(file_name)
self.image_paths.append(abs_file_path)
# Note: Sorting is done after returning from this method in __init__


def get_image_by_id(self, image_id):
if image_id < 0 or image_id >= len(self.image_paths):
cstr(f"Invalid image index `{image_id}`").error.print()
return
i = Image.open(self.image_paths[image_id])
i = ImageOps.exif_transpose(i)
return (i, os.path.basename(self.image_paths[image_id]))
# Ensure self.image_paths is populated and not empty
if not self.image_paths or image_id < 0 or image_id >= len(self.image_paths):
cstr(f"Invalid image index `{image_id}` or image path list is empty").error.print()
return (None, None) # Return None, None to indicate load failure
image_path = self.image_paths[image_id]
try:
i = Image.open(image_path)
i = ImageOps.exif_transpose(i)
return (i, os.path.basename(image_path))
except Exception as e:
cstr(f"Error opening image '{image_path}': {e}").error.print()
return (None, None)

def get_next_image(self):
if self.index >= len(self.image_paths):
self.index = 0
# Ensure self.image_paths is populated and not empty
if not self.image_paths:
cstr(f"Image path list is empty. Cannot get next image.").error.print()
return (None, None)

# incremental mode needs to update index and store in database
# Check if index is out of bounds before accessing list
if self.index >= len(self.image_paths) or self.index < 0:
self.index = 0 # Reset index if out of bounds
cstr(f'{cstr.color.YELLOW}{self.label}{cstr.color.END} Index reset to 0').msg.print()


image_path = self.image_paths[self.index]
self.index += 1
if self.index == len(self.image_paths):
self.index = 0
cstr(f'{cstr.color.YELLOW}{self.label}{cstr.color.END} Index: {self.index}').msg.print()
self.WDB.insert('Batch Counters', self.label, self.index)
i = Image.open(image_path)
i = ImageOps.exif_transpose(i)
return (i, os.path.basename(image_path))

# Get current index *before* attempting to open image
current_index_attempt = self.index

try:
i = Image.open(image_path)
i = ImageOps.exif_transpose(i)

# Update index and database only on successful image open
self.index = current_index_attempt + 1
if self.index >= len(self.image_paths): # Wrap around
self.index = 0
cstr(f'{cstr.color.YELLOW}{self.label}{cstr.color.END} Index: {self.index}').msg.print()
# Only update database on successful load to persist position
self.WDB.insert('Batch Counters', self.label, self.index)

return (i, os.path.basename(image_path))
except Exception as e:
cstr(f"Error opening image '{image_path}' for index {current_index_attempt}: {e}. Skipping this image.").error.print()
# On error, increment index to skip the problematic image in the next attempt
self.index = current_index_attempt + 1
if self.index >= len(self.image_paths): # Wrap around
self.index = 0
# Do NOT update database here, wait for a successful load.
cstr(f'{cstr.color.YELLOW}{self.label}{cstr.color.END} Index incremented to {self.index} after error.').msg.print()
return (None, None) # Return None, None on error

def get_current_image(self):
if self.index >= len(self.image_paths):
self.index = 0
# This method is used by IS_CHANGED. It should try to determine the *current* image based on the stored index.
# Ensure self.image_paths is populated first (it will be by __init__)
if not self.image_paths or self.index < 0 or self.index >= len(self.image_paths):
# If list is empty or index is invalid, return a placeholder that will trigger rescan or indicate no image.
cstr(f"Warning: Image path list is empty or index invalid ({self.index}) in get_current_image.").warning.print()
# Return something that IS_CHANGED can handle as "changed" or "unknown"
return f"NO_IMAGE_OR_INVALID_INDEX_{self.index}_LIST_LEN_{len(self.image_paths) if self.image_paths else 0}"

image_path = self.image_paths[self.index]
return os.path.basename(image_path)

@classmethod
def IS_CHANGED(cls, **kwargs):
if kwargs['mode'] != 'single_image':
Expand Down Expand Up @@ -11156,7 +11309,7 @@ def INPUT_TYPES(cls):

CATEGORY = "WAS Suite/IO"

def load_file(self, file_path='', dictionary_name='[filename]]'):
def load_file(self, file_path='', dictionary_name='[filename]'):

filename = ( os.path.basename(file_path).split('.', 1)[0]
if '.' in os.path.basename(file_path) else os.path.basename(file_path) )
Expand Down