Skip to content

Commit 451c27b

Browse files
authored
fix crash with return alternative on CUDA (OpenNMT#1733)
1 parent 72a461a commit 451c27b

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/decoding.cc

+5
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,14 @@ namespace ctranslate2 {
7070
StorageView& ids) {
7171
if (!decoder.output_layer_is_updated())
7272
return;
73+
ctranslate2::Device device = ids.device();
74+
if (device != Device::CPU)
75+
ids = ids.to(Device::CPU);
7376
auto* ids_data = ids.data<int32_t>();
7477
for (dim_t i = 0; i < ids.size(); ++i)
7578
ids_data[i] = decoder.to_original_word_id(ids_data[i]);
79+
if (ids.device() != device)
80+
ids = ids.to(device);
7681
}
7782

7883
template <typename T>

0 commit comments

Comments
 (0)