Skip to content

Commit 6c5b78e

Browse files
committed
Remove super call in __getitem__ which causes problems in Python 3.8
1 parent ffd5b86 commit 6c5b78e

File tree

7 files changed

+70
-79
lines changed

7 files changed

+70
-79
lines changed

Diff for: nets/attention_model.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,14 @@ class AttentionModelFixed(NamedTuple):
2929
logit_key: torch.Tensor
3030

3131
def __getitem__(self, key):
32-
if torch.is_tensor(key) or isinstance(key, slice):
33-
return AttentionModelFixed(
34-
node_embeddings=self.node_embeddings[key],
35-
context_node_projected=self.context_node_projected[key],
36-
glimpse_key=self.glimpse_key[:, key], # dim 0 are the heads
37-
glimpse_val=self.glimpse_val[:, key], # dim 0 are the heads
38-
logit_key=self.logit_key[key]
39-
)
40-
return super(AttentionModelFixed, self).__getitem__(key)
32+
assert torch.is_tensor(key) or isinstance(key, slice)
33+
return AttentionModelFixed(
34+
node_embeddings=self.node_embeddings[key],
35+
context_node_projected=self.context_node_projected[key],
36+
glimpse_key=self.glimpse_key[:, key], # dim 0 are the heads
37+
glimpse_val=self.glimpse_val[:, key], # dim 0 are the heads
38+
logit_key=self.logit_key[key]
39+
)
4140

4241

4342
class AttentionModel(nn.Module):

Diff for: problems/op/state_op.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,15 @@ def dist(self):
3636
return (self.coords[:, :, None, :] - self.coords[:, None, :, :]).norm(p=2, dim=-1)
3737

3838
def __getitem__(self, key):
39-
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
40-
return self._replace(
41-
ids=self.ids[key],
42-
prev_a=self.prev_a[key],
43-
visited_=self.visited_[key],
44-
lengths=self.lengths[key],
45-
cur_coord=self.cur_coord[key],
46-
cur_total_prize=self.cur_total_prize[key],
47-
)
48-
return super(StateOP, self).__getitem__(key)
39+
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
40+
return self._replace(
41+
ids=self.ids[key],
42+
prev_a=self.prev_a[key],
43+
visited_=self.visited_[key],
44+
lengths=self.lengths[key],
45+
cur_coord=self.cur_coord[key],
46+
cur_total_prize=self.cur_total_prize[key],
47+
)
4948

5049
# Warning: cannot override len of NamedTuple, len should be number of fields, not batch size
5150
# def __len__(self):

Diff for: problems/pctsp/state_pctsp.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,16 @@ def dist(self):
3636
return (self.coords[:, :, None, :] - self.coords[:, None, :, :]).norm(p=2, dim=-1)
3737

3838
def __getitem__(self, key):
39-
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
40-
return self._replace(
41-
ids=self.ids[key],
42-
prev_a=self.prev_a[key],
43-
visited_=self.visited_[key],
44-
lengths=self.lengths[key],
45-
cur_total_prize=self.cur_total_prize[key],
46-
cur_total_penalty=self.cur_total_penalty[key],
47-
cur_coord=self.cur_coord[key],
48-
)
49-
return super(StatePCTSP, self).__getitem__(key)
39+
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
40+
return self._replace(
41+
ids=self.ids[key],
42+
prev_a=self.prev_a[key],
43+
visited_=self.visited_[key],
44+
lengths=self.lengths[key],
45+
cur_total_prize=self.cur_total_prize[key],
46+
cur_total_penalty=self.cur_total_penalty[key],
47+
cur_coord=self.cur_coord[key],
48+
)
5049

5150
# Warning: cannot override len of NamedTuple, len should be number of fields, not batch size
5251
# def __len__(self):

Diff for: problems/tsp/state_tsp.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,15 @@ def visited(self):
2828
return mask_long2bool(self.visited_, n=self.loc.size(-2))
2929

3030
def __getitem__(self, key):
31-
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
32-
return self._replace(
33-
ids=self.ids[key],
34-
first_a=self.first_a[key],
35-
prev_a=self.prev_a[key],
36-
visited_=self.visited_[key],
37-
lengths=self.lengths[key],
38-
cur_coord=self.cur_coord[key] if self.cur_coord is not None else None,
39-
)
40-
return super(StateTSP, self).__getitem__(key)
31+
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
32+
return self._replace(
33+
ids=self.ids[key],
34+
first_a=self.first_a[key],
35+
prev_a=self.prev_a[key],
36+
visited_=self.visited_[key],
37+
lengths=self.lengths[key],
38+
cur_coord=self.cur_coord[key] if self.cur_coord is not None else None,
39+
)
4140

4241
@staticmethod
4342
def initialize(loc, visited_dtype=torch.uint8):

Diff for: problems/vrp/state_cvrp.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,15 @@ def dist(self):
3434
return (self.coords[:, :, None, :] - self.coords[:, None, :, :]).norm(p=2, dim=-1)
3535

3636
def __getitem__(self, key):
37-
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
38-
return self._replace(
39-
ids=self.ids[key],
40-
prev_a=self.prev_a[key],
41-
used_capacity=self.used_capacity[key],
42-
visited_=self.visited_[key],
43-
lengths=self.lengths[key],
44-
cur_coord=self.cur_coord[key],
45-
)
46-
return super(StateCVRP, self).__getitem__(key)
37+
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
38+
return self._replace(
39+
ids=self.ids[key],
40+
prev_a=self.prev_a[key],
41+
used_capacity=self.used_capacity[key],
42+
visited_=self.visited_[key],
43+
lengths=self.lengths[key],
44+
cur_coord=self.cur_coord[key],
45+
)
4746

4847
# Warning: cannot override len of NamedTuple, len should be number of fields, not batch size
4948
# def __len__(self):

Diff for: problems/vrp/state_sdvrp.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,15 @@ class StateSDVRP(NamedTuple):
2222
VEHICLE_CAPACITY = 1.0 # Hardcoded
2323

2424
def __getitem__(self, key):
25-
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
26-
return self._replace(
27-
ids=self.ids[key],
28-
prev_a=self.prev_a[key],
29-
used_capacity=self.used_capacity[key],
30-
demands_with_depot=self.demands_with_depot[key],
31-
lengths=self.lengths[key],
32-
cur_coord=self.cur_coord[key],
33-
)
34-
return super(StateSDVRP, self).__getitem__(key)
25+
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
26+
return self._replace(
27+
ids=self.ids[key],
28+
prev_a=self.prev_a[key],
29+
used_capacity=self.used_capacity[key],
30+
demands_with_depot=self.demands_with_depot[key],
31+
lengths=self.lengths[key],
32+
cur_coord=self.cur_coord[key],
33+
)
3534

3635
@staticmethod
3736
def initialize(input):

Diff for: utils/beam_search.py

+16-19
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,14 @@ def ids(self):
7070
return self.state.ids.view(-1) # Need to flat as state has steps dimension
7171

7272
def __getitem__(self, key):
73-
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
74-
return self._replace(
75-
# ids=self.ids[key],
76-
score=self.score[key] if self.score is not None else None,
77-
state=self.state[key],
78-
parent=self.parent[key] if self.parent is not None else None,
79-
action=self.action[key] if self.action is not None else None
80-
)
81-
return super(BatchBeam, self).__getitem__(key)
73+
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
74+
return self._replace(
75+
# ids=self.ids[key],
76+
score=self.score[key] if self.score is not None else None,
77+
state=self.state[key],
78+
parent=self.parent[key] if self.parent is not None else None,
79+
action=self.action[key] if self.action is not None else None
80+
)
8281

8382
# Do not use __len__ since this is used by namedtuple internally and should be number of fields
8483
# def __len__(self):
@@ -207,15 +206,13 @@ def __getitem__(self, key):
207206
assert not isinstance(key, slice), "CachedLookup does not support slicing, " \
208207
"you can slice the result of an index operation instead"
209208

210-
if torch.is_tensor(key): # If tensor, idx all tensors by this tensor:
211-
212-
if self.key is None:
213-
self.key = key
214-
self.current = self.orig[key]
215-
elif len(key) != len(self.key) or (key != self.key).any():
216-
self.key = key
217-
self.current = self.orig[key]
209+
assert torch.is_tensor(key) # If tensor, idx all tensors by this tensor:
218210

219-
return self.current
211+
if self.key is None:
212+
self.key = key
213+
self.current = self.orig[key]
214+
elif len(key) != len(self.key) or (key != self.key).any():
215+
self.key = key
216+
self.current = self.orig[key]
220217

221-
return super(CachedLookup, self).__getitem__(key)
218+
return self.current

0 commit comments

Comments
 (0)