Skip to content

Commit e17ed5c

Browse files
kollasbfacebook-github-bot
authored andcommitted
Hipify Pulsar for PyTorch3D
Summary: - Hipified Pytorch Pulsar - Created separate target for Pulsar tests and enabled RE testing - Pytorch3D full test suite requires additional work like fixing EGL dependencies on AMD Reviewed By: danzimm Differential Revision: D61339912 fbshipit-source-id: 0d10bc966e4de4a959f3834a386bad24e449dc1f
1 parent 8ed0c7a commit e17ed5c

23 files changed

+26
-18
lines changed

pytorch3d/csrc/ext.cpp

-6
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,11 @@
77
*/
88

99
// clang-format off
10-
#if !defined(USE_ROCM)
1110
#include "./pulsar/global.h" // Include before <torch/extension.h>.
12-
#endif
1311
#include <torch/extension.h>
1412
// clang-format on
15-
#if !defined(USE_ROCM)
1613
#include "./pulsar/pytorch/renderer.h"
1714
#include "./pulsar/pytorch/tensor_util.h"
18-
#endif
1915
#include "ball_query/ball_query.h"
2016
#include "blending/sigmoid_alpha_blend.h"
2117
#include "compositing/alpha_composite.h"
@@ -104,7 +100,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
104100

105101
// Pulsar.
106102
// Pulsar not enabled on AMD.
107-
#if !defined(USE_ROCM)
108103
#ifdef PULSAR_LOGGING_ENABLED
109104
c10::ShowLogInfoToStderr();
110105
#endif
@@ -189,5 +184,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
189184
m.attr("MAX_UINT") = py::int_(MAX_UINT);
190185
m.attr("MAX_USHORT") = py::int_(MAX_USHORT);
191186
m.attr("PULSAR_MAX_GRAD_SPHERES") = py::int_(MAX_GRAD_SPHERES);
192-
#endif
193187
}

pytorch3d/csrc/pulsar/global.h

+4
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@
3636
#pragma nv_diag_suppress 2951
3737
#pragma nv_diag_suppress 2967
3838
#else
39+
#if !defined(USE_ROCM)
3940
#pragma diag_suppress = attribute_not_allowed
4041
#pragma diag_suppress = 1866
4142
#pragma diag_suppress = 2941
4243
#pragma diag_suppress = 2951
4344
#pragma diag_suppress = 2967
45+
#endif //! USE_ROCM
4446
#endif
4547
#else // __CUDACC__
4648
#define INLINE inline
@@ -56,7 +58,9 @@
5658
#pragma clang diagnostic pop
5759
#ifdef WITH_CUDA
5860
#include <ATen/cuda/CUDAContext.h>
61+
#if !defined(USE_ROCM)
5962
#include <vector_functions.h>
63+
#endif //! USE_ROCM
6064
#else
6165
#ifndef cudaStream_t
6266
typedef void* cudaStream_t;
File renamed without changes.

pytorch3d/csrc/pulsar/cuda/commands.h pytorch3d/csrc/pulsar/gpu/commands.h

+8
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ getLastCudaError(const char* errorMessage, const char* file, const int line) {
5959
#define SHARED __shared__
6060
#define ACTIVEMASK() __activemask()
6161
#define BALLOT(mask, val) __ballot_sync((mask), val)
62+
63+
/* TODO (ROCM-6.2): None of the WARP_* are used anywhere and ROCM-6.2 natively
64+
* supports __shfl_*. Disabling until the move to ROCM-6.2.
65+
*/
66+
#if !defined(USE_ROCM)
6267
/**
6368
* Find the cumulative sum within a warp up to the current
6469
* thread lane, with each mask thread contributing base.
@@ -115,6 +120,7 @@ INLINE DEVICE float3 WARP_SUM_FLOAT3(
115120
ret.z = WARP_SUM(group, mask, base.z);
116121
return ret;
117122
}
123+
#endif //! USE_ROCM
118124

119125
// Floating point.
120126
// #define FMUL(a, b) __fmul_rn((a), (b))
@@ -142,6 +148,7 @@ INLINE DEVICE float3 WARP_SUM_FLOAT3(
142148
#define FMA(x, y, z) __fmaf_rn((x), (y), (z))
143149
#define I2F(a) __int2float_rn(a)
144150
#define FRCP(x) __frcp_rn(x)
151+
#if !defined(USE_ROCM)
145152
__device__ static float atomicMax(float* address, float val) {
146153
int* address_as_i = (int*)address;
147154
int old = *address_as_i, assumed;
@@ -166,6 +173,7 @@ __device__ static float atomicMin(float* address, float val) {
166173
} while (assumed != old);
167174
return __int_as_float(old);
168175
}
176+
#endif //! USE_ROCM
169177
#define DMAX(a, b) FMAX(a, b)
170178
#define DMIN(a, b) FMIN(a, b)
171179
#define DSQRT(a) sqrt(a)

pytorch3d/csrc/pulsar/include/camera.device.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include "./commands.h"
1515

1616
namespace pulsar {
17-
IHD CamGradInfo::CamGradInfo() {
17+
IHD CamGradInfo::CamGradInfo(int x) {
1818
cam_pos = make_float3(0.f, 0.f, 0.f);
1919
pixel_0_0_center = make_float3(0.f, 0.f, 0.f);
2020
pixel_dir_x = make_float3(0.f, 0.f, 0.f);

pytorch3d/csrc/pulsar/include/camera.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ inline bool operator==(const CamInfo& a, const CamInfo& b) {
6363
};
6464

6565
struct CamGradInfo {
66-
HOST DEVICE CamGradInfo();
66+
HOST DEVICE CamGradInfo(int = 0);
6767
float3 cam_pos;
6868
float3 pixel_0_0_center;
6969
float3 pixel_dir_x;

pytorch3d/csrc/pulsar/include/commands.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
// #pragma diag_suppress = 68
2525
#include <ATen/cuda/CUDAContext.h>
2626
// #pragma pop
27-
#include "../cuda/commands.h"
27+
#include "../gpu/commands.h"
2828
#else
2929
#pragma clang diagnostic push
3030
#pragma clang diagnostic ignored "-Weverything"

pytorch3d/csrc/pulsar/include/math.h

+2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ IHD float3 outer_product_sum(const float3& a) {
4646
}
4747

4848
// TODO: put intrinsics here.
49+
#if !defined(USE_ROCM)
4950
IHD float3 operator+(const float3& a, const float3& b) {
5051
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
5152
}
@@ -93,6 +94,7 @@ IHD float3 operator*(const float3& a, const float3& b) {
9394
IHD float3 operator*(const float& a, const float3& b) {
9495
return b * a;
9596
}
97+
#endif //! USE_ROCM
9698

9799
INLINE DEVICE float length(const float3& v) {
98100
// TODO: benchmark what's faster.

pytorch3d/csrc/pulsar/include/renderer.render.device.h

+7-1
Original file line numberDiff line numberDiff line change
@@ -283,9 +283,15 @@ GLOBAL void render(
283283
(percent_allowed_difference > 0.f &&
284284
max_closest_possible_intersection > depth_threshold) ||
285285
tracker.get_n_hits() >= max_n_hits;
286+
#if defined(__CUDACC__) && defined(__HIP_PLATFORM_AMD__)
287+
unsigned long long warp_done = __ballot(done);
288+
int warp_done_bit_cnt = __popcll(warp_done);
289+
#else
286290
uint warp_done = thread_warp.ballot(done);
291+
int warp_done_bit_cnt = POPC(warp_done);
292+
#endif //__CUDACC__ && __HIP_PLATFORM_AMD__
287293
if (thread_warp.thread_rank() == 0)
288-
ATOMICADD_B(&n_pixels_done, POPC(warp_done));
294+
ATOMICADD_B(&n_pixels_done, warp_done_bit_cnt);
289295
// This sync is necessary to keep n_loaded until all threads are done with
290296
// painting.
291297
thread_block.sync();

pytorch3d/renderer/__init__.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,9 @@
7676
PointsRasterizationSettings,
7777
PointsRasterizer,
7878
PointsRenderer,
79+
PulsarPointsRenderer,
7980
rasterize_points,
8081
)
81-
82-
# Pulsar is not enabled on amd.
83-
if not torch.version.hip:
84-
from .points import PulsarPointsRenderer
85-
8682
from .splatter_blend import SplatterBlender
8783
from .utils import (
8884
convert_to_tensors_and_broadcast,

pytorch3d/renderer/points/__init__.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010

1111
from .compositor import AlphaCompositor, NormWeightedCompositor
1212

13-
# Pulsar not enabled on amd.
14-
if not torch.version.hip:
15-
from .pulsar.unified import PulsarPointsRenderer
13+
from .pulsar.unified import PulsarPointsRenderer
1614

1715
from .rasterize_points import rasterize_points
1816
from .rasterizer import PointsRasterizationSettings, PointsRasterizer

0 commit comments

Comments
 (0)