-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
93 lines (73 loc) · 2.69 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
def bilinear_interpolate(loc_ip, data):
"""bilinear interpolation for raster data"""
# Taken from: https://github.com/elijahcole/sinr/blob/main/utils.py
# loc is N x 2 vector, where each row is [lon,lat] entry
# each entry spans range [-1,1]
# data is H x W x C, height x width x channel data matrix
# op will be N x C matrix of interpolated features
assert data is not None
# map to [0,1], then scale to data size
loc = (loc_ip.clone() + 1) / 2.0
loc[:, 1] = (
1 - loc[:, 1]
) # this is because latitude goes from +90 on top to bottom while
# longitude goes from -90 to 90 left to right
assert not torch.any(torch.isnan(loc))
# cast locations into pixel space
loc[:, 0] *= data.shape[1] - 1
loc[:, 1] *= data.shape[0] - 1
loc_int = torch.floor(loc).long() # integer pixel coordinates
xx = loc_int[:, 0]
yy = loc_int[:, 1]
xx_plus = xx + 1
xx_plus[xx_plus > (data.shape[1] - 1)] = data.shape[1] - 1
yy_plus = yy + 1
yy_plus[yy_plus > (data.shape[0] - 1)] = data.shape[0] - 1
loc_delta = loc - torch.floor(loc) # delta values
dx = loc_delta[:, 0].unsqueeze(1)
dy = loc_delta[:, 1].unsqueeze(1)
interp_val = (
data[yy, xx, :] * (1 - dx) * (1 - dy)
+ data[yy, xx_plus, :] * dx * (1 - dy)
+ data[yy_plus, xx, :] * (1 - dx) * dy
+ data[yy_plus, xx_plus, :] * dx * dy
)
return interp_val
class DummyParams:
"""Dummy class needed for DefaultParams"""
def __init__(self):
pass
class DefaultParams:
"""DefaultParams for convenience when creating experiments without the hydra wrapper."""
def __init__(self, sinr):
self.dataset = DummyParams()
self.local = DummyParams()
self.pos_weight = 2048
self.lr = 5e-4
self.l2_dec = 0
self.epochs = 7
self.model = "sat_sinr_mf"
self.dataset.predictors = "loc_env_sent2"
self.sinr_layers = 8
self.sinr_hidden = 512
self.dropout = 0.3
self.tag = "DefaultParams"
self.embedder = "cnn_default"
self.validate = False
self.checkpoint = "None"
self.dataset.batchsize = 2048
self.dataset.use_ds_samples = True
self.dataset.num_workers = 16
self.local.sent_data_path = ""
self.local.bioclim_path = ""
self.local.dataset_file_path = ""
self.local.cp_dir_path = ""
self.local.logs_dir_path = ""
self.local.test_data_path = ""
self.local.gpu = True
if sinr:
self.dataset.predictors = "loc_env" # "loc"
self.epochs = 15
self.pos_weight = 2048
self.model = "sinr"