-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset.py
201 lines (182 loc) · 7.26 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import numpy as np
import pandas as pd
import torch
from PIL import Image
from torchvision.transforms import PILToTensor, v2
from utils import bilinear_interpolate
# load data and put it into torch dataset
class SINR_DS(torch.utils.data.Dataset):
def __init__(self, params, dataset_file, predictors, bioclim_path, sent_data_path):
super().__init__()
self.data = dataset_file
self.bioclim_path = bioclim_path
# test_data is not used by the dataset itself, but the model needs this object
with open(params.local.test_data_path, "r") as f:
data_test = pd.read_csv(f, sep=";", header="infer", low_memory=False)
self.test_data = (
data_test.groupby(["patchID", "dayOfYear", "lon", "lat"])
.agg({"speciesId": lambda x: list(x)})
.reset_index()
)
self.test_data = {
str(entry["lon"])
+ "/"
+ str(entry["lat"])
+ "/"
+ str(entry["dayOfYear"])
+ "/"
+ str(entry["patchID"]): entry["speciesId"]
for idx, entry in self.test_data.iterrows()
}
self.predictors = predictors
if "sent2" in predictors:
self.to_tensor = PILToTensor()
if "env" in predictors:
# The raster we are loading is already cropped to Europe and normalized
context_feats = np.load(bioclim_path).astype(np.float32)
self.raster = torch.from_numpy(context_feats)
self.raster[torch.isnan(self.raster)] = (
0.0 # replace with mean value (0 is mean post-normalization)
)
self.sent_data_path = sent_data_path
self.transforms = v2.Compose(
[
v2.RandomHorizontalFlip(p=0.5),
v2.RandomVerticalFlip(p=0.5),
]
)
def __len__(self):
return len(self.data)
def _normalize_loc_to_uniform(self, lon, lat):
"""Normalizes lon and lat between [-1,1]"""
lon = (lon - (-10.53904)) / (34.55792 - (-10.53904))
lat = (lat - 34.56858) / (71.18392 - 34.56858)
lon = lon * 2 - 1
lat = lat * 2 - 1
return lon, lat
def _encode_loc(self, lon, lat):
"""Expects lon and lat to be scale between [-1,1]"""
features = [
np.sin(np.pi * lon),
np.cos(np.pi * lon),
np.sin(np.pi * lat),
np.cos(np.pi * lat),
]
return np.stack(features, axis=-1)
def sample_encoded_locs(self, size):
"""Samples #size random locations from dataset, along with environmental factors"""
lon = np.random.rand(size)
lat = np.random.rand(size)
lon = lon * 2 - 1
lat = lat * 2 - 1
loc_enc = torch.tensor(self._encode_loc(lon, lat), dtype=torch.float32)
if "env" in self.predictors:
env_enc = bilinear_interpolate(
torch.stack([torch.tensor(lon), torch.tensor(lat)], dim=1), self.raster
)
if "loc" in self.predictors:
return torch.cat([loc_enc, env_enc], dim=1).type("torch.FloatTensor")
else:
return env_enc.type("torch.FloatTensor")
else:
return loc_enc
def get_env_raster(self, lon, lat):
"""Rescales lon/lat to [-1,1] and gets env raster values through bilinear interpolation.
The normalization bounds are the bounds that were used to crop the bioclim raster to Europe.
They are independent of the bounds used in _normalize_loc_to_uniform."""
lat = (lat - 34) / (72 - 34)
lon = (lon - (-11)) / (35 - (-11))
lon = lon * 2 - 1
lat = lat * 2 - 1
return bilinear_interpolate(torch.tensor([[lon, lat]]), self.raster)
def get_loc_env(self, lon, lat):
"""Given lon and lat, create the location and environmental embedding."""
lon_norm, lat_norm = self._normalize_loc_to_uniform(lon, lat)
loc_enc = torch.tensor(
self._encode_loc(lon_norm, lat_norm), dtype=torch.float32
)
env_enc = self.get_env_raster(lon, lat).type("torch.FloatTensor")
return torch.cat((loc_enc, env_enc.view(20)))
def get_env(self, lon, lat):
"""Get env raster values, turn into FloatTensor and reshape."""
env_enc = self.get_env_raster(lon, lat).type("torch.FloatTensor")
return env_enc.view(20)
def get_lon(self, lon, lat):
"""Create location embedding and turn into FloatTensor."""
lon_norm, lat_norm = self._normalize_loc_to_uniform(lon, lat)
return torch.tensor(self._encode_loc(lon_norm, lat_norm), dtype=torch.float32)
def encode(self, lon, lat):
"""Three different options to combine loc and env embeddings."""
if "env" in self.predictors:
if "loc" in self.predictors:
return self.get_loc_env(lon, lat)
else:
return self.get_env(lon, lat)
else:
return self.get_lon(lon, lat)
def get_gbif_sent2(self, pid):
"""Get Sentinel-2 image for patch_id."""
rgb_path = (
self.sent_data_path
+ "rgb/"
+ str(pid)[-2:]
+ "/"
+ str(pid)[-4:-2]
+ "/"
+ str(pid)
+ ".jpeg"
)
nir_path = (
self.sent_data_path
+ "nir/"
+ str(pid)[-2:]
+ "/"
+ str(pid)[-4:-2]
+ "/"
+ str(pid)
+ ".jpeg"
)
rgb = Image.open(rgb_path)
nir = Image.open(nir_path)
img = torch.concat([self.to_tensor(rgb), self.to_tensor(nir)], dim=0) / 255
return self.transforms(img)
def __getitem__(self, idx):
"""Combines previous methods to return the item based on the predictor combination.
The steps, in which the dataset constructs a datapoint, are a bit convoluted."""
data_dict = self.data.iloc[idx]
lon, lat = tuple(data_dict[["lon", "lat"]].to_numpy())
if "sent2" in self.predictors:
return (
self.encode(lon, lat),
self.get_gbif_sent2(data_dict["patchID"]),
torch.tensor(data_dict["speciesId"]),
)
else:
return self.encode(lon, lat), torch.tensor(data_dict["speciesId"])
def create_datasets(params):
"""Creates dataset and dataloaders from the various files"""
dataset_file = pd.read_csv(
params.local.dataset_file_path, sep=";", header="infer", low_memory=False
)
bioclim_path = params.local.bioclim_path
dataset = SINR_DS(
params,
dataset_file,
params.dataset.predictors,
sent_data_path=params.local.sent_data_path,
bioclim_path=bioclim_path,
)
ds_train, ds_val = torch.utils.data.random_split(dataset, [0.9, 0.1])
train_loader = torch.utils.data.DataLoader(
ds_train,
shuffle=True,
batch_size=params.dataset.batchsize,
num_workers=params.dataset.num_workers,
)
val_loader = torch.utils.data.DataLoader(
ds_val,
shuffle=False,
batch_size=params.dataset.batchsize,
num_workers=params.dataset.num_workers,
)
return dataset, train_loader, val_loader