forked from zhyever/PatchFusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
159 lines (134 loc) · 7.72 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# predict.py
from typing import Any
from cog import BasePredictor, Input, Path
import os
import sys
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) # Get directory of predict.py
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT) # Add to the beginning of sys.path
EXTERNAL_DIR = os.path.join(PROJECT_ROOT, "external") # Assuming 'external' is a subdirectory
PATCHFUSION_DIR = os.path.join(PROJECT_ROOT, "patchfusion") # Assuming 'patchfusion' is a subdirectory
if EXTERNAL_DIR not in sys.path and os.path.exists(EXTERNAL_DIR):
sys.path.insert(0, EXTERNAL_DIR)
if PATCHFUSION_DIR not in sys.path and os.path.exists(PATCHFUSION_DIR):
sys.path.insert(0, PATCHFUSION_DIR)
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # Add repo root to path for imports
import os.path as osp
import numpy as np
import torch
from PIL import Image
import cv2
from torchvision import transforms
from mmengine.config import Config # Import Config for loading configs
from estimator.models.builder import build_model # Import build_model
from mmengine import print_log # For logging (optional, can use standard print)
import time # For timing (optional)
# --- ADDED IMPORT for ResizeDA ---
from depth_anything.transform import Resize as ResizeDA
class Predictor(BasePredictor):
def setup(self):
"""Load the PatchFusion model and preprocessing transforms."""
start_time = time.time()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- 1. Load Config ---
config_path = "configs/patchfusion_depthanything/depthanything_vitl_patchfusion_u4k.py"
print(f"Loading config from: {config_path}")
self.cfg = Config.fromfile(config_path)
# --- 2. Local Checkpoint Path ---
local_ckp_path = "./ckps/patchfusion.pth"
print(f"Loading checkpoint from: {local_ckp_path}")
if not os.path.exists(local_ckp_path):
raise FileNotFoundError(f"Checkpoint file not found at: {local_ckp_path}")
self.cfg.ckp_path = local_ckp_path
# --- 3. Build Model ---
print("Building PatchFusion model architecture...")
model = build_model(self.cfg.model)
print(f'Checkpoint Path: {self.cfg.ckp_path}. Loading from a local file')
print_log(f'Checkpoint Path: {self.cfg.ckp_path}. Loading from a local file', logger='current')
# --- 4. Load Model Weights ---
print("Loading model weights from local checkpoint...")
checkpoint = torch.load(self.cfg.ckp_path, weights_only=False)
state_dict = checkpoint.get('model_state_dict', checkpoint)
if hasattr(model, 'load_dict'):
load_info = model.load_dict(state_dict)
print_log(load_info, logger='current')
else:
load_info = model.load_state_dict(state_dict, strict=True)
print_log(load_info, logger='current')
self.model = model.to(self.device).eval()
# --- 5. Preprocessing Transforms (for normalization) ---
print("Setting up preprocessing transforms (normalization)...")
self.transform = transforms.Compose([
transforms.ToTensor(),
])
print("Preprocessing transforms (normalization) setup complete.")
# --- 6. Dataset-like ResizeDA Transform for image_lr --- <--- NEW CODE BLOCK START
print("Setting up dataset-like ResizeDA transform for image_lr...")
network_process_size = (392, 518)
net_h, net_w = network_process_size[0], network_process_size[1]
self.resize = ResizeDA( # Use ResizeDA for 'depth-anything' mode
net_w, net_h,
keep_aspect_ratio=False,
ensure_multiple_of=14,
resize_method="minimal"
)
print("Dataset-like ResizeDA transform setup complete.") # <--- NEW CODE BLOCK END
print(f"Setup finished in {time.time() - start_time:.2f} seconds.")
def predict(self,
image: Path = Input(description="Input image for depth estimation"),
patch_split_x: int = Input(description="Patch split X dimension (minimal value is 2)", default=3, ge=2),
patch_split_y: int = Input(description="Patch split Y dimension (minimal value is 2)", default=3, ge=2),
num_process: int = Input(description="Number of processes (minimal value is 1)", default=1, ge=1),
) -> Path:
"""Run depth estimation with PatchFusion and return path to depth map JPG."""
print(f"Processing image: {image}")
start_time = time.time()
image_pil = None
depth_map_image_pil = None
output_depth_map_path = None
try:
image_pil = Image.open(str(image)).convert('RGB')
original_width, original_height = image_pil.size
patch_split_num = [patch_split_y, patch_split_x] # change to [patch_split_y, patch_split_x] to align with height and width
downscale_factor_width = 2 * patch_split_num[1]
downscale_factor_height = 2 * patch_split_num[0]
new_width = (original_width // downscale_factor_width) * downscale_factor_width
new_height = (original_height // downscale_factor_height) * downscale_factor_height
downscaled_size = (new_width, new_height)
print(f"Original size: ({original_width}, {original_height}), Downscaled size: {downscaled_size}, patch_split_num: {patch_split_num}")
resized_image_pil = image_pil.resize(downscaled_size, Image.BICUBIC)
image_hr = self.transform(resized_image_pil).unsqueeze(0).to(self.device)
image_lr = self.resize(image_hr)
tile_cfg = dict()
tile_cfg['image_raw_shape'] = [new_height, new_width]
tile_cfg['patch_split_num'] = patch_split_num
with torch.no_grad():
result, _ = self.model(
'infer', image_lr, image_hr,
cai_mode='r32',
process_num=num_process, # use input num_process
tile_cfg=tile_cfg,
)
depth_array_float32 = result.clone().squeeze().detach().cpu().numpy().astype(np.float32)
depth_array_log_scaled = np.log1p(depth_array_float32)
min_log_depth = np.min(depth_array_log_scaled)
max_log_depth = np.max(depth_array_log_scaled)
if max_log_depth > min_log_depth:
depth_array_log_rescaled_0_1 = (depth_array_log_scaled - min_log_depth) / (max_log_depth - min_log_depth)
depth_array_rescaled_0_255 = (depth_array_log_rescaled_0_1 * 255).astype(np.uint8)
else:
depth_array_rescaled_0_255 = np.zeros_like(depth_array_log_scaled, dtype=np.uint8)
rescaled_depth_image_pil = Image.fromarray(255 - depth_array_rescaled_0_255, mode='L')
output_dir = "output"
os.makedirs(output_dir, exist_ok=True)
input_filename = os.path.splitext(os.path.basename(str(image)))[0] # Extract filename without extension from input Path
patch_split_str = f"{patch_split_x}x{patch_split_y}" # Create patch_split_num string (e.g., "2x2")
output_depth_map_filename = f"{input_filename}_pf{patch_split_str}_p{num_process}.webp" # add num_process to filename
output_depth_map_path = os.path.join(output_dir, output_depth_map_filename)
rescaled_depth_image_pil.save(output_depth_map_path, format="WebP", lossless=True, quality=100) # Save rescaled depth map
print(f"Prediction finished in {time.time() - start_time:.2f} seconds.")
return Path(output_depth_map_path) # Return path to saved WEBP depth map
finally:
print("Cleaning up memory...")
torch.cuda.empty_cache()
print("Memory cleanup complete.")