Skip to content

训练展示的指标较好,但是实际的推理的结果很差。 #84

@ILZFLNO02

Description

@ILZFLNO02

训练代码如下 run.sh脚本,跑的参数基本不变

’datapath='/oeecs/SimpleNet/datas'
datasets=('ls')
dataset_flags=($(for dataset in "${datasets[@]}"; do echo '-d '"${dataset}"; done))

python main.py
--gpu 1
--seed 0
--log_group simplenet_mvtec
--log_project MVTecAD_Results
--results_path results
--run_name run
net
-b wideresnet50
-le layer2
-le layer3
--pretrain_embed_dimension 1536
--target_embed_dimension 1536
--patchsize 3
--meta_epochs 20
--embedding_size 256
--gan_epochs 4
--noise_std 0.015
--dsc_hidden 1024
--dsc_layers 2
--dsc_margin .5
--pre_proj 1
dataset
--batch_size 8
--resize 329
--imagesize 288 "${dataset_flags[@]}" mvtec $datapath

数据是MvTecAD中的wall plugs

推理脚本如下

`
import os
import time

import torch
import cv2
import numpy as np
import backbones
import simplenet
from PIL import Image
from torchvision import transforms
import glob

INPUT_PATH = r"C:\python_project\SimpleNet-main\datas\ls\test\type1\004_shift_2.png"

OUTPUT_DIR = r"inference_results"

MODEL_PATH = r"result_model/ckpt-ls.pth"

BACKBONE_NAME = "wideresnet50"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

def preprocess_image(image_path, resize=329, imagesize=288):
try:
image = Image.open(image_path).convert("RGB")
transform = transforms.Compose([
transforms.Resize(resize),
transforms.CenterCrop(imagesize),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
input_tensor = transform(image).unsqueeze(0)
return input_tensor, image
except Exception as e:
print(f"Error reading image {image_path}: {e}")
return None, None

def load_model(device, backbone_name, model_path):
print(f"Initializing model with backbone: {backbone_name}...")
# Backbone
backbone = backbones.load(backbone_name)
backbone.name = backbone_name

# SimpleNet
simplenet_inst = simplenet.SimpleNet(device)
# Note: These parameters must match training configuration
simplenet_inst.load(
    backbone=backbone,
    layers_to_extract_from=['layer2', 'layer3'],
    device=device,
    input_shape=(3, 288, 288),
    pretrain_embed_dimension=1536,
    target_embed_dimension=1536,
    patchsize=3,
    embedding_size=256,
    meta_epochs=40,
    aed_meta_epochs=1,
    gan_epochs=4,
    noise_std=0.015,
    dsc_layers=2,
    dsc_hidden=1024,
    dsc_margin=0.5,
    dsc_lr=0.0002,
    auto_noise=0,
    train_backbone=False,
    cos_lr=False,
    pre_proj=1,
    proj_layer_type=0,
    mix_noise=1,
)

# Load weights
print(f"Loading weights from {model_path}...")
state_dict = torch.load(model_path, map_location=device)
if 'discriminator' in state_dict:
    simplenet_inst.discriminator.load_state_dict(state_dict['discriminator'])
    if "pre_projection" in state_dict:
        simplenet_inst.pre_projection.load_state_dict(state_dict["pre_projection"])
else:
    simplenet_inst.load_state_dict(state_dict, strict=False)

return simplenet_inst

def infer_single_image(simplenet_inst, image_path, output_dir, device):
filename = os.path.basename(image_path)
print(f"Processing: {filename}")
t1 = time.time()
input_tensor, original_pil = preprocess_image(image_path)
print(fr'前处理时间为“{time.time() - t1}')
if input_tensor is None:
return

input_tensor = input_tensor.to(device)

# Inference
t2 = time.time()
scores, segmentations, _ = simplenet_inst.predict(input_tensor)
print(fr'模型推理时间为“{time.time() - t2}')
anomaly_score = scores[0]
segmentation_map = segmentations[0]

# Post-processing
segmentation_map = np.maximum(segmentation_map, 0)

# Scale for visualization
heatmap = (segmentation_map * 255).clip(0, 255).astype(np.uint8)
heatmap_colored = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

# Create overlay
crop_transform = transforms.Compose([
    transforms.Resize(329),
    transforms.CenterCrop(288)
])
cropped = np.array(crop_transform(original_pil))
cropped = cv2.cvtColor(cropped, cv2.COLOR_RGB2BGR)
overlay = cv2.addWeighted(cropped, 0.6, heatmap_colored, 0.4, 0)

# Save results
name_no_ext = os.path.splitext(filename)[0]

cv2.imwrite(os.path.join(output_dir, f"{name_no_ext}_heatmap.png"), heatmap_colored)
cv2.imwrite(os.path.join(output_dir, f"{name_no_ext}_overlay.png"), overlay)

with open(os.path.join(output_dir, f"{name_no_ext}_score.txt"), "w") as f:
    f.write(f"Anomaly Score: {anomaly_score:.6f}")

print(f"  -> Score: {anomaly_score:.4f}")

def main():
print(f"Input Path: {INPUT_PATH}")
print(f"Output Directory: {OUTPUT_DIR}")

# Setup directories
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Setup Device
device = torch.device(DEVICE)

# Load Model
model = load_model(device, BACKBONE_NAME, MODEL_PATH)

# Determine input mode
if os.path.isfile(INPUT_PATH):
    # Single file mode
    print("Mode: Single File Inference")
    infer_single_image(model, INPUT_PATH, OUTPUT_DIR, device)
elif os.path.isdir(INPUT_PATH):
    # Directory mode
    print("Mode: Directory Batch Inference")
    image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.bmp', '*.tif', '*.tiff']
    image_files = []
    for ext in image_extensions:
        # Case insensitive search
        image_files.extend(glob.glob(os.path.join(INPUT_PATH, ext)))
        image_files.extend(glob.glob(os.path.join(INPUT_PATH, ext.upper())))
    
    # Remove duplicates
    image_files = sorted(list(set(image_files)))
    
    if not image_files:
        print(f"No images found in {INPUT_PATH}")
        return
        
    print(f"Found {len(image_files)} images.")
    for img_path in image_files:
        infer_single_image(model, img_path, OUTPUT_DIR, device)
else:
    print(f"Error: Input path '{INPUT_PATH}' does not exist.")

if name == "main":
main()

部分推理结果的照片

Image

不知道是我的哪一步出现问题了

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions