Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UCT: Derived memh #10332

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/uct/api/v2/uct_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ typedef struct {
typedef enum {
UCT_MD_MEM_REG_FIELD_FLAGS = UCS_BIT(0),
UCT_MD_MEM_REG_FIELD_DMABUF_FD = UCS_BIT(1),
UCT_MD_MEM_REG_FIELD_DMABUF_OFFSET = UCS_BIT(2)
UCT_MD_MEM_REG_FIELD_DMABUF_OFFSET = UCS_BIT(2),
UCT_MD_MEM_REG_FIELD_MEMH = UCS_BIT(3)
} uct_md_mem_reg_field_mask_t;


Expand Down Expand Up @@ -480,6 +481,18 @@ typedef struct uct_md_mem_reg_params {
* dmabuf region, then this field must be omitted or set to 0.
*/
size_t dmabuf_offset;

/**
* Represents a pointer to the existing memory handle.
* Used to register a derived memory handle: a shallow copy of existing UCT
* memory handle, which can be used to access the same memory region. When
* created, the derived memh inherits the original memh access flags and
* state. The lifetime of the derived memh is bound to the original memh,
* and the original memh cannot be destroyed until all its derived handles
* are destroyed. The derived memh cannot be used to register another
* derived memh.
*/
uct_mem_h memh;
} uct_md_mem_reg_params_t;


Expand Down
8 changes: 8 additions & 0 deletions src/uct/cuda/cuda_ipc/cuda_ipc_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,17 @@ static ucs_status_t
uct_cuda_ipc_mem_reg(uct_md_h md, void *address, size_t length,
const uct_md_mem_reg_params_t *params, uct_mem_h *memh_p)
{
uct_mem_h base = (params != NULL) ?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

params is not checked to be non-NULL in other places. I think we can remove this check here as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, initially I didn't have this check, added it to fix NULL pointer crash in CI

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you remember the failed test? Looks like an issue in caller function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember it, but I just see my commit named "Fix NPE" that I've made to address that failure: 53a4b08

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, found and fixed this NULL pointer error in gtest: b193dea

UCT_MD_MEM_REG_FIELD_VALUE(params, memh, FIELD_MEMH, NULL) :
NULL;
uct_cuda_ipc_memh_t *memh;
CUdevice cu_device;

if (ENABLE_PARAMS_CHECK && (base != NULL)) {
ucs_error("CUDA IPC does not support derived memory handles");
return UCS_ERR_UNSUPPORTED;
}

UCT_CUDA_IPC_GET_DEVICE(cu_device);

memh = ucs_malloc(sizeof(*memh), "uct_cuda_ipc_memh_t");
Expand Down
41 changes: 37 additions & 4 deletions src/uct/ib/base/ib_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -584,19 +584,35 @@ ucs_status_t uct_ib_mem_advise(uct_md_h uct_md, uct_mem_h memh, void *addr,
return UCS_OK;
}

ucs_status_t uct_ib_memh_alloc(uct_ib_md_t *md, size_t length,
unsigned mem_flags, size_t memh_base_size,
size_t mr_size, uct_ib_mem_t **memh_p)
static uct_ib_mem_t *
uct_ib_memh_alloc_internal(uct_ib_md_t *md, size_t memh_base_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make better names, instead of uct_ib_memh_alloc_internal+uct_ib_memh_alloc - uct_ib_memh_alloc+uct_ib_memh_new/init

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uct_ib_memh_alloc is an existing function, so we don't want to change its name
I just extracted some common part from uct_ib_memh_alloc into _internal so that it can be reused by uct_ib_memh_clone.

Maybe we name it uct_ib_memh_alloc_common?

size_t mr_size, size_t *memh_size_p)
{
int num_mrs = md->relaxed_order ?
2 /* UCT_IB_MR_DEFAULT and UCT_IB_MR_STRICT_ORDER */ :
1 /* UCT_IB_MR_DEFAULT */;
uct_ib_mem_t *memh;

memh = ucs_calloc(1, memh_base_size + (mr_size * num_mrs), "ib_memh");
*memh_size_p = memh_base_size + (mr_size * num_mrs);
memh = ucs_calloc(1, *memh_size_p, "ib_memh");
if (memh == NULL) {
ucs_error("%s: failed to allocated memh struct",
uct_ib_device_name(&md->dev));
return NULL;
}

return memh;
}

ucs_status_t uct_ib_memh_alloc(uct_ib_md_t *md, size_t length,
unsigned mem_flags, size_t memh_base_size,
size_t mr_size, uct_ib_mem_t **memh_p)
{
uct_ib_mem_t *memh;
size_t memh_size;

memh = uct_ib_memh_alloc_internal(md, memh_base_size, mr_size, &memh_size);
if (memh == NULL) {
return UCS_ERR_NO_MEMORY;
}

Expand Down Expand Up @@ -626,6 +642,23 @@ ucs_status_t uct_ib_memh_alloc(uct_ib_md_t *md, size_t length,
return UCS_OK;
}

ucs_status_t uct_ib_memh_clone(uct_ib_md_t *md, const uct_ib_mem_t *src,
size_t memh_base_size, size_t mr_size,
uct_ib_mem_t **memh_p)
{
uct_ib_mem_t *memh;
size_t memh_size;

memh = uct_ib_memh_alloc_internal(md, memh_base_size, mr_size, &memh_size);
if (memh == NULL) {
return UCS_ERR_NO_MEMORY;
}

memcpy(memh, src, memh_size);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. clone seems wrong - if the derived memh is shallow, why need to fully copy the original memh?
  2. seems weird that we use calloc() and then override everything with memcpy

Copy link
Contributor Author

@iyastreb iyastreb Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Answering in the opposite order:
2. Right, here malloc would be enough but I'm reusing existing function uct_ib_memh_alloc_internal that does calloc and calculates the size, just to reuse the common code. I think the overhead is minimal.
We could split size calculation into a separate function etc, but I think it's not worth the effort

  1. Ok, maybe it's a terminology issue here
    My understanding is:
    Deep copy creates an independent instance of an object, that can be used apart from the original.
    Shallow copy creates an "alias" that depends on master copy and cannot be used separately.
    The latter is what's implemented in this PR, whatever we name it.

Derived memh is a shallow copy, because it makes a shallow copy of MRs - the most important part of original memh. And original object remains the only owner of the MRs state.
Of course there could be different implementations of derived memh. For example, we can imagine a shallow copy looking like that:

struct {
   uct_ib_mlx5_devx_mem_t *base;
   // derived specific fields
} uct_derived_memh;

Looks ok, right? Still you need to allocate this object.
What are the issues with this approach:

  • There are places where we assume that memh has always the uct_ib_mem_t base: uct_rc_mlx5_txqp_tag_inline_post: ((uct_ib_mem_t*)iov->memh)
    So we must add uct_ib_mem_t super field to our copy
  • The copy must handle a set of rkeys independently from the original memh, meaning that the following fields also needs to be copied:
    struct mlx5dv_devx_obj      *atomic_dvmr;
    struct mlx5dv_devx_obj      *indirect_dvmr;
    uint32_t                    atomic_rkey;
    uint32_t                    indirect_rkey;

This is just to show you that we need to duplicate the significant part of the original memh anyway.

  • Then in each and every place we need to check whether passed memh is derived or original (because their layout is different), adding a lot of boilerplate code and CPU overhead..
  • Then we need to modify quite some existing functions for key generation, because they cache keys in the original object.

To overcome all these issues I make a shallow copy of the original memh (== they have the same memory layout), and use it interchangeably in all the places. This allows me to minimize the amount of ifs in the code, basically we check for derived handle only during init and cleanup. This helps to avoid any refactoring in the existing functions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we nee to brainstorm the approaches. I don't like having structs where only some fields used, maybe need a deeper refactoring in the IB memh structure.

*memh_p = memh;
return UCS_OK;
}

uint64_t uct_ib_memh_access_flags(uct_ib_mem_t *memh, int relaxed_order,
uint64_t access_flags)
{
Expand Down
6 changes: 6 additions & 0 deletions src/uct/ib/base/ib_md.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ enum {
#endif
UCT_IB_MEM_FLAG_GVA = UCS_BIT(5), /**< The memory handle is a
GVA region */
UCT_IB_MEM_FLAG_DERIVED = UCS_BIT(6), /**< The memory handle is a
derived memh */
};

enum {
Expand Down Expand Up @@ -432,4 +434,8 @@ ucs_status_t uct_ib_memh_alloc(uct_ib_md_t *md, size_t length,
unsigned mem_flags, size_t memh_base_size,
size_t mr_size, uct_ib_mem_t **memh_p);

ucs_status_t uct_ib_memh_clone(uct_ib_md_t *md, const uct_ib_mem_t *src,
size_t memh_base_size, size_t mr_size,
uct_ib_mem_t **memh_p);

#endif
59 changes: 59 additions & 0 deletions src/uct/ib/mlx5/dv/ib_mlx5dv_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,26 @@ uct_ib_mlx5_devx_memh_alloc(uct_ib_mlx5_md_t *md, size_t length,
return UCS_OK;
}

static ucs_status_t
uct_ib_mlx5_devx_memh_clone(uct_ib_mlx5_md_t *md,
const uct_ib_mlx5_devx_mem_t *src,
uct_ib_mlx5_devx_mem_t **memh_p)
{
size_t mr_size = src->super.flags & UCT_IB_MEM_IMPORTED ?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

( )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

0 : sizeof(src->mrs[0]);
uct_ib_mem_t *ib_memh;
ucs_status_t status;

status = uct_ib_memh_clone(&md->super, &src->super, sizeof(**memh_p),
mr_size, &ib_memh);
if (status != UCS_OK) {
return status;
}

*memh_p = ucs_derived_of(ib_memh, uct_ib_mlx5_devx_mem_t);
return UCS_OK;
}

static int
uct_ib_mlx5_devx_memh_has_ro(uct_ib_mlx5_md_t *md, uct_ib_mlx5_devx_mem_t *memh)
{
Expand Down Expand Up @@ -837,13 +857,42 @@ uct_ib_mlx5_devx_mem_reg_gva(uct_md_h uct_md, unsigned flags, uct_mem_h *memh_p)
return status;
}

static ucs_status_t
uct_ib_mlx5_devx_derived_mem_reg(uct_md_h uct_md, uct_ib_mlx5_devx_mem_t *base,
uct_mem_h *memh_p)
{
uct_ib_mlx5_md_t *md = ucs_derived_of(uct_md, uct_ib_mlx5_md_t);
uct_ib_mlx5_devx_mem_t *memh;
ucs_status_t status;

ucs_assertv(!(base->super.flags & UCT_IB_MEM_FLAG_DERIVED),
"memh=%p is already a derived memh", base);

status = uct_ib_mlx5_devx_memh_clone(md, base, &memh);
if (status != UCS_OK) {
ucs_error("%s: failed to clone memory handle: %s",
uct_ib_mlx5_dev_name(md), ucs_status_string(status));
return status;
}

memh->super.flags |= UCT_IB_MEM_FLAG_DERIVED;
memh->atomic_dvmr = NULL;
memh->atomic_rkey = UCT_IB_INVALID_MKEY;
memh->indirect_dvmr = NULL;
memh->indirect_rkey = UCT_IB_INVALID_MKEY;

*memh_p = memh;
return UCS_OK;
}

ucs_status_t
uct_ib_mlx5_devx_mem_reg(uct_md_h uct_md, void *address, size_t length,
const uct_md_mem_reg_params_t *params,
uct_mem_h *memh_p)
{
uct_ib_mlx5_md_t *md = ucs_derived_of(uct_md, uct_ib_mlx5_md_t);
unsigned flags = UCT_MD_MEM_REG_FIELD_VALUE(params, flags, FIELD_FLAGS, 0);
uct_mem_h base = UCT_MD_MEM_REG_FIELD_VALUE(params, memh, FIELD_MEMH, NULL);
uct_ib_mlx5_devx_mem_t *memh;
ucs_status_t status;
uint32_t dummy_mkey;
Expand All @@ -852,6 +901,10 @@ uct_ib_mlx5_devx_mem_reg(uct_md_h uct_md, void *address, size_t length,
return uct_ib_mlx5_devx_mem_reg_gva(uct_md, flags, memh_p);
}

if (base != NULL) {
return uct_ib_mlx5_devx_derived_mem_reg(uct_md, base, memh_p);
}

status = uct_ib_mlx5_devx_memh_alloc(md, length, flags,
sizeof(memh->mrs[0]), &memh);
if (status != UCS_OK) {
Expand Down Expand Up @@ -1509,6 +1562,11 @@ uct_ib_mlx5_devx_mem_dereg(uct_md_h uct_md,
return status;
}

/* Derived memh owns only indirect keys, but not the other state */
if (memh->super.flags & UCT_IB_MEM_FLAG_DERIVED) {
goto out;
}

if (memh->smkey_mr != NULL) {
ucs_trace("%s: destroy smkey_mr %p with key %x",
uct_ib_device_name(&md->super.dev), memh->smkey_mr,
Expand Down Expand Up @@ -1567,6 +1625,7 @@ uct_ib_mlx5_devx_mem_dereg(uct_md_h uct_md,
uct_invoke_completion(params->comp, UCS_OK);
}

out:
ucs_free(memh);
return UCS_OK;
}
Expand Down
56 changes: 50 additions & 6 deletions test/gtest/uct/ib/test_ib_md.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ class test_ib_md : public test_md
void test_mkey_pack_mt_internal(unsigned access_mask, bool invalidate);
void test_smkey_reg_atomic(void);

uct_mem_h reg_derived_mem(uct_mem_h base) const {
uct_mem_h memh;
uct_md_mem_reg_params_t params;
params.field_mask = UCT_MD_MEM_REG_FIELD_MEMH;
params.memh = base;
ASSERT_UCS_OK(uct_md_mem_reg_v2(md(), NULL, SIZE_MAX, &params, &memh));
return memh;
}

private:
#ifdef HAVE_MLX5_DV
uint32_t m_mlx5_flags = 0;
Expand Down Expand Up @@ -272,12 +281,7 @@ void test_ib_md::test_mkey_pack_mt_internal(unsigned access_mask,
uct_ib_mem_t *ib_memh = (uct_ib_mem_t*)memh;
EXPECT_TRUE(ib_memh->flags & UCT_IB_MEM_MULTITHREADED);

std::vector<uint8_t> rkey(md_attr().rkey_packed_size);
uct_md_mkey_pack_params_t pack_params;
pack_params.field_mask = UCT_MD_MKEY_PACK_FIELD_FLAGS;
pack_params.flags = pack_flags;
ASSERT_UCS_OK(uct_md_mkey_pack_v2(md(), memh, buffer, size,
&pack_params, rkey.data()));
mkey_pack(memh, pack_flags, buffer, size);

uct_md_mem_dereg_params_t params;
params.field_mask = UCT_MD_MEM_DEREG_FIELD_MEMH |
Expand Down Expand Up @@ -378,6 +382,46 @@ UCS_TEST_P(test_ib_md, mt_fail, "IB_REG_MT_THRESH=128K", "IB_REG_MT_CHUNK=16K")
}
}

UCS_TEST_SKIP_COND_P(test_ib_md, derived_mem,
!check_invalidate_support(UCT_MD_MEM_ACCESS_RMA))
{
bool is_atomic = check_caps(UCT_MD_FLAG_INVALIDATE_AMO);
unsigned flags = UCT_MD_MKEY_PACK_FLAG_INVALIDATE_RMA |
(is_atomic ? UCT_MD_MKEY_PACK_FLAG_INVALIDATE_AMO : 0);
unsigned md_flags = UCT_MD_MEM_ACCESS_RMA |
(is_atomic ? UCT_MD_MEM_ACCESS_REMOTE_ATOMIC : 0);
std::vector<uint8_t> buffer(1024);
uct_mem_h base;
EXPECT_UCS_OK(reg_mem(md_flags, buffer.data(), buffer.size(), &base));

/* Test case 1: creating derived memh from memh before mkey_pack */
uct_mem_h der1 = reg_derived_mem(base);

/* Test case 2: creating derived memh from memh after mkey_pack */
std::vector<uint8_t> base_rkey1 = mkey_pack(base, flags);
uct_mem_h der2 = reg_derived_mem(base);
std::vector<uint8_t> der2_rkey1 = mkey_pack(der2, flags);
EXPECT_NE(base_rkey1, der2_rkey1);

/* Test case 3: subsequent mkey_pack calls return the same result */
std::vector<uint8_t> der2_rkey2 = mkey_pack(der2, flags);
EXPECT_EQ(der2_rkey1, der2_rkey2);

/* Test case 4: multiple derived memhs do not share the same rkeys */
std::vector<uint8_t> der1_rkey1 = mkey_pack(der1, flags);
EXPECT_NE(der1_rkey1, der2_rkey1);

/* Invalidation = destroying derived memh */
EXPECT_UCS_OK(uct_md_mem_dereg(md(), der1));
EXPECT_UCS_OK(uct_md_mem_dereg(md(), der2));

/* Test case 5: base memh can still be used to pack mkeys */
std::vector<uint8_t> base_rkey2 = mkey_pack(base, flags);
EXPECT_EQ(base_rkey1, base_rkey2);

EXPECT_UCS_OK(uct_md_mem_dereg(md(), base));
}

_UCT_MD_INSTANTIATE_TEST_CASE(test_ib_md, ib)

class test_ib_md_non_blocking : public test_md_non_blocking {
Expand Down
16 changes: 2 additions & 14 deletions test/gtest/uct/test_md.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,7 @@ void test_md::test_reg_mem(unsigned access_mask,
status = uct_md_mem_dereg_v2(md(), &params);
ASSERT_UCS_STATUS_EQ(UCS_ERR_INVALID_PARAM, status);

std::vector<uint8_t> rkey(md_attr().rkey_packed_size);
uct_md_mkey_pack_params_t pack_params;
pack_params.field_mask = UCT_MD_MKEY_PACK_FIELD_FLAGS;
pack_params.flags = invalidate_flag;
status = uct_md_mkey_pack_v2(md(), memh, ptr, size, &pack_params,
rkey.data());
EXPECT_UCS_OK(status);
mkey_pack(memh, invalidate_flag, ptr, size);

status = uct_md_mem_dereg_v2(md(), &params);
}
Expand Down Expand Up @@ -963,13 +957,7 @@ UCS_TEST_SKIP_COND_P(test_md, exported_mkey,
status = reg_mem(UCT_MD_MEM_ACCESS_ALL, address, size, &export_memh);
ASSERT_UCS_OK(status);

std::vector<uint8_t> mkey_buffer(md_attr().exported_mkey_packed_size);
uct_md_mkey_pack_params_t pack_params;
pack_params.field_mask = UCT_MD_MKEY_PACK_FIELD_FLAGS;
pack_params.flags = UCT_MD_MKEY_PACK_FLAG_EXPORT;
status = uct_md_mkey_pack_v2(md(), export_memh, address, size, &pack_params,
mkey_buffer.data());
ASSERT_UCS_OK(status);
mkey_pack(export_memh, UCT_MD_MKEY_PACK_FLAG_EXPORT, address, size);

uct_md_mem_dereg_params_t dereg_params;
dereg_params.field_mask = UCT_MD_MEM_DEREG_FIELD_MEMH;
Expand Down
15 changes: 15 additions & 0 deletions test/gtest/uct/test_md.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,21 @@ class test_md : public testing::TestWithParam<test_md_param>,
return m_md_attr;
}

std::vector<uint8_t>
mkey_pack(uct_mem_h memh, unsigned flags = 0, void *ptr = NULL,
size_t size = SIZE_MAX) const {
size_t rkey_size = flags & UCT_MD_MKEY_PACK_FLAG_EXPORT ?
md_attr().exported_mkey_packed_size :
md_attr().rkey_packed_size;
std::vector<uint8_t> rkey(rkey_size, 0);
uct_md_mkey_pack_params_t pack_params;
pack_params.field_mask = UCT_MD_MKEY_PACK_FIELD_FLAGS;
pack_params.flags = flags;
EXPECT_UCS_OK(uct_md_mkey_pack_v2(md(), memh, ptr, size, &pack_params,
rkey.data()));
return rkey;
}

typedef struct {
test_md *self;
uct_completion_t comp;
Expand Down
Loading