Skip to content

Commit 0e5f524

Browse files
committed
Add wrap_ptr function
1 parent e824b62 commit 0e5f524

File tree

3 files changed

+48
-6
lines changed

3 files changed

+48
-6
lines changed

include/kernel_float/memory.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,18 @@ KERNEL_FLOAT_INLINE vector_ptr<T, N, U, A>& operator+=(vector_ptr<T, N, U, A>& p
626626
return p = p + i;
627627
}
628628

629+
/**
630+
* Creates a `vector_ptr<T, N, U>` from a raw pointer `U*`.
631+
*
632+
* @tparam T The type of the elements as viewed by the user.
633+
* @tparam N The vector size in number of elements.
634+
* @tparam U The type of the elements pointed to by the raw pointer.
635+
*/
636+
template<typename T, size_t N = 1, typename U>
637+
KERNEL_FLOAT_INLINE vector_ptr<T, N, U> wrap_ptr(U* ptr) {
638+
return vector_ptr<T, N, U> {ptr};
639+
}
640+
629641
/**
630642
* Creates a `vector_ptr<T, N>` from a raw pointer `T*` by asserting a specific alignment `N`.
631643
*
@@ -658,10 +670,10 @@ using view_ptr = vector_ptr<T, 1, U, alignof(U)>;
658670

659671
#if defined(__cpp_deduction_guides)
660672
template<typename T>
661-
vector_ptr(T*) -> vector_ptr<T, 1, T>;
673+
vector_ptr(T*) -> vector_ptr<T, 1, T, alignof(T)>;
662674

663675
template<typename T>
664-
vector_ptr(const T*) -> vector_ptr<T, 1, const T>;
676+
vector_ptr(const T*) -> vector_ptr<T, 1, const T, alignof(T)>;
665677

666678
#if __cpp_deduction_guides >= 201907L
667679
template<typename T>

single_include/kernel_float.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
//================================================================================
1818
// this file has been auto-generated, do not modify its contents!
19-
// date: 2025-09-15 12:44:05.768243
20-
// git hash: 9b41485a27b669ea6f4aae118b4d251947608bf6
19+
// date: 2025-09-15 16:14:44.345265
20+
// git hash: e824b62e2e7d40e70322cae48a0b652fbec3803c
2121
//================================================================================
2222

2323
#ifndef KERNEL_FLOAT_MACROS_H
@@ -3202,6 +3202,18 @@ KERNEL_FLOAT_INLINE vector_ptr<T, N, U, A>& operator+=(vector_ptr<T, N, U, A>& p
32023202
return p = p + i;
32033203
}
32043204

3205+
/**
3206+
* Creates a `vector_ptr<T, N, U>` from a raw pointer `U*`.
3207+
*
3208+
* @tparam T The type of the elements as viewed by the user.
3209+
* @tparam N The vector size in number of elements.
3210+
* @tparam U The type of the elements pointed to by the raw pointer.
3211+
*/
3212+
template<typename T, size_t N = 1, typename U>
3213+
KERNEL_FLOAT_INLINE vector_ptr<T, N, U> wrap_ptr(U* ptr) {
3214+
return vector_ptr<T, N, U> {ptr};
3215+
}
3216+
32053217
/**
32063218
* Creates a `vector_ptr<T, N>` from a raw pointer `T*` by asserting a specific alignment `N`.
32073219
*
@@ -3234,10 +3246,10 @@ using view_ptr = vector_ptr<T, 1, U, alignof(U)>;
32343246

32353247
#if defined(__cpp_deduction_guides)
32363248
template<typename T>
3237-
vector_ptr(T*) -> vector_ptr<T, 1, T>;
3249+
vector_ptr(T*) -> vector_ptr<T, 1, T, alignof(T)>;
32383250

32393251
template<typename T>
3240-
vector_ptr(const T*) -> vector_ptr<T, 1, const T>;
3252+
vector_ptr(const T*) -> vector_ptr<T, 1, const T, alignof(T)>;
32413253

32423254
#if __cpp_deduction_guides >= 201907L
32433255
template<typename T>

tests/memory.cu

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,24 @@ struct vector_ptr_test {
233233
ASSERT_EQ(b3_ptr.get(), static_cast<const U*>(storage.data));
234234
ASSERT_EQ(b4_ptr.get(), static_cast<const U*>(storage.data));
235235
}
236+
237+
{
238+
U* ptr = nullptr;
239+
240+
auto a1 = kf::wrap_ptr<T>(ptr);
241+
ASSERT(std::is_same<decltype(a1), kf::vector_ptr<T, 1, U>>::value);
242+
243+
auto a2 = kf::wrap_ptr<T, 2>(ptr);
244+
ASSERT(std::is_same<decltype(a2), kf::vector_ptr<T, 2, U>>::value);
245+
246+
auto a3 = kf::assert_aligned(ptr);
247+
ASSERT(
248+
std::is_same<decltype(a3), kf::vector_ptr<U, 1, U, KERNEL_FLOAT_MAX_ALIGNMENT>>::
249+
value);
250+
251+
auto a4 = kf::assert_aligned<2>(ptr);
252+
ASSERT(std::is_same<decltype(a4), kf::vector_ptr<U, 2>>::value);
253+
}
236254
}
237255
};
238256

0 commit comments

Comments
 (0)