diff --git a/qwen-batch-single-pass-v2.py b/qwen-batch-single-pass-v2.py new file mode 100644 index 0000000..2b26be9 --- /dev/null +++ b/qwen-batch-single-pass-v2.py @@ -0,0 +1,90 @@ +import os +import re +import torch +import time +import argparse +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +# Argument Parsing +parser = argparse.ArgumentParser(description='Image Captioning Script') +parser.add_argument('--imgdir', type=str, default='path/to/img/dir', help='Directory containing images') +parser.add_argument('--exist', type=str, default='replace', choices=['skip', 'add', 'replace'], help='Handling of existing captions') +parser.add_argument('--prompt', type=str, default='describe this image in detail, in less than 35 words', help='Prompt to use for image captioning') +parser.add_argument('--sub', type=lambda x: (str(x).lower() == 'true'), default=False, help='Search for images in subdirectories') +args = parser.parse_args() + +# Function to check for unwanted elements in the caption +def has_unwanted_elements(caption): + patterns = [r'.*?', r'.*?'] + return any(re.search(pattern, caption) for pattern in patterns) + +# Function to clean up the caption +def clean_caption(caption): + caption = re.sub(r'(.*?)', r'\1', caption) + caption = re.sub(r'.*?', '', caption) + return caption.strip() + +# Supported image types +image_types = ['.png', '.jpg', '.jpeg', '.bmp', '.gif'] + +# Initialize the model and tokenizer +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat-Int4", trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat-Int4", device_map="cuda", trust_remote_code=True, use_flash_attn=True).eval() + +# Function to get files recursively from a directory +def get_files_from_directory(directory, image_types, search_subdirectories=False): + if search_subdirectories: + return [os.path.join(dp, f) for dp, dn, filenames in os.walk(directory) for f in filenames if os.path.splitext(f)[1].lower() in image_types] + else: + return [f for f in os.listdir(directory) if os.path.splitext(f)[1].lower() in image_types] + +# Get the list of image files in the specified directory, possibly including subdirectories +files = get_files_from_directory(args.imgdir, image_types, args.sub) + +# Initialize the progress bar +pbar = tqdm(total=len(files), desc="Captioning", dynamic_ncols=True, position=0, leave=True) +start_time = time.time() + +print("Captioning phase:") +for i in range(len(files)): + filename = files[i] + image_path = os.path.join(args.imgdir, filename) + + # Handle based on the argument 'exist' + txt_filename = os.path.splitext(filename)[0] + '.txt' + txt_path = os.path.join(args.imgdir, txt_filename) + + if args.exist == 'skip' and os.path.exists(txt_path): + pbar.update(1) + continue + elif args.exist == 'add' and os.path.exists(txt_path): + with open(txt_path, 'r', encoding='utf-8') as f: + existing_content = f.read() + + # Generate the caption using the model + query = tokenizer.from_list_format([ + {'image': image_path}, + {'text': args.prompt}, + ]) + response, _ = model.chat(tokenizer, query=query, history=None) + + # Clean up the caption if necessary + if has_unwanted_elements(response): + response = clean_caption(response) + + # Write the caption to the corresponding .txt file + with open(txt_path, 'w', encoding='utf-8') as f: + if args.exist == 'add' and os.path.exists(txt_path): + f.write(existing_content + " " + response) + else: + f.write(response) + + # Update progress bar with some additional information about the process + elapsed_time = time.time() - start_time + images_per_sec = (i + 1) / elapsed_time + estimated_time_remaining = (len(files) - i - 1) / images_per_sec + pbar.set_postfix({"Time Elapsed": f"{elapsed_time:.2f}s", "ETA": f"{estimated_time_remaining:.2f}s", "Speed": f"{images_per_sec:.2f} img/s"}) + pbar.update(1) + +pbar.close() \ No newline at end of file diff --git a/qwen-batch-single-pass.py b/qwen-batch-single-pass.py new file mode 100644 index 0000000..69f5890 --- /dev/null +++ b/qwen-batch-single-pass.py @@ -0,0 +1,82 @@ +import os +import re +import torch +import time +import argparse +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +# Argument Parsing +parser = argparse.ArgumentParser(description='Image Captioning Script') +parser.add_argument('--imgdir', type=str, default='path/to/img/dir', help='Directory containing images') +parser.add_argument('--exist', type=str, default='replace', choices=['skip', 'add', 'replace'], help='Handling of existing captions') +parser.add_argument('--prompt', type=str, default='describe this image in detail, in less than 35 words', help='Prompt to use for image captioning') +args = parser.parse_args() + +# Function to check for unwanted elements in the caption +def has_unwanted_elements(caption): + patterns = [r'.*?', r'.*?'] + return any(re.search(pattern, caption) for pattern in patterns) + +# Function to clean up the caption +def clean_caption(caption): + caption = re.sub(r'(.*?)', r'\1', caption) + caption = re.sub(r'.*?', '', caption) + return caption.strip() + +# Supported image types +image_types = ['.png', '.jpg', '.jpeg', '.bmp', '.gif'] + +# Initialize the model and tokenizer +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat-Int4", trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat-Int4", device_map="cuda", trust_remote_code=True, use_flash_attn=True).eval() + +# Get the list of image files in the specified directory +files = [f for f in os.listdir(args.imgdir) if os.path.splitext(f)[1].lower() in image_types] + +# Initialize the progress bar +pbar = tqdm(total=len(files), desc="Captioning", dynamic_ncols=True, position=0, leave=True) +start_time = time.time() + +print("Captioning phase:") +for i in range(len(files)): + filename = files[i] + image_path = os.path.join(args.imgdir, filename) + + # Handle based on the argument 'exist' + txt_filename = os.path.splitext(filename)[0] + '.txt' + txt_path = os.path.join(args.imgdir, txt_filename) + + if args.exist == 'skip' and os.path.exists(txt_path): + pbar.update(1) + continue + elif args.exist == 'add' and os.path.exists(txt_path): + with open(txt_path, 'r', encoding='utf-8') as f: + existing_content = f.read() + + # Generate the caption using the model + query = tokenizer.from_list_format([ + {'image': image_path}, + {'text': args.prompt}, + ]) + response, _ = model.chat(tokenizer, query=query, history=None) + + # Clean up the caption if necessary + if has_unwanted_elements(response): + response = clean_caption(response) + + # Write the caption to the corresponding .txt file + with open(txt_path, 'w', encoding='utf-8') as f: + if args.exist == 'add' and os.path.exists(txt_path): + f.write(existing_content + " " + response) + else: + f.write(response) + + # Update progress bar with some additional information about the process + elapsed_time = time.time() - start_time + images_per_sec = (i + 1) / elapsed_time + estimated_time_remaining = (len(files) - i - 1) / images_per_sec + pbar.set_postfix({"Time Elapsed": f"{elapsed_time:.2f}s", "ETA": f"{estimated_time_remaining:.2f}s", "Speed": f"{images_per_sec:.2f} img/s"}) + pbar.update(1) + +pbar.close() \ No newline at end of file