Skip to content

Commit fc1f9a8

Browse files
committed
Add memory-efficient trick with gradients
1 parent 7a2d62d commit fc1f9a8

10 files changed

+3914
-31
lines changed

.gitignore

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
11
/.ipynb_checkpoints
22
**/__pycache__
3-
/lkh_data/*
4-
/.vscode
5-
/checkpts/*
6-
/valsets/*
3+
/.vscode

attention_dynamic_model.py

+44-7
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,32 @@ def get_projections(self, embeddings, context_vectors):
185185

186186
return K_tanh, Q_context, K, V
187187

188-
def forward(self, inputs, return_pi=False):
188+
189+
def fwd_rein_loss(self, inputs, baseline, bl_vals, num_batch, return_pi=False):
190+
"""
191+
Forward and calculate loss for REINFORCE algorithm in a memory efficient way.
192+
This sacrifices a bit of performance but is way better in memory terms and works
193+
by reordering the terms in the gradient formula such that we don't store gradients
194+
for all the seguence for a long time which hence produces a lot of memory consumption.
195+
"""
196+
197+
on_training = self.training
198+
self.eval()
199+
with torch.no_grad():
200+
cost, log_likelihood, seq = self(inputs, True)
201+
bl_val = bl_vals[num_batch] if bl_vals is not None else baseline.eval(inputs, cost)
202+
pre_cost = cost - bl_val.detach()
203+
detached_loss = torch.mean((pre_cost) * log_likelihood)
204+
205+
if on_training: self.train()
206+
return detached_loss, self(inputs, return_pi, seq, pre_cost)
207+
208+
def forward(self, inputs, return_pi=False, pre_selects=None, pre_cost=None):
209+
"""
210+
Forward method. Works as expected except and as described on the paper, however
211+
if pre_selects is None which hence implies that pre_cost should be none it's because
212+
fwd_rein_loss is calling it; check that method for a description of why this is useful.
213+
"""
189214

190215
self.batch_size = inputs[0].shape[0]
191216

@@ -194,7 +219,10 @@ def forward(self, inputs, return_pi=False):
194219
sequences = []
195220
ll = torch.zeros(self.batch_size)
196221

222+
if pre_selects is not None:
223+
pre_selects = pre_selects.transpose(0, 1)
197224
# Perform decoding steps
225+
pre_select_idx = 0
198226
while not state.all_finished():
199227

200228
state.i = torch.zeros(1, dtype=torch.int64)
@@ -222,22 +250,31 @@ def forward(self, inputs, return_pi=False):
222250
log_p = self.get_log_p(mha, K_tanh, mask) # (batch_size, 1, n_nodes)
223251

224252
# next step is to select node
225-
selected = self._select_node(log_p.detach()) # (batch_size,)
253+
if pre_selects is None:
254+
selected = self._select_node(log_p.detach()) # (batch_size,)
255+
else:
256+
selected = pre_selects[pre_select_idx]
226257

227258
state.step(selected.detach().cpu())
228259

229-
ll += self.get_likelihood_selection(log_p[:, 0, :].cpu(), selected.detach().cpu())
230-
260+
curr_ll = self.get_likelihood_selection(log_p[:, 0, :].cpu(), selected.detach().cpu())
261+
if pre_selects is not None:
262+
curr_loss = (curr_ll * pre_cost).sum() / self.batch_size
263+
curr_loss.backward(retain_graph=True)
264+
curr_ll = curr_ll.detach()
265+
ll += curr_ll
266+
231267
sequences.append(selected.detach().cpu())
268+
pre_select_idx += 1
232269
# torch.cuda.empty_cache()
233270
# torch.cuda.empty_cache()
234271

235272
pi = torch.stack(sequences, dim=1) # (batch_size, len(outputs))
236273
cost = self.problem.get_costs((inputs[0].detach().cpu(), inputs[1].detach().cpu(), inputs[2].detach().cpu()), pi)
237-
if return_pi:
238-
return cost, ll, pi
239274

240-
return cost, ll
275+
ret = [cost, ll]
276+
if return_pi: ret.append(pi)
277+
return ret
241278

242279
def set_input_device(self, inp_tens):
243280
if self.dev is None: self.dev = get_dev_of_mod(self)

backup_results_VRP_20_2021-08-31.csv

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
epochs,train_loss,train_cost,val_cost
2+
0,0.057725433,8.126152,7.142528
3+
1,0.43523985,7.0198774,6.8331623
4+
2,-0.05244048,6.799927,6.692892
5+
3,-0.21513543,6.7059703,6.6481056
6+
4,-0.16910644,6.6476564,6.585553
7+
5,-0.21084756,6.606468,6.55772
8+
6,-0.20123357,6.5790243,6.533435
9+
7,-0.20451881,6.5594954,6.523542
10+
8,-0.17670833,6.5427947,6.505786
11+
9,-0.17904146,6.530585,6.4873476
12+
10,-0.19809744,6.519344,6.4940104
13+
11,-0.15416735,6.508296,6.478488
14+
12,-0.16621172,6.4995356,6.4869556
15+
13,-0.13755234,6.4940825,6.47636
16+
14,-0.11532383,6.4876075,6.460555
17+
15,-0.14074893,6.4802,6.4586205
18+
16,-0.13124135,6.474483,6.4573145
19+
17,-0.11284375,6.4697905,6.4460807
20+
18,-0.12861899,6.464823,6.437971
21+
19,-0.13243529,6.459164,6.4339194
22+
20,-0.13290825,6.4535875,6.438004
23+
21,-0.12190188,6.44997,6.443571
24+
22,-0.11239375,6.447536,6.420986
25+
23,-0.12801744,6.443129,6.431984
26+
24,-0.11823546,6.440166,6.42777
27+
25,-0.107929625,6.436949,6.418355
28+
26,-0.09892246,6.4335003,6.4274316
29+
27,-0.09260898,6.430896,6.412428
30+
28,-0.11238246,6.429055,6.4071865
31+
29,-0.11275884,6.424093,6.4058414
32+
30,-0.11065548,6.4216547,6.401159
33+
31,-0.10574161,6.4185834,6.3972263
34+
32,-0.1168384,6.415805,6.399093
35+
33,-0.11321704,6.4131637,6.389376
36+
34,-0.1172516,6.4110584,6.393464
37+
35,-0.11503467,6.410205,6.3980913
38+
36,-0.109255955,6.4086895,6.392976
39+
37,-0.10428008,6.4064713,6.39075
40+
38,-0.0974484,6.4040713,6.3854094
41+
39,-0.108119674,6.403427,6.39277

checkpts/.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
*
2+
*/
3+
!.gitignore
Loading

0 commit comments

Comments
 (0)