diff --git a/RenAIssance_Transformer_OCR_Utsav_Rai/code/app/__pycache__/app_streamlit.cpython-39.pyc b/RenAIssance_Transformer_OCR_Utsav_Rai/code/app/__pycache__/app_streamlit.cpython-39.pyc new file mode 100644 index 00000000..d3d8648c Binary files /dev/null and b/RenAIssance_Transformer_OCR_Utsav_Rai/code/app/__pycache__/app_streamlit.cpython-39.pyc differ diff --git a/RenAIssance_Transformer_OCR_Utsav_Rai/code/app/app_streamlit.py b/RenAIssance_Transformer_OCR_Utsav_Rai/code/app/app_streamlit.py index 7770a801..c5958cf3 100644 --- a/RenAIssance_Transformer_OCR_Utsav_Rai/code/app/app_streamlit.py +++ b/RenAIssance_Transformer_OCR_Utsav_Rai/code/app/app_streamlit.py @@ -1,9 +1,11 @@ import sys import os -# Add CRAFT directory to sys.path for craft imports -CRAFT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'CRAFT')) -if CRAFT_DIR not in sys.path: - sys.path.insert(0, CRAFT_DIR) + +APP_DIR = os.path.dirname(os.path.abspath(__file__)) +CRAFT_DIR = os.path.abspath(os.path.join(APP_DIR, "..", "CRAFT")) +for path in (APP_DIR, CRAFT_DIR): + if os.path.isdir(path) and path not in sys.path: + sys.path.insert(0, path) import torch import torch.backends.cudnn as cudnn from collections import OrderedDict @@ -17,7 +19,6 @@ from PIL import Image, ImageEnhance import cv2 import numpy as np -import os import math from transformers import TrOCRProcessor, VisionEncoderDecoderModel import streamlit as st @@ -25,6 +26,21 @@ st.set_page_config(layout="wide") + +def resolve_existing_path(env_var, *candidates): + override = os.getenv(env_var) + if override: + return override + + for candidate in candidates: + if os.path.exists(candidate): + return candidate + + raise FileNotFoundError( + f"Could not resolve a path for {env_var or 'required asset'}. " + f"Tried: {', '.join(candidates)}" + ) + def copyStateDict(state_dict): if list(state_dict.keys())[0].startswith("module"): start_idx = 1 @@ -39,7 +55,11 @@ def copyStateDict(state_dict): @st.cache_resource def load_craft_model(): # Define the path to the pre-trained CRAFT model weights - trained_model_path = '../../weights/craft_mlt_25k.pth' + trained_model_path = resolve_existing_path( + "RENAISSANCE_CRAFT_MODEL_PATH", + os.path.join(APP_DIR, "weights", "craft_mlt_25k.pth"), + os.path.abspath(os.path.join(APP_DIR, "..", "..", "weights", "craft_mlt_25k.pth")), + ) # Initialize the CRAFT model net = CRAFT() # initialize @@ -57,7 +77,11 @@ def load_craft_model(): refine = True # Set to True if using refine_net if refine: from refinenet import RefineNet - refiner_model_path = '../../weights/craft_refiner_CTW1500.pth' # Update the path + refiner_model_path = resolve_existing_path( + "RENAISSANCE_CRAFT_REFINER_PATH", + os.path.join(APP_DIR, "weights", "craft_refiner_CTW1500.pth"), + os.path.abspath(os.path.join(APP_DIR, "..", "..", "weights", "craft_refiner_CTW1500.pth")), + ) refine_net = RefineNet() refine_net.load_state_dict(copyStateDict(torch.load(refiner_model_path, map_location=device))) refine_net.to(device) @@ -109,9 +133,17 @@ def test_net(net, image, text_threshold, link_threshold, low_text, *, cuda, poly @st.cache_resource def load_ocr_model(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # Update path to point to the correct location of the OCR weights - model_path = "../../models" - processor_path = "../../models" + model_path = resolve_existing_path( + "RENAISSANCE_OCR_MODEL_DIR", + os.path.join(APP_DIR, "models"), + os.path.abspath(os.path.join(APP_DIR, "..", "..", "models")), + ) + processor_path = resolve_existing_path( + "RENAISSANCE_OCR_PROCESSOR_DIR", + model_path, + os.path.join(APP_DIR, "models"), + os.path.abspath(os.path.join(APP_DIR, "..", "..", "models")), + ) processor = TrOCRProcessor.from_pretrained(processor_path) model = VisionEncoderDecoderModel.from_pretrained(model_path).to(device) return processor, model, device @@ -771,4 +803,4 @@ def get_virtual_page(pdf_document, virtual_index, dpi, **kwargs): st.write("No image to display.") else: - st.info("Please upload a PDF file from the left panel.") \ No newline at end of file + st.info("Please upload a PDF file from the left panel.")