Skip to content

Commit b317033

Browse files
issue/411: adjust mudnn handle use to prevent mudnn handle mismatching
2 parents 1d06439 + f34d4e3 commit b317033

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

src/infiniop/devices/moore/moore_common.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ class Handle::Internal {
2121
_block_size[3],
2222
_grid_size[3];
2323

24+
int _device_id;
2425
template <typename T>
2526
using Fn = std::function<infiniStatus_t(T)>;
2627

2728
public:
28-
Internal(int);
29+
Internal(int device_id);
30+
2931
infiniStatus_t useMublas(musaStream_t stream, const Fn<mublasHandle_t> &f) const;
3032
infiniStatus_t useMudnn(musaStream_t stream, const Fn<::musa::dnn::Handle &> &f) const;
3133

src/infiniop/devices/moore/moore_handle.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
1111
return _internal;
1212
}
1313

14-
Handle::Internal::Internal(int device_id) {
14+
Handle::Internal::Internal(int device_id)
15+
: _device_id(device_id) {
1516
musaDeviceProp prop;
1617
musaGetDeviceProperties(&prop, device_id);
1718
_warp_size = prop.warpSize;
@@ -45,7 +46,7 @@ infiniStatus_t Handle::Internal::useMudnn(musaStream_t stream, const Fn<::musa::
4546
if (opt_handle.has_value()) {
4647
handle = std::move(*opt_handle);
4748
} else {
48-
handle = std::make_unique<::musa::dnn::Handle>();
49+
handle = std::make_unique<::musa::dnn::Handle>(_device_id);
4950
}
5051
CHECK_MUDNN(handle->SetStream(stream));
5152
CHECK_STATUS(f(*handle));

0 commit comments

Comments
 (0)