-
Notifications
You must be signed in to change notification settings - Fork 100
训练展示的指标较好,但是实际的推理的结果很差。 #84
Description
训练代码如下 run.sh脚本,跑的参数基本不变
’datapath='/oeecs/SimpleNet/datas'
datasets=('ls')
dataset_flags=(
python main.py
--gpu 1
--seed 0
--log_group simplenet_mvtec
--log_project MVTecAD_Results
--results_path results
--run_name run
net
-b wideresnet50
-le layer2
-le layer3
--pretrain_embed_dimension 1536
--target_embed_dimension 1536
--patchsize 3
--meta_epochs 20
--embedding_size 256
--gan_epochs 4
--noise_std 0.015
--dsc_hidden 1024
--dsc_layers 2
--dsc_margin .5
--pre_proj 1
dataset
--batch_size 8
--resize 329
--imagesize 288 "${dataset_flags[@]}" mvtec $datapath
数据是MvTecAD中的wall plugs
推理脚本如下
`
import os
import time
import torch
import cv2
import numpy as np
import backbones
import simplenet
from PIL import Image
from torchvision import transforms
import glob
INPUT_PATH = r"C:\python_project\SimpleNet-main\datas\ls\test\type1\004_shift_2.png"
OUTPUT_DIR = r"inference_results"
MODEL_PATH = r"result_model/ckpt-ls.pth"
BACKBONE_NAME = "wideresnet50"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
def preprocess_image(image_path, resize=329, imagesize=288):
try:
image = Image.open(image_path).convert("RGB")
transform = transforms.Compose([
transforms.Resize(resize),
transforms.CenterCrop(imagesize),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
input_tensor = transform(image).unsqueeze(0)
return input_tensor, image
except Exception as e:
print(f"Error reading image {image_path}: {e}")
return None, None
def load_model(device, backbone_name, model_path):
print(f"Initializing model with backbone: {backbone_name}...")
# Backbone
backbone = backbones.load(backbone_name)
backbone.name = backbone_name
# SimpleNet
simplenet_inst = simplenet.SimpleNet(device)
# Note: These parameters must match training configuration
simplenet_inst.load(
backbone=backbone,
layers_to_extract_from=['layer2', 'layer3'],
device=device,
input_shape=(3, 288, 288),
pretrain_embed_dimension=1536,
target_embed_dimension=1536,
patchsize=3,
embedding_size=256,
meta_epochs=40,
aed_meta_epochs=1,
gan_epochs=4,
noise_std=0.015,
dsc_layers=2,
dsc_hidden=1024,
dsc_margin=0.5,
dsc_lr=0.0002,
auto_noise=0,
train_backbone=False,
cos_lr=False,
pre_proj=1,
proj_layer_type=0,
mix_noise=1,
)
# Load weights
print(f"Loading weights from {model_path}...")
state_dict = torch.load(model_path, map_location=device)
if 'discriminator' in state_dict:
simplenet_inst.discriminator.load_state_dict(state_dict['discriminator'])
if "pre_projection" in state_dict:
simplenet_inst.pre_projection.load_state_dict(state_dict["pre_projection"])
else:
simplenet_inst.load_state_dict(state_dict, strict=False)
return simplenet_inst
def infer_single_image(simplenet_inst, image_path, output_dir, device):
filename = os.path.basename(image_path)
print(f"Processing: {filename}")
t1 = time.time()
input_tensor, original_pil = preprocess_image(image_path)
print(fr'前处理时间为“{time.time() - t1}')
if input_tensor is None:
return
input_tensor = input_tensor.to(device)
# Inference
t2 = time.time()
scores, segmentations, _ = simplenet_inst.predict(input_tensor)
print(fr'模型推理时间为“{time.time() - t2}')
anomaly_score = scores[0]
segmentation_map = segmentations[0]
# Post-processing
segmentation_map = np.maximum(segmentation_map, 0)
# Scale for visualization
heatmap = (segmentation_map * 255).clip(0, 255).astype(np.uint8)
heatmap_colored = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
# Create overlay
crop_transform = transforms.Compose([
transforms.Resize(329),
transforms.CenterCrop(288)
])
cropped = np.array(crop_transform(original_pil))
cropped = cv2.cvtColor(cropped, cv2.COLOR_RGB2BGR)
overlay = cv2.addWeighted(cropped, 0.6, heatmap_colored, 0.4, 0)
# Save results
name_no_ext = os.path.splitext(filename)[0]
cv2.imwrite(os.path.join(output_dir, f"{name_no_ext}_heatmap.png"), heatmap_colored)
cv2.imwrite(os.path.join(output_dir, f"{name_no_ext}_overlay.png"), overlay)
with open(os.path.join(output_dir, f"{name_no_ext}_score.txt"), "w") as f:
f.write(f"Anomaly Score: {anomaly_score:.6f}")
print(f" -> Score: {anomaly_score:.4f}")
def main():
print(f"Input Path: {INPUT_PATH}")
print(f"Output Directory: {OUTPUT_DIR}")
# Setup directories
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Setup Device
device = torch.device(DEVICE)
# Load Model
model = load_model(device, BACKBONE_NAME, MODEL_PATH)
# Determine input mode
if os.path.isfile(INPUT_PATH):
# Single file mode
print("Mode: Single File Inference")
infer_single_image(model, INPUT_PATH, OUTPUT_DIR, device)
elif os.path.isdir(INPUT_PATH):
# Directory mode
print("Mode: Directory Batch Inference")
image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.bmp', '*.tif', '*.tiff']
image_files = []
for ext in image_extensions:
# Case insensitive search
image_files.extend(glob.glob(os.path.join(INPUT_PATH, ext)))
image_files.extend(glob.glob(os.path.join(INPUT_PATH, ext.upper())))
# Remove duplicates
image_files = sorted(list(set(image_files)))
if not image_files:
print(f"No images found in {INPUT_PATH}")
return
print(f"Found {len(image_files)} images.")
for img_path in image_files:
infer_single_image(model, img_path, OUTPUT_DIR, device)
else:
print(f"Error: Input path '{INPUT_PATH}' does not exist.")
if name == "main":
main()
部分推理结果的照片
不知道是我的哪一步出现问题了