Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions src/utils/libfabric/libfabric_rail_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ nixlLibfabricRailManager::createDataRails(const std::vector<std::string> &fabric
fabric_devices[i], provider_name, static_cast<uint16_t>(i)));

// Initialize fabric device mapping
device_to_rail_map[fabric_devices[i]] = i;
device_to_rail_map[fabric_devices[i]].push_back(i);

NIXL_DEBUG << "Created data rail " << i << " (device: " << fabric_devices[i]
<< ", provider: " << provider_name << ")";
Expand Down Expand Up @@ -328,19 +328,22 @@ nixlLibfabricRailManager::selectRailsForMemory(void *mem_addr,
for (const std::string &device_name : gpu_nics) {
auto it = device_to_rail_map.find(device_name);
if (it != device_to_rail_map.end()) {
// Bounds check: ensure rail index is valid
if (it->second < data_rails_.size()) {
gpu_rails.push_back(it->second);
NIXL_DEBUG << "VRAM memory " << mem_addr << " on GPU " << gpu_id
<< " mapped to rail " << it->second << " (fabric device: " << device_name
<< ")";
} else {
NIXL_WARN << "Fabric device " << device_name << " maps to rail " << it->second
<< " but only " << data_rails_.size() << " rails available";
const std::vector<size_t> &rails = it->second;
for (size_t rail_id : rails) {
// Bounds check: ensure rail index is valid
if (rail_id < data_rails_.size()) {
gpu_rails.push_back(rail_id);
NIXL_DEBUG << "VRAM memory " << mem_addr << " on GPU " << gpu_id
<< " mapped to rail " << rail_id
<< " (fabric device: " << device_name << ")";
} else {
NIXL_WARN << "Fabric device " << device_name << " maps to rail " << rail_id
<< " but only " << data_rails_.size() << " rails available";
}
}
} else {
NIXL_WARN << "Fabric device " << device_name << " not found in rail mapping for GPU "
<< gpu_id;
NIXL_WARN << "Fabric device " << device_name
<< " not found in rail mapping for GPU " << gpu_id;
}
}

Expand Down Expand Up @@ -420,7 +423,8 @@ nixlLibfabricRailManager::registerMemory(void *buffer,

struct fid_mr *mr;
uint64_t key;
nixl_status_t status = data_rails_[rail_idx]->registerMemory(buffer, length, hmem_hint, gpu_id, &mr, &key);
nixl_status_t status =
data_rails_[rail_idx]->registerMemory(buffer, length, hmem_hint, gpu_id, &mr, &key);
if (status != NIXL_SUCCESS) {
NIXL_ERROR << "Failed to register memory on rail " << rail_idx;
// Cleanup already registered MRs
Expand Down
2 changes: 1 addition & 1 deletion src/utils/libfabric/libfabric_rail_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ class nixlLibfabricRailManager {
std::unique_ptr<nixlLibfabricTopology> topology;

// Fabric device to rail mapping
std::unordered_map<std::string, size_t> device_to_rail_map;
std::unordered_map<std::string, std::vector<size_t>> device_to_rail_map;

// Active Rail Tracking System
std::unordered_set<size_t> active_rails_;
Expand Down
Loading