From 5b86d32fb0548f077990a78b14a5ac999ac9b53f Mon Sep 17 00:00:00 2001 From: Ruben Date: Fri, 15 Nov 2019 11:47:04 +0100 Subject: [PATCH] Torch.multinomial on GPU --- models/FCModel.py | 4 ++-- models/OldModel.py | 4 ++-- models/ShowTellModel.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/models/FCModel.py b/models/FCModel.py index f6f9a44b..6f43d7f8 100644 --- a/models/FCModel.py +++ b/models/FCModel.py @@ -177,10 +177,10 @@ def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): it = it.view(-1).long() else: if temperature == 1.0: - prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) + prob_prev = torch.exp(logprobs.data) # .cpu() # fetch prev distribution: shape Nx(M+1) else: # scale logprobs by temperature - prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() + prob_prev = torch.exp(torch.div(logprobs.data, temperature)) # .cpu() it = torch.multinomial(prob_prev, 1).cuda() sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions it = it.view(-1).long() # and flatten indices for downstream processing diff --git a/models/OldModel.py b/models/OldModel.py index 4a654034..10bd5a31 100644 --- a/models/OldModel.py +++ b/models/OldModel.py @@ -148,10 +148,10 @@ def sample(self, fc_feats, att_feats, opt={}): it = it.view(-1).long() else: if temperature == 1.0: - prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) + prob_prev = torch.exp(logprobs.data) # .cpu() # fetch prev distribution: shape Nx(M+1) else: # scale logprobs by temperature - prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() + prob_prev = torch.exp(torch.div(logprobs.data, temperature)) # .cpu() it = torch.multinomial(prob_prev, 1).cuda() sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions it = it.view(-1).long() # and flatten indices for downstream processing diff --git a/models/ShowTellModel.py b/models/ShowTellModel.py index e466bef7..7f85dc87 100644 --- a/models/ShowTellModel.py +++ b/models/ShowTellModel.py @@ -147,10 +147,10 @@ def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): it = it.view(-1).long() else: if temperature == 1.0: - prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) + prob_prev = torch.exp(logprobs.data) # .cpu() # fetch prev distribution: shape Nx(M+1) else: # scale logprobs by temperature - prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() + prob_prev = torch.exp(torch.div(logprobs.data, temperature)) # .cpu() it = torch.multinomial(prob_prev, 1).cuda() sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions it = it.view(-1).long() # and flatten indices for downstream processing