Skip to content

Commit 6900f7d

Browse files
pavanbalajimeta-codesync[bot]
authored andcommitted
Pass default values for hints more cleanly
Summary: Reduces possibility of garbage values. Reviewed By: siyengar Differential Revision: D85972057 fbshipit-source-id: 4786a97ef5833eddbb9e26fc05d3e30145a5cce0
1 parent 7475b39 commit 6900f7d

File tree

5 files changed

+19
-20
lines changed

5 files changed

+19
-20
lines changed

comms/torchcomms/nccl/TorchCommNCCL.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ void TorchCommNCCL::init(
151151
max_event_pool_size_ =
152152
std::stoull(options_.hints.at("torchcomm::nccl::max_event_pool_size"));
153153
} else {
154-
max_event_pool_size_ = kMaxEventPoolSize;
154+
max_event_pool_size_ = kDefaultMaxEventPoolSize;
155155
}
156156

157157
// Give up our internal reference to the store object here. The caller

comms/torchcomms/nccl/TorchCommNCCL.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
namespace torch {
2929
namespace comms {
3030

31-
constexpr size_t kMaxEventPoolSize = 1000;
31+
constexpr size_t kDefaultMaxEventPoolSize = 1000;
3232

3333
// Custom exception class for better error handling
3434
class NCCLException : public std::exception {

comms/torchcomms/ncclx/TorchCommNCCLX.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -151,25 +151,19 @@ void TorchCommNCCLX::init(
151151
"Failed to allocate barrier buffer");
152152

153153
if (options_.hints.contains("torchcomm::ncclx::max_event_pool_size")) {
154-
max_event_pool_size_ =
154+
configs_.max_event_pool_size_ =
155155
std::stoull(options_.hints.at("torchcomm::ncclx::max_event_pool_size"));
156-
} else {
157-
max_event_pool_size_ = kMaxEventPoolSize;
158156
}
159157

160158
if (options_.hints.contains(
161159
"torchcomm::ncclx::garbage_collect_interval_ms")) {
162-
garbage_collect_interval_ms_ = std::stoull(
160+
configs_.garbage_collect_interval_ms_ = std::stoull(
163161
options_.hints.at("torchcomm::ncclx::garbage_collect_interval_ms"));
164-
} else {
165-
garbage_collect_interval_ms_ = kGarbageCollectIntervalMs;
166162
}
167163

168164
if (options_.hints.contains("torchcomm::ncclx::enable_cuda_graph_support")) {
169-
enable_cuda_graph_support_ = string_to_bool(
165+
configs_.enable_cuda_graph_support_ = string_to_bool(
170166
options_.hints.at("torchcomm::ncclx::enable_cuda_graph_support"));
171-
} else {
172-
enable_cuda_graph_support_ = kEnableCudaGraphSupport;
173167
}
174168

175169
// Give up our internal reference to the store object here. The caller

comms/torchcomms/ncclx/TorchCommNCCLX.hpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
namespace torch {
2929
namespace comms {
3030

31-
constexpr size_t kMaxEventPoolSize = 1000;
32-
constexpr size_t kGarbageCollectIntervalMs = 100;
33-
constexpr bool kEnableCudaGraphSupport = true;
31+
constexpr size_t kDefaultMaxEventPoolSize = 1000;
32+
constexpr size_t kDefaultGarbageCollectIntervalMs = 100;
33+
constexpr bool kDefaultEnableCudaGraphSupport = true;
3434

3535
// Custom exception class for better error handling
3636
class NCCLException : public std::exception {
@@ -332,9 +332,14 @@ class TorchCommNCCLX : public TorchCommBackend,
332332
int comm_size_{};
333333
int rank_{};
334334
CommOptions options_;
335-
size_t max_event_pool_size_{};
336-
size_t garbage_collect_interval_ms_{};
337-
bool enable_cuda_graph_support_{};
335+
336+
struct Configs {
337+
size_t max_event_pool_size_{kDefaultMaxEventPoolSize};
338+
size_t garbage_collect_interval_ms_{kDefaultGarbageCollectIntervalMs};
339+
bool enable_cuda_graph_support_{kDefaultEnableCudaGraphSupport};
340+
};
341+
Configs configs_;
342+
338343
cudaStream_t internal_stream_{};
339344
void* barrier_buffer_{}; // Pre-allocated CUDA buffer for barrier operations
340345
enum class InitializationState {

comms/torchcomms/ncclx/TorchCommNCCLXUtils.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ void TorchCommNCCLX::timeoutWatchdog() noexcept {
184184
// Wake up either after some time or immediately if shutdown is requested
185185
timeout_cv_.wait_for(
186186
lock,
187-
std::chrono::milliseconds(garbage_collect_interval_ms_),
187+
std::chrono::milliseconds(configs_.garbage_collect_interval_ms_),
188188
[this]() { return shutdown_.load(); });
189189

190190
// If we're shutting down, exit the loop
@@ -254,7 +254,7 @@ void TorchCommNCCLX::checkAndAbortIfTimedOutOrError() {
254254
}
255255

256256
bool TorchCommNCCLX::getGraphCaptureMode() {
257-
if (!enable_cuda_graph_support_) {
257+
if (!configs_.enable_cuda_graph_support_) {
258258
return false;
259259
}
260260

@@ -430,7 +430,7 @@ cudaEvent_t TorchCommNCCLX::getEvent() {
430430
void TorchCommNCCLX::returnEvent(cudaEvent_t event) {
431431
std::lock_guard<std::mutex> lock(event_pool_mutex_);
432432

433-
if (event_pool_.size() < max_event_pool_size_) {
433+
if (event_pool_.size() < configs_.max_event_pool_size_) {
434434
event_pool_.push(event);
435435
} else {
436436
// Pool is full, destroy the event

0 commit comments

Comments
 (0)