-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_eval_plotting.py
128 lines (100 loc) · 4.5 KB
/
model_eval_plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from xgcm import Grid
import pop_tools
import gcsfs
import fsspec as fs
import numpy as np
import xesmf as xe
import xarray as xr
import random
import matplotlib.pyplot as plt
import warnings
from xgcm import Grid
import importlib
import preprocessing
import os
import xrft
import gcm_filters
import random
warnings.filterwarnings("ignore")
importlib.reload(preprocessing)
from preprocessing import preprocess_data
from gcm_filtering import filter_inputs_dataset
from gcm_filtering import filter_inputs
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
def evaluate_model(model, test_loader, criterion, batch_size=1):
model.eval() # Set the model to evaluation mode
total_loss = 0.0
all_predictions = []
all_targets = []
with torch.no_grad(): # Disable gradient calculation for evaluation
for inputs, targets in test_loader:
if batch_size == 1:
# For batch size 1, reshape and permute as needed
inputs = inputs.squeeze(0).permute(2, 0, 1).float() # Prepare input shape (vars, y, x)
targets = targets.squeeze(0).float() # Prepare target shape (y, x)
outputs = model(inputs.unsqueeze(0)) # Add batch dimension back
targets = targets.unsqueeze(0) # Add batch dimension to targets
else:
# For larger batch sizes, permute inputs to (batch_size, vars, y, x)
inputs = inputs.permute(0, 3, 1, 2).float() # Prepare input shape (batch_size, vars, y, x)
targets = targets.float() # Prepare target shape (batch_size, y, x)
outputs = model(inputs) # No need to add batch dimension, it already exists
# Calculate loss
loss = criterion(outputs, targets)
# Accumulate loss
total_loss += loss.item()
# Store predictions and targets
all_predictions.append(outputs.cpu()) # Store outputs directly
all_targets.append(targets.cpu()) # Store targets directly
average_loss = total_loss / len(test_loader)
# Convert lists to tensors
all_predictions = torch.cat(all_predictions, dim=0) # Concatenate along the batch dimension
all_targets = torch.cat(all_targets, dim=0) # Concatenate along the batch dimension
return average_loss, all_predictions, all_targets
def plot_predictions_vs_targets(predictions, targets, num_samples=9):
"""
Plot predictions and targets side-by-side with individual colorbars for each pair.
The prediction plot uses the target plot's vmin and vmax for consistency.
Args:
predictions (Tensor): Model predictions, shape (N, C, H, W) or (N, H, W).
targets (Tensor): Ground truth targets, shape (N, C, H, W) or (N, H, W).
num_samples (int): Number of samples to display. Default is 9.
"""
# Ensure we're only plotting a limited number of samples
num_samples = min(num_samples, predictions.shape[0])
# Create a figure with subplots
fig, axes = plt.subplots(num_samples, 2, figsize=(10, 2 * num_samples))
for i in range(num_samples):
# Determine vmin and vmax for the target plot
target_image = targets[i]
target_vmin = target_image.min().item()
target_vmax = target_image.max().item()
# Plot the target image
ax = axes[i, 0]
# Check the number of dimensions and reshape if necessary
if target_image.dim() == 3:
img = ax.imshow(target_image.permute(1, 2, 0).cpu().numpy(), vmin=target_vmin, vmax=target_vmax)
elif target_image.dim() == 2:
img = ax.imshow(target_image.cpu().numpy(), vmin=target_vmin, vmax=target_vmax)
ax.set_title(f'Target {i + 1}')
ax.axis('off')
# Add colorbar for the target image
fig.colorbar(img, ax=ax)
# Plot the predicted image
ax = axes[i, 1]
prediction_image = predictions[i]
# Check the number of dimensions and reshape if necessary
if prediction_image.dim() == 3:
img = ax.imshow(prediction_image.permute(1, 2, 0).cpu().numpy(), vmin=target_vmin, vmax=target_vmax)
elif prediction_image.dim() == 2:
img = ax.imshow(prediction_image.cpu().numpy(), vmin=target_vmin, vmax=target_vmax)
ax.set_title(f'Prediction {i + 1}')
ax.axis('off')
# Add colorbar for the predicted image
fig.colorbar(img, ax=ax)
plt.tight_layout()
plt.show()