-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
60 lines (49 loc) · 2.04 KB
/
Copy pathevaluate.py
File metadata and controls
60 lines (49 loc) · 2.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn.functional as F
from torchvision.models import ResNet18_Weights
from models.encoder import MetaResNet18Encoder
from models.metaoptnet_svm import MetaOptNetSVM
from data.dataset import DeepfakeDataset
from utils.few_shot_sampler import create_episode
import argparse
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, required=True, help="Path to trained model (.pth)")
parser.add_argument('--data_path', type=str, required=True, help="Path to dataset")
parser.add_argument('--episodes', type=int, default=10, help="Number of evaluation episodes")
args = parser.parse_args()
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load encoder model
encoder = MetaResNet18Encoder(weights=ResNet18_Weights.DEFAULT).to(device)
# Remap keys if needed (e.g., from train.py saved model)
state_dict = torch.load(args.model_path, map_location=device)
remapped = {}
for k, v in state_dict.items():
if k.startswith("backbone."):
remapped[k.replace("backbone.", "features.", 1)] = v
else:
remapped[k] = v
encoder.load_state_dict(remapped, strict=False)
encoder.eval()
# Load dataset
dataset = DeepfakeDataset(args.data_path)
classifier = MetaOptNetSVM()
# Run evaluation over episodes
total_correct = 0
total_queries = 0
for episode in range(args.episodes):
support_x, support_y, query_x, query_y = create_episode(dataset)
support_x, support_y = support_x.to(device), support_y.to(device)
query_x, query_y = query_x.to(device), query_y.to(device)
with torch.no_grad():
support_features = encoder(support_x)
query_features = encoder(query_x)
preds = classifier(support_features, support_y, query_features)
pred_labels = preds.argmax(dim=1)
total_correct += (pred_labels == query_y).sum().item()
total_queries += query_y.size(0)
# Final Accuracy
accuracy = total_correct / total_queries * 100
print(f"✅ Final Evaluation Accuracy: {accuracy:.2f}%")