diff --git a/src/ATen/native/xpu/FbgemmOps.cpp b/src/ATen/native/xpu/FbgemmOps.cpp new file mode 100644 index 000000000..3d3d9f889 --- /dev/null +++ b/src/ATen/native/xpu/FbgemmOps.cpp @@ -0,0 +1,507 @@ +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace at { +namespace native { +namespace xpu { + +#define XPU_DEVICE_GUARD(TENSOR) \ + const OptionalDeviceGuard device_guard(device_of(TENSOR)); + +Tensor asynchronous_complete_cumsum_xpu(const Tensor& t_in) { + TORCH_CHECK(t_in.is_contiguous()); + TORCH_CHECK(t_in.dtype() == at::kInt || t_in.dtype() == at::kLong); + TORCH_CHECK(t_in.dim() == 1 || t_in.dim() == 2); + + if (t_in.dim() == 1) { + Tensor t_out = at::zeros({t_in.numel() + 1}, t_in.options()); + auto r_out = t_out.slice(0, 1); + at::cumsum_out(r_out, t_in, 0); + return t_out; + } + + Tensor t_out = at::zeros({t_in.size(0), t_in.size(1) + 1}, t_in.options()); + auto r_out = t_out.slice(1, 1); + at::cumsum_out(r_out, t_in, 1); + return t_out; +} + +Tensor dense_to_jagged_forward_xpu( + const Tensor& dense, + const std::vector& offsets, + std::optional total_L) { + TORCH_CHECK(dense.is_xpu(), "value must be a xpu tensor"); + for (auto& offset : offsets) { + TORCH_CHECK(offset.is_xpu(), "offset must be a xpu tensor"); + } + + const int num_jagged_dim = dense.dim() - 2; + TORCH_CHECK( + offsets.size() == static_cast(num_jagged_dim), + "x_offsets.size(), ", + offsets.size(), + " != num_jagged_dim, ", + num_jagged_dim); + + // D is the embedding dimension + auto D = dense.size(-1); + + // If total_L is not given then compute it + at::SymInt total_L_computed; + if (total_L.has_value()) { + total_L_computed = total_L.value(); + } else { + total_L_computed = (int64_t)offsets.back().max().item(); + } + auto values = at::empty_symint({total_L_computed, D}, dense.options()); + auto output = at::empty_like(values); // not used + + if (dense.numel() == 0 || values.numel() == 0) { + return output; + } + + XPU_DEVICE_GUARD(dense); + + dense_to_jagged_forward_xpu_kernel(values, offsets, dense, output); + + return output; +} + +Tensor jagged_to_padded_dense_forward_xpu( + const Tensor& values, + const std::vector& offsets, + c10::SymIntArrayRef max_lengths, + const double padding_value) { + size_t num_jagged_dim = offsets.size(); + TORCH_CHECK( + max_lengths.size() == num_jagged_dim, + "max_lengths.size(), ", + max_lengths.size(), + " != num_jagged_dim, ", + num_jagged_dim); + + TORCH_CHECK(values.is_xpu(), "value must be a xpu tensor"); + for (auto& offset : offsets) { + TORCH_CHECK(offset.is_xpu(), "offset must be a xpu tensor"); + } + + XPU_DEVICE_GUARD(values); + + const Tensor values_canonicalized = values.view( + {values.size(0), + std::accumulate( + values.sizes().begin() + 1, + values.sizes().end(), + 1, + std::multiplies())}); + at::SymDimVector padded_values_shape({at::SymInt(offsets[0].size(0) - 1)}); + padded_values_shape.insert( + padded_values_shape.end(), max_lengths.begin(), max_lengths.end()); + + // Canonicalize padded_values by unsqueeze the last dim if the inner dense + // dimension is 1 and folded. + const bool D_folded = values.dim() == 1; + if (!D_folded) { + padded_values_shape.push_back(values.size(-1)); + } + Tensor padded_values = + at::empty_symint(padded_values_shape, values.options()); + Tensor padded_values_view = + D_folded ? padded_values.unsqueeze(-1) : padded_values; + + num_jagged_dim = padded_values_view.dim() - 2; + TORCH_CHECK( + offsets.size() == static_cast(num_jagged_dim), + "x_offsets.size(), ", + offsets.size(), + " != num_jagged_dim ", + num_jagged_dim); + + if (padded_values_view.numel() == 0) { + return padded_values; + } + + jagged_to_padded_dense_forward_xpu_kernel( + values_canonicalized, + offsets, + padded_values_view, + padded_values_view, + padding_value); + + return padded_values; +} + +class DenseToJaggedOp : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const Tensor& dense, + const std::vector& offsets, + const std::optional& total_L) { + // uncomment when implement backward + + // dims of dense tensor: + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::dense_to_jagged_forward", "") + .typed& offsets, + std::optional total_L)>(); + auto output = op.call(dense, offsets, total_L); + + return {output}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_outputs) { + // TODO: backward kernel + return { + torch::autograd::Variable(), + torch::autograd::Variable(), // offsets + torch::autograd::Variable() // total_L + }; + } +}; + +// output = x + y where x is jagged, y is dense, and output is jagged +std::tuple> dense_to_jagged( + const Tensor& dense, + const std::vector& offsets, + std::optional total_L) { + return {DenseToJaggedOp::apply(dense, offsets, total_L)[0], offsets}; +} + +class JaggedToPaddedDenseOp + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const Tensor& values, + const std::vector& offsets, + const c10::SymIntArrayRef max_lengths, + const double padding_value) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::jagged_to_padded_dense_forward", "") + .typed& offsets, + at::ArrayRef max_lengths, + const double padding_value)>(); + Tensor padded_values = op.call(values, offsets, max_lengths, padding_value); + + return {padded_values}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_outputs) { + // TODO: backward kernel + return { + torch::autograd::Variable(), + torch::autograd::Variable(), // offsets + torch::autograd::Variable(), // max_lengths + torch::autograd::Variable(), // padding_value + }; + } +}; + +Tensor jagged_to_padded_dense( + const Tensor& values, + const std::vector& offsets, + const c10::SymIntArrayRef max_lengths, + const double padding_value = 0.0) { + return JaggedToPaddedDenseOp::apply( + values, offsets, max_lengths, padding_value)[0]; +} + +class JaggedDenseAddJaggedOutputOp + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const Tensor& x_values, + const std::vector& offsets, + const Tensor& dense) { + TORCH_CHECK(x_values.is_xpu(), "value must be a xpu tensor"); + for (auto& offset : offsets) { + TORCH_CHECK(offset.is_xpu(), "offset must be a xpu tensor"); + } + TORCH_CHECK(dense.is_xpu(), "dense must be a xpu tensor"); + + const int num_jagged_dim = dense.dim() - 2; + TORCH_CHECK( + offsets.size() == static_cast(num_jagged_dim), + "x_offsets.size(), ", + offsets.size(), + " != num_jagged_dim, ", + num_jagged_dim); + + auto output = at::empty_like(x_values); + if (dense.numel() == 0 || x_values.numel() == 0) { + return {output}; + } + + XPU_DEVICE_GUARD(dense); + jagged_dense_elementwise_add_jagged_output_fwd_xpu_kn( + x_values, offsets, dense, output); + + return {output}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_outputs) { + // TODO: backward kernel + return { + torch::autograd::Variable(), + torch::autograd::Variable(), // offsets + torch::autograd::Variable()}; + } +}; + +std::tuple> +jagged_dense_elementwise_add_jagged_output_xpu( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y) { + auto sum_values = + JaggedDenseAddJaggedOutputOp::apply(x_values, x_offsets, y)[0]; + + return {sum_values, x_offsets}; +} + +Tensor reorder_batched_ad_lengths_xpu( + const Tensor& cat_ad_lengths, + const Tensor& batch_offsets, + const int64_t num_ads_in_batch, + const bool broadcast_lengths, + const int64_t max_batch_size = 0) { + TORCH_CHECK_LE(max_batch_size, 0); + TENSORS_ON_SAME_XPU_IF_NOT_OPTIONAL(cat_ad_lengths, batch_offsets); + + XPU_DEVICE_GUARD(cat_ad_lengths); + + const int64_t B = batch_offsets.numel() - 1; + const int64_t T = broadcast_lengths + ? cat_ad_lengths.numel() / B + : cat_ad_lengths.numel() / num_ads_in_batch; + + Tensor reordered_cat_ad_lengths = broadcast_lengths + ? at::empty({T * num_ads_in_batch}, cat_ad_lengths.options()) + : at::empty_like(cat_ad_lengths); + + const int64_t grid_size = (B * T + 32 - 1) / 32; + TORCH_CHECK( + grid_size >= 0, + "grid_size must be positive, got ", + grid_size, + " where B =", + B, + " and T =", + T); + + reorder_batched_ad_lengths_xpu_kernel( + cat_ad_lengths, + batch_offsets, + reordered_cat_ad_lengths, + T, + broadcast_lengths, + grid_size); + + return reordered_cat_ad_lengths; +} + +Tensor reorder_batched_ad_indices_xpu( + const at::Tensor& cat_ad_offsets, + const at::Tensor& cat_ad_indices, + const at::Tensor& reordered_cat_ad_offsets, + const at::Tensor& batch_offsets, + const int64_t num_ads_in_batch, + const bool broadcast_indices = false, + const int64_t num_indices_after_broadcast = -1) { + TENSORS_ON_SAME_XPU_IF_NOT_OPTIONAL( + cat_ad_offsets, cat_ad_indices, reordered_cat_ad_offsets, batch_offsets); + + XPU_DEVICE_GUARD(cat_ad_offsets); + + const int64_t B = batch_offsets.numel() - 1; + const int64_t T = (reordered_cat_ad_offsets.numel() - 1) / num_ads_in_batch; + Tensor reordered_cat_ad_indices; + if (broadcast_indices) { + TORCH_CHECK_GE(num_indices_after_broadcast, 0); + reordered_cat_ad_indices = + at::empty({num_indices_after_broadcast}, cat_ad_indices.options()); + } else { + reordered_cat_ad_indices = at::empty_like(cat_ad_indices); + } + + reorder_batched_ad_indices_xpu_kernel( + cat_ad_offsets, + cat_ad_indices, + reordered_cat_ad_offsets, + batch_offsets, + reordered_cat_ad_indices, + num_ads_in_batch, + B, + T, + broadcast_indices); + + return reordered_cat_ad_indices; +} + +Tensor asynchronous_exclusive_cumsum_(const Tensor& t_in) { + torch_tensor_on_xpu_check(t_in); + XPU_DEVICE_GUARD(t_in); + + if (t_in.numel() == 0) { + return at::empty_like(t_in); + } + + TORCH_CHECK(t_in.is_contiguous()); + TORCH_CHECK(t_in.dtype() == at::kInt || t_in.dtype() == at::kLong); + // only handles up to INT_MAX elements. + TORCH_CHECK(t_in.numel() < std::numeric_limits::max()); + auto t_in_flatten = t_in.flatten(); + auto t_out = at::empty_like(t_in_flatten); + + cumsum_kernel(t_out, t_in_flatten, 0); + + // make it exclusive + t_out = t_out.roll(1, 0); + // set all first elemnts 0 + t_out[0] = 0; + return t_out.view_as(t_in); +} + +std::tuple> +permute_2D_sparse_data_xpu( + const at::Tensor& permute, + const at::Tensor& lengths, + const at::Tensor& indices, + const std::optional& weights, + const std::optional& permuted_lengths_sum) { + TENSORS_ON_SAME_XPU_IF_NOT_OPTIONAL(permute, lengths, indices, weights); + TORCH_CHECK(lengths.dim() == 2); + + XPU_DEVICE_GUARD(indices); + + const auto permute_contig = permute.contiguous(); + const auto lengths_contig = lengths.contiguous(); + const auto indices_contig = indices.contiguous(); + // the data to permute over can be less or more with or without + // repetitions + const auto T = permute.numel(); + const auto B = lengths.size(1); + + if (T == 0 || B == 0) { + // When T = 0 or B = 0, permutation will not be performed. Return the + // input tensors. + return { + lengths.clone(), + indices.clone(), + weights.has_value() ? std::make_optional(weights->clone()) + : std::nullopt}; + } + + Tensor permuted_lengths = at::empty({T, B}, lengths.options()); + Tensor permuted_indices; + Tensor permuted_weights; + + permute_2D_lengths_kernel_xpu( + T, B, lengths_contig, permute_contig, permuted_lengths); + + // convert lengths to offsets + const auto input_offsets = asynchronous_exclusive_cumsum_(lengths_contig); + const auto output_offsets = + asynchronous_complete_cumsum_xpu(permuted_lengths.flatten()); + int64_t permuted_indices_size = 0; + if (permuted_lengths_sum.has_value()) { + permuted_indices_size = permuted_lengths_sum.value(); + } else { + permuted_indices_size = output_offsets[-1].item(); + } + + permuted_indices = at::empty(permuted_indices_size, indices.options()); + + if (weights.has_value()) { + const Tensor weights_value = weights.value(); + int32_t weights_columns = 1; + if (weights_value.dense_dim() > 1) { + weights_columns = weights_value.size(1); + permuted_weights = at::empty( + {permuted_indices_size, weights_columns}, weights_value.options()); + } else { + permuted_weights = + at::empty(permuted_indices_size, weights_value.options()); + } + permute_2D_data_kernel_xpu( + permuted_indices_size, + T, + B, + indices_contig, + std::optional{weights_value}, + weights_columns, + permute_contig, + input_offsets, + output_offsets, + permuted_indices, + std::optional{permuted_weights}); + } else { + permute_2D_data_kernel_xpu( + permuted_indices_size, + T, + B, + indices_contig, + std::nullopt, + 0, + permute_contig, + input_offsets, + output_offsets, + permuted_indices, + std::nullopt); + } + + return {permuted_lengths, permuted_indices, permuted_weights}; +} + +} // namespace xpu +} // namespace native +} // namespace at + +namespace { + +TORCH_LIBRARY_IMPL(fbgemm, XPU, m) { + m.impl( + "asynchronous_complete_cumsum", + &at::native::xpu::asynchronous_complete_cumsum_xpu); + m.impl("dense_to_jagged", &at::native::xpu::dense_to_jagged); + m.impl( + "dense_to_jagged_forward", &at::native::xpu::dense_to_jagged_forward_xpu); + m.impl("jagged_to_padded_dense", &at::native::xpu::jagged_to_padded_dense); + m.impl( + "jagged_to_padded_dense_forward", + &at::native::xpu::jagged_to_padded_dense_forward_xpu); + m.impl( + "jagged_dense_elementwise_add_jagged_output", + &at::native::xpu::jagged_dense_elementwise_add_jagged_output_xpu); + m.impl( + "reorder_batched_ad_lengths", + &at::native::xpu::reorder_batched_ad_lengths_xpu); + m.impl( + "reorder_batched_ad_indices", + &at::native::xpu::reorder_batched_ad_indices_xpu); + m.impl( + "permute_2D_sparse_data", &at::native::xpu::permute_2D_sparse_data_xpu); +} + +} // namespace diff --git a/src/ATen/native/xpu/sycl/FbgemmKernels.cpp b/src/ATen/native/xpu/sycl/FbgemmKernels.cpp new file mode 100644 index 000000000..70ebba4b2 --- /dev/null +++ b/src/ATen/native/xpu/sycl/FbgemmKernels.cpp @@ -0,0 +1,1541 @@ +#include +#include +#include + +#include + +namespace syclext = sycl::ext::oneapi; +namespace syclexp = sycl::ext::oneapi::experimental; + +namespace at { +namespace native { +namespace xpu { + +template +struct StackArray { + T vals[kStackArrayMaxDims]; + size_t ndim; +}; + +template +class SimpleAddFunctor3 { + public: + T operator()(T x, T y, T /*unused*/) { + return x + y; + } +}; + +template +class SimpleRetSecondFunctor3 { + public: + T operator()(T /*unused*/, T y, T /*unused*/) { + return y; + } +}; + +template +class SimpleRetFirstFunctor2 { + public: + T operator()(T x, T /*unused*/) { + return x; + } +}; + +#define JAGGED_TENSOR_DISPATCH_DIMS() \ + AT_DISPATCH_INDEX_TYPES(x_offsets[0].scalar_type(), "jagged_indices", [=] { \ + switch (num_jagged_dim) { \ + case 1: \ + INVOKE_KERNEL_WITH_DIM(1); \ + break; \ + case 2: \ + INVOKE_KERNEL_WITH_DIM(2); \ + break; \ + case 3: \ + INVOKE_KERNEL_WITH_DIM(3); \ + break; \ + case 4: \ + INVOKE_KERNEL_WITH_DIM(4); \ + break; \ + case 5: \ + INVOKE_KERNEL_WITH_DIM(5); \ + break; \ + default: \ + TORCH_CHECK( \ + false, "unsupported number of jagged dim ", num_jagged_dim); \ + } \ + }); + +template +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<2>)) +void jagged_dense_elementwise_jagged_output_kernel_( + GenericPackedTensorAccessor + x_values, + StackArray x_offsets, + StackArray x_offsets_sizes, + GenericPackedTensorAccessor y_0, + GenericPackedTensorAccessor y_1, + GenericPackedTensorAccessor + output_values, + StackArray jagged_dims, + F f) { + auto output_values_acc = output_values; + const int outer_dense_size = y_0.size(0); + const int inner_dense_size = y_0.size(2); + const int nnz = x_values.size(0); + + auto item = syclext::this_work_item::get_nd_item<2>(); + const int offset_begin = + item.get_group(0) * item.get_local_range(1) + item.get_local_id(1); + const int offset_stride = item.get_global_range(0) * item.get_local_range(1); + for (int offset = offset_begin; offset < nnz; offset += offset_stride) { + int offset_temp = offset; + int jidx = 0; + bool truncated = false; + int dim_prod = 1; +#pragma unroll + for (int d = NUM_JAGGED_DIM - 1; d >= 0; --d) { + // Binary search the first that is bigger than offset + int count = x_offsets_sizes.vals[d] - 1; + int first = 1; + while (count > 0) { + int idx = first; + int step = count / 2; + idx += step; + if (x_offsets.vals[d][idx] <= offset_temp) { + first = ++idx; + count -= step + 1; + } else { + count = step; + } + } + + --first; + int coord = offset_temp - x_offsets.vals[d][first]; + if (coord >= jagged_dims.vals[d]) { + truncated = true; + break; + } + jidx += coord * dim_prod; + dim_prod *= jagged_dims.vals[d]; + offset_temp = first; + } + + if (offset_temp >= outer_dense_size) { + // This can happen when values have more elements than the last element of + // offset + truncated = true; + } + if (!truncated) { + const int oidx = offset_temp; + int iidx; + for (iidx = item.get_local_id(0); iidx * 2 + 1 < inner_dense_size; + iidx += item.get_local_range(0)) { + output_values_acc[offset][2 * iidx] = + f(x_values[offset][2 * iidx], + y_0[oidx][jidx][2 * iidx], + y_1[oidx][jidx][2 * iidx]); + output_values_acc[offset][2 * iidx + 1] = + f(x_values[offset][2 * iidx + 1], + y_0[oidx][jidx][2 * iidx + 1], + y_1[oidx][jidx][2 * iidx + 1]); + } + if (iidx * 2 + 1 == inner_dense_size) { + output_values_acc[offset][2 * iidx] = + f(x_values[offset][2 * iidx], + y_0[oidx][jidx][2 * iidx], + y_1[oidx][jidx][2 * iidx]); + } + } else { + int iidx; + for (iidx = item.get_local_id(0); iidx * 2 + 1 < inner_dense_size; + iidx += item.get_local_range(0)) { + output_values_acc[offset][2 * iidx] = + f(x_values[offset][2 * iidx], 0, 0); + output_values_acc[offset][2 * iidx + 1] = + f(x_values[offset][2 * iidx + 1], 0, 0); + } + if (iidx * 2 + 1 == inner_dense_size) { + output_values_acc[offset][2 * iidx] = + f(x_values[offset][2 * iidx], 0, 0); + } + } + } +} + +template +void jagged_dense_elementwise_jagged_output_launch_( + GenericPackedTensorAccessor + x_values, // output + StackArray x_offsets, + StackArray x_offsets_sizes, + GenericPackedTensorAccessor + y_0, // not used + GenericPackedTensorAccessor y_1, + GenericPackedTensorAccessor + output_values, // not used + StackArray jagged_dims, + F f, + int64_t wg_0, + int64_t wg_1, + int64_t wg_num) { + sycl_kernel_submit>( + sycl::range<2>(wg_0 * wg_num, wg_1), + sycl::range<2>(wg_0, wg_1), + getCurrentSYCLQueue(), + 0, + x_values, + x_offsets, + x_offsets_sizes, + y_0, + y_1, + output_values, + jagged_dims, + f); +} + +template +bool walk_down_tensor_storage_tree_( + int& offset, + const int flattened_jagged_idx, + const StackArray& jagged_dims, + const StackArray& x_offsets) { + // compute coorindates + int jagged_coords[NUM_JAGGED_DIM]; + int j_temp = flattened_jagged_idx; +#pragma unroll + for (int d = NUM_JAGGED_DIM - 1; d >= 0; --d) { + const int jagged_size = jagged_dims.vals[d]; + jagged_coords[d] = j_temp % jagged_size; + j_temp /= jagged_size; + } + + // walk down the tree + bool is_zero = false; +#pragma unroll + for (int d = 0; d < NUM_JAGGED_DIM; ++d) { + const int begin = x_offsets.vals[d][offset]; + const int end = x_offsets.vals[d][offset + 1]; + if (jagged_coords[d] >= end - begin) { + is_zero = true; + break; + } + offset = begin + jagged_coords[d]; + } + return is_zero; +} + +inline std::tuple> +check_shape_and_partition_( + const Tensor& values, + const std::vector& offsets, + const Tensor& dense_tensor) { + const int32_t outer_dense_size = dense_tensor.size(0); + TORCH_CHECK( + outer_dense_size == offsets[0].numel() - 1, + "outer_dense_size, ", + outer_dense_size, + " != offsets[0].numel() - 1, ", + offsets[0].numel() - 1); + const int32_t inner_dense_size = dense_tensor.size(-1); + TORCH_CHECK( + inner_dense_size == values.size(-1), + "inner_dense_size, ", + inner_dense_size, + " != values.size(-1), ", + values.size(-1)); + const int32_t jagged_folded_size = + dense_tensor.numel() / (outer_dense_size * inner_dense_size); + + const int32_t sub_group_size = syclMaxSubGroupSize(); + const int64_t wg_size_0 = inner_dense_size >= sub_group_size / 2 + ? sub_group_size + : inner_dense_size; + const int64_t wg_size_1 = syclDeviceMaxWorkGroupSize() / sub_group_size; + const int64_t wg_num = + CeilDivUp(outer_dense_size * jagged_folded_size, (int32_t)wg_size_1); + + StackArray jagged_dims_tensor; + const int32_t num_jagged_dim = dense_tensor.dim() - 2; + TORCH_CHECK(num_jagged_dim <= kStackArrayMaxDims); + jagged_dims_tensor.ndim = num_jagged_dim; + std::memcpy( + &(jagged_dims_tensor.vals[0]), + dense_tensor.sizes().data() + 1, + num_jagged_dim * sizeof(int64_t)); + return {wg_size_0, wg_size_1, wg_num, jagged_dims_tensor}; +} + +template +void jagged_dense_elementwise_jagged_output_( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y, + const Tensor& output_values, + F f) { + // Canonicalize y to 3D, collapsing jagged dimensions. + const int num_jagged_dim = y.dim() - 2; + const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)}); +#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \ + { \ + int64_t wg_0, wg_1, wg_num; \ + StackArray jagged_dims_tensor; \ + std::tie(wg_0, wg_1, wg_num, jagged_dims_tensor) = \ + check_shape_and_partition_(x_values, x_offsets, y); \ + wg_num = CeilDivUp(x_values.size(0), wg_1); \ + std::vector x_offsets_contig; \ + x_offsets_contig.resize(num_jagged_dim); \ + StackArray x_offset_ptrs; \ + x_offset_ptrs.ndim = num_jagged_dim; \ + StackArray x_offset_sizes; \ + x_offset_sizes.ndim = num_jagged_dim; \ + for (int d = 0; d < num_jagged_dim; ++d) { \ + x_offsets_contig[d] = x_offsets[d].contiguous(); \ + x_offset_ptrs.vals[d] = \ + x_offsets_contig[d].template data_ptr(); \ + x_offset_sizes.vals[d] = x_offsets[d].numel(); \ + } \ + jagged_dense_elementwise_jagged_output_launch_( \ + x_values.packed_accessor32(), \ + x_offset_ptrs, \ + x_offset_sizes, \ + y_reshaped.packed_accessor32(), \ + y_reshaped.packed_accessor32(), \ + output_values.packed_accessor32(), \ + jagged_dims_tensor, \ + f, \ + wg_0, \ + wg_1, \ + wg_num); \ + } + + JAGGED_TENSOR_DISPATCH_DIMS(); +#undef INVOKE_KERNEL_WITH_DIM +} + +void dense_to_jagged_forward_xpu_kernel( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y, + const Tensor& output_values) { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::BFloat16, + at::ScalarType::Half, + x_values.scalar_type(), + "dense_to_jagged_forward_xpu_kernel", + [&]() { + jagged_dense_elementwise_jagged_output_( + x_values, + x_offsets, + y, + output_values, + SimpleRetSecondFunctor3()); + }); +} + +template +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<2>)) +void jagged_dense_elementwise_dense_output_kernel_( + GenericPackedTensorAccessor + x_values, + StackArray x_offsets, + GenericPackedTensorAccessor y, + GenericPackedTensorAccessor output, + StackArray jagged_dims, + const scalar_t padding_value, + F f) { + auto output_acc = output; + const int outer_dense_size = y.size(0); + const int jagged_folded_size = y.size(1); + const int inner_dense_size = y.size(2); + + auto item = syclext::this_work_item::get_nd_item<2>(); + const int outer_begin = + item.get_group(0) * item.get_local_range(1) + item.get_local_id(1); + const int outer_stride = item.get_global_range(0) * item.get_local_range(1); + for (int outer = outer_begin; outer < outer_dense_size * jagged_folded_size; + outer += outer_stride) { + const int oidx = outer / jagged_folded_size; + const int jidx = outer % jagged_folded_size; + + int offset = oidx; + const bool is_zero = walk_down_tensor_storage_tree_( + offset, jidx, jagged_dims, x_offsets); + + if (is_zero) { + int iidx; + for (iidx = item.get_local_id(0); iidx * 2 + 1 < inner_dense_size; + iidx += item.get_local_range(0)) { + output_acc[oidx][jidx][2 * iidx] = + f(padding_value, y[oidx][jidx][2 * iidx]); + output_acc[oidx][jidx][2 * iidx + 1] = + f(padding_value, y[oidx][jidx][2 * iidx + 1]); + } + if (iidx * 2 + 1 == inner_dense_size) { + output_acc[oidx][jidx][2 * iidx] = + f(padding_value, y[oidx][jidx][2 * iidx]); + } + } else { + int iidx; + for (iidx = item.get_local_id(0); iidx * 2 + 1 < inner_dense_size; + iidx += item.get_local_range(0)) { + output_acc[oidx][jidx][2 * iidx] = + f(x_values[offset][2 * iidx], y[oidx][jidx][2 * iidx]); + output_acc[oidx][jidx][2 * iidx + 1] = + f(x_values[offset][2 * iidx + 1], y[oidx][jidx][2 * iidx + 1]); + } + if (iidx * 2 + 1 == inner_dense_size) { + output_acc[oidx][jidx][2 * iidx] = + f(x_values[offset][2 * iidx], y[oidx][jidx][2 * iidx]); + } + } + } +} + +template +void jagged_dense_elementwise_dense_output_launch_( + const GenericPackedTensorAccessor + x_values, + StackArray x_offsets, + const GenericPackedTensorAccessor y, + GenericPackedTensorAccessor output, + StackArray jagged_dims, + const scalar_t padding_value, + int64_t wg_0, + int64_t wg_1, + int64_t wg_num) { + sycl_kernel_submit>>( + sycl::range<2>(wg_0 * wg_num, wg_1), + sycl::range<2>(wg_0, wg_1), + getCurrentSYCLQueue(), + 0, + x_values, + x_offsets, + y, + output, + jagged_dims, + padding_value, + SimpleRetFirstFunctor2()); +} + +template +void jagged_dense_elementwise_dense_output_( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y, + const Tensor& output, + const scalar_t padding_value) { + int64_t wg_0, wg_1, wg_num; + StackArray jagged_dims_tensor; + std::tie(wg_0, wg_1, wg_num, jagged_dims_tensor) = + check_shape_and_partition_(x_values, x_offsets, y); + + // Canonicalize y and output to 3D, collapsing jagged dimensions. + const int num_jagged_dim = y.dim() - 2; + const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)}); + Tensor output_reshaped = output.view(y_reshaped.sizes()); + +#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \ + { \ + std::vector x_offsets_contig; \ + x_offsets_contig.resize(num_jagged_dim); \ + StackArray x_offset_ptrs; \ + x_offset_ptrs.ndim = num_jagged_dim; \ + for (int d = 0; d < num_jagged_dim; ++d) { \ + x_offsets_contig[d] = x_offsets[d].contiguous(); \ + x_offset_ptrs.vals[d] = \ + x_offsets_contig[d].template data_ptr(); \ + } \ + jagged_dense_elementwise_dense_output_launch_( \ + x_values.packed_accessor32(), \ + x_offset_ptrs, \ + y_reshaped.packed_accessor32(), \ + output_reshaped.packed_accessor32(), \ + jagged_dims_tensor, \ + padding_value, \ + wg_0, \ + wg_1, \ + wg_num); \ + } + + JAGGED_TENSOR_DISPATCH_DIMS(); +#undef INVOKE_KERNEL_WITH_DIM +} + +void jagged_to_padded_dense_forward_xpu_kernel( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y, + const Tensor& output, + const double padding_value) { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::BFloat16, + at::ScalarType::Half, + x_values.scalar_type(), + "jagged_to_padded_dense_forward_xpu_kernel", + [&] { + jagged_dense_elementwise_dense_output_( + x_values, + x_offsets, + y, // not used + output, + static_cast(padding_value)); + }); +} + +// Check to see if the inputs to the op are amenable to the fast path +inline bool jagged_dense_dense_elementwise_jagged_output_matches_opt( + const int& num_jagged_dim, + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y_0_reshaped, + const Tensor& y_1_reshaped, + const Tensor& output_values) { + bool matches = true; + matches &= (num_jagged_dim == 1); + + // Unit stride embedding dim + matches &= (x_values.stride(-1) == 1); + matches &= (output_values.stride(-1) == 1); + matches &= (y_0_reshaped.stride(-1) == 1); + matches &= (y_1_reshaped.stride(-1) == 1); + + // Each row is aligned to 128-bit + matches &= ((x_values.stride(-2) & 0x7) == 0); + matches &= ((output_values.stride(-2) & 0x7) == 0); + matches &= ((y_0_reshaped.stride(-2) & 0x7) == 0); + matches &= ((y_1_reshaped.stride(-2) & 0x7) == 0); + + // Base addresses aligned to 128-bit + matches &= ((reinterpret_cast(x_values.data_ptr()) & 0xF) == 0); + matches &= + ((reinterpret_cast(output_values.data_ptr()) % 0xF) == 0); + matches &= ((reinterpret_cast(y_0_reshaped.data_ptr()) % 0xF) == 0); + matches &= ((reinterpret_cast(y_1_reshaped.data_ptr()) % 0xF) == 0); + + // Rows and col fit into int32_t + matches &= (y_0_reshaped.size(0) < INT_MAX); + matches &= (y_0_reshaped.size(1) < INT_MAX); + + // maximum shared local memory size + int max_shared_bytes = syclLocalMemSize(); + // Use all shared memory, no L1 cache consideration + int max_shared_kb = max_shared_bytes >> 10; + int used_shared_kb = round_down(max_shared_kb, 16); + TORCH_CHECK(used_shared_kb > 0); + int used_shared_bytes = used_shared_kb << 10; + AT_DISPATCH_INDEX_TYPES( + x_offsets[0].scalar_type(), "check_shared_memory", [&] { + auto B = y_0_reshaped.size(0); + // the default shared memory on V100/A100/H100 is 48 KB from + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory-8-x + if ((B + 1) * sizeof(index_t) >= used_shared_bytes) { + matches = false; + } + }); + return matches; +} + +template +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<3>)) +void jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_( + const GenericPackedTensorAccessor< + index_t, + 1, + at::RestrictPtrTraits, + int32_t> offsets, + GenericPackedTensorAccessor rows, + GenericPackedTensorAccessor cols, + int nnz, + int B) { + index_t* offsets_sh = + reinterpret_cast(syclexp::get_work_group_scratch_memory()); + auto item = syclext::this_work_item::get_nd_item<3>(); + + for (auto i = item.get_local_id(0); i < B + 1; i += item.get_local_range(0)) { + offsets_sh[i] = offsets[i]; + } + group_barrier(item.get_group()); + auto row = item.get_local_id(0) + item.get_group(0) * item.get_local_range(0); + if (row >= nnz) + return; + int first = -1; + int count = B - 1; + first = 1; + while (count > 0) { + int idx = first; + int step = count / 2; + idx += step; + if (offsets_sh[idx] <= row) { + first = ++idx; + count -= step + 1; + } else { + count = step; + } + } + --first; + + int dense_row = first; + int offset = offsets_sh[dense_row]; + int dense_col = row - offset; + rows[row] = dense_row; + cols[row] = dense_col; +} + +template +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<3>)) +void jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_( + GenericPackedTensorAccessor + values, + const GenericPackedTensorAccessor< + c10::Half, + 2, + at::RestrictPtrTraits, + int32_t> x_values, + const GenericPackedTensorAccessor< + c10::Half, + 3, + at::RestrictPtrTraits, + int32_t> y0, + const GenericPackedTensorAccessor< + c10::Half, + 3, + at::RestrictPtrTraits, + int32_t> y1, + const GenericPackedTensorAccessor + rows, + const GenericPackedTensorAccessor + cols, + const int nnz, + const int E, + F f) { + auto item = syclext::this_work_item::get_nd_item<3>(); + auto values_row = + item.get_local_id(1) + item.get_group(1) * item.get_local_range(1); + if (values_row >= nnz) + return; + for (int real_row = values_row; real_row < nnz; + real_row += item.get_local_range(1) * item.get_group_range(1)) { + int dense_row = rows[real_row]; + int dense_col = cols[real_row]; + sycl::half* values_ptr = + reinterpret_cast(&values[real_row][0]); + const sycl::half* x_ptr = + reinterpret_cast(&x_values[real_row][0]); + const sycl::half* y0_ptr = + reinterpret_cast(&y0[dense_row][dense_col][0]); + const sycl::half* y1_ptr = + reinterpret_cast(&y1[dense_row][dense_col][0]); + if ((dense_col < y0.size(1)) && (dense_row < y0.size(0)) && + (dense_col < y1.size(1)) && (dense_row < y1.size(0)) && + (dense_col >= 0) && (dense_row >= 0)) { + for (auto tid = item.get_local_id(0); tid < E / 8; + tid += item.get_local_range(0)) { + VecType128 v_x, v_out, v_y0, v_y1; + v_x.data.mask = + (reinterpret_cast(x_ptr))[tid]; + v_y0.data.mask = + (reinterpret_cast(y0_ptr))[tid]; + v_y1.data.mask = + (reinterpret_cast(y1_ptr))[tid]; + f128(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (auto tid = item.get_local_id(0) + (E / 8) * 8; tid < E / 4; + tid += item.get_local_range(0)) { + VecType64 v_x, v_out, v_y0, v_y1; + v_x.data.mask = (reinterpret_cast(x_ptr))[tid]; + v_y0.data.mask = + (reinterpret_cast(y0_ptr))[tid]; + v_y1.data.mask = + (reinterpret_cast(y1_ptr))[tid]; + f64(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (auto tid = item.get_local_id(0) + (E / 4) * 4; tid < E / 2; + tid += item.get_local_range(0)) { + VecType32 v_x, v_out, v_y0, v_y1; + v_x.data.mask = (reinterpret_cast(x_ptr))[tid]; + v_y0.data.mask = + (reinterpret_cast(y0_ptr))[tid]; + v_y1.data.mask = + (reinterpret_cast(y1_ptr))[tid]; + f32(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (auto tid = item.get_local_id(0) + (E / 2) * 2; tid < E; + tid += item.get_local_range(0)) { + sycl::half v_x, v_out, v_y0, v_y1; + v_x = static_cast(x_ptr[tid]); + v_y0 = static_cast(y0_ptr[tid]); + v_y1 = static_cast(y1_ptr[tid]); + fh(v_out, v_x, v_y0, v_y1, f); + values_ptr[tid] = v_out; + } + } else { + for (auto tid = item.get_local_id(0); tid < E / 8; + tid += item.get_local_range(0)) { + VecType128 v_x, v_out, v_y0, v_y1; + v_x.data.mask = + (reinterpret_cast(x_ptr))[tid]; + f128(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (auto tid = item.get_local_id(0) + (E / 8) * 8; tid < E / 4; + tid += item.get_local_range(0)) { + VecType64 v_x, v_out, v_y0, v_y1; + v_x.data.mask = (reinterpret_cast(x_ptr))[tid]; + f64(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (auto tid = item.get_local_id(0) + (E / 4) * 4; tid < E / 2; + tid += item.get_local_range(0)) { + VecType32 v_x, v_out, v_y0, v_y1; + v_x.data.mask = (reinterpret_cast(x_ptr))[tid]; + f32(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (auto tid = item.get_local_id(0) + (E / 2) * 2; tid < E; + tid += item.get_local_range(0)) { + sycl::half v_x, v_out, v_y0, v_y1; + v_x = static_cast(x_ptr[tid]); + fh(v_out, v_x, v_y0, v_y1, f); + values_ptr[tid] = v_out; + } + } + } +} + +template +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<2>)) +void jagged_dense_dense_elementwise_jagged_output_kernel_( + const GenericPackedTensorAccessor< + scalar_t, + 2, + at::RestrictPtrTraits, + int32_t> x_values, + StackArray x_offsets, + StackArray x_offsets_sizes, + const GenericPackedTensorAccessor< + scalar_t, + 3, + at::RestrictPtrTraits, + int32_t> y_0, + const GenericPackedTensorAccessor< + scalar_t, + 3, + at::RestrictPtrTraits, + int32_t> y_1, + GenericPackedTensorAccessor + output_values, + StackArray jagged_dims, + F f) { + const int outer_dense_size = y_0.size(0); + const int inner_dense_size = y_0.size(2); + const int nnz = x_values.size(0); + + auto item = syclext::this_work_item::get_nd_item<2>(); + + const auto offset_begin = + item.get_group(0) * item.get_local_range(1) + item.get_local_id(1); + const auto offset_stride = item.get_group_range(0) * item.get_local_range(1); + for (int offset = offset_begin; offset < nnz; offset += offset_stride) { + int offset_temp = offset; + int jidx = 0; + bool truncated = false; + int dim_prod = 1; +#pragma unroll + for (int d = NUM_JAGGED_DIM - 1; d >= 0; --d) { + // Binary search the first that is bigger than offset + int count = x_offsets_sizes.vals[d] - 1; + int first = 1; + while (count > 0) { + int idx = first; + int step = count / 2; + idx += step; + if (x_offsets.vals[d][idx] <= offset_temp) { + first = ++idx; + count -= step + 1; + } else { + count = step; + } + } + + --first; + int coord = offset_temp - x_offsets.vals[d][first]; + if (coord >= jagged_dims.vals[d]) { + truncated = true; + break; + } + jidx += coord * dim_prod; + dim_prod *= jagged_dims.vals[d]; + offset_temp = first; + } + + if (offset_temp >= outer_dense_size) { + // This can happen when values have more elements than the last element of + // offset + truncated = true; + } + if (!truncated) { + const int oidx = offset_temp; + int iidx; + for (iidx = item.get_local_id(0); iidx * 2 + 1 < inner_dense_size; + iidx += item.get_local_range(0)) { + output_values[offset][2 * iidx] = + f(x_values[offset][2 * iidx], + y_0[oidx][jidx][2 * iidx], + y_1[oidx][jidx][2 * iidx]); + output_values[offset][2 * iidx + 1] = + f(x_values[offset][2 * iidx + 1], + y_0[oidx][jidx][2 * iidx + 1], + y_1[oidx][jidx][2 * iidx + 1]); + } + if (iidx * 2 + 1 == inner_dense_size) { + output_values[offset][2 * iidx] = + f(x_values[offset][2 * iidx], + y_0[oidx][jidx][2 * iidx], + y_1[oidx][jidx][2 * iidx]); + } + } else { + int iidx; + for (iidx = item.get_local_id(0); iidx * 2 + 1 < inner_dense_size; + iidx += item.get_local_range(0)) { + output_values[offset][2 * iidx] = f(x_values[offset][2 * iidx], 0, 0); + output_values[offset][2 * iidx + 1] = + f(x_values[offset][2 * iidx + 1], 0, 0); + } + if (iidx * 2 + 1 == inner_dense_size) { + output_values[offset][2 * iidx] = f(x_values[offset][2 * iidx], 0, 0); + } + } + } +} + +// defined for jagged_dense_elementwise_jagged_output_opt_ +// and jagged_dense_elementwise_jagged_output_ +#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \ + { \ + int64_t wg_0, wg_1, wg_num; \ + StackArray jagged_dims_tensor; \ + std::tie(wg_0, wg_1, wg_num, jagged_dims_tensor) = \ + check_shape_and_partition_(x_values, x_offsets, y); \ + wg_num = CeilDivUp(x_values.size(0), wg_1); \ + std::vector x_offsets_contig; \ + x_offsets_contig.resize(num_jagged_dim); \ + StackArray x_offset_ptrs; \ + x_offset_ptrs.ndim = num_jagged_dim; \ + StackArray x_offset_sizes; \ + x_offset_sizes.ndim = num_jagged_dim; \ + for (int d = 0; d < num_jagged_dim; ++d) { \ + x_offsets_contig[d] = x_offsets[d].contiguous(); \ + x_offset_ptrs.vals[d] = \ + x_offsets_contig[d].template data_ptr(); \ + x_offset_sizes.vals[d] = x_offsets[d].numel(); \ + } \ + sycl_kernel_submit>( \ + sycl::range<2>(wg_0 * wg_num, wg_1), \ + sycl::range<2>(wg_0, wg_1), \ + getCurrentSYCLQueue(), \ + 0, \ + x_values.packed_accessor32(), \ + x_offset_ptrs, \ + x_offset_sizes, \ + y_reshaped.packed_accessor32(), \ + y_reshaped.packed_accessor32(), \ + output_values.packed_accessor32(), \ + jagged_dims_tensor, \ + f); \ + } // namespace xpu + +template +void jagged_dense_elementwise_jagged_output_opt_( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y, + const Tensor& output_values, + F f) { + // Canonicalize y to 3D, collapsing jagged dimensions. + const int num_jagged_dim = y.dim() - 2; + const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)}); + if (jagged_dense_dense_elementwise_jagged_output_matches_opt( + num_jagged_dim, + x_values, + x_offsets, + y_reshaped, + y_reshaped, + output_values)) { + AT_DISPATCH_INDEX_TYPES( + x_offsets[0].scalar_type(), "jagged_indices_fast_path", [=] { + auto nnz = output_values.size(0); + auto B = y_reshaped.size(0); + auto E = y_reshaped.size(2); + Tensor t_rows_after_bs = at::empty( + {nnz}, + at::TensorOptions().dtype(at::kInt).device( + at::kXPU, c10::xpu::current_device())); + Tensor t_cols_after_bs = at::empty( + {nnz}, + at::TensorOptions().dtype(at::kInt).device( + at::kXPU, c10::xpu::current_device())); + + // Binary search + size_t dynamic_smem_size = (B + 1) * sizeof(index_t); + auto max_shared_bytes = syclLocalMemSize(); + int max_shared_kb = max_shared_bytes >> 10; + int used_shared_kb = round_down(max_shared_kb, 16); + TORCH_CHECK(used_shared_kb > 0); + int used_shared_bytes = used_shared_kb << 10; + TORCH_CHECK(dynamic_smem_size <= used_shared_bytes); + + int max_wg_size = syclDeviceMaxWorkGroupSize(); + int wg_size = max_wg_size < 1024 ? max_wg_size : 1024; + + int nbr_of_wg = CeilDivUp(nnz, wg_size); + sycl_kernel_submit< + jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_< + index_t>>( + sycl::range<3>(nbr_of_wg * wg_size, 1, 1), + sycl::range<3>(wg_size, 1, 1), + getCurrentSYCLQueue(), + dynamic_smem_size, + x_offsets[0] + .packed_accessor32(), + t_rows_after_bs + .packed_accessor32(), + t_cols_after_bs + .packed_accessor32(), + nnz, + B); + + // Gather kernel + int dim_0_1 = 16; + int nbr_of_wg_g = CeilDivUp(nnz, dim_0_1); + if (nbr_of_wg_g > 65535) { + nbr_of_wg_g = round_down(65535, dim_0_1); + } + + sycl_kernel_submit< + jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_< + index_t, + F>>( + sycl::range<3>(dim_0_1 * 1, dim_0_1 * nbr_of_wg_g, 1), + sycl::range<3>(dim_0_1, dim_0_1, 1), + getCurrentSYCLQueue(), + 0, + output_values + .packed_accessor32(), + x_values.packed_accessor32(), + y_reshaped + .packed_accessor32(), + y_reshaped + .packed_accessor32(), + t_rows_after_bs + .packed_accessor32(), + t_cols_after_bs + .packed_accessor32(), + nnz, + E, + f); + }); // AT_DISPATCH + } else { + JAGGED_TENSOR_DISPATCH_DIMS(); + } +} + +void jagged_dense_elementwise_add_jagged_output_fwd_xpu_kn( + const Tensor& x_values, + const std::vector& offsets, + const Tensor& dense, + const Tensor& output_values) { + AT_DISPATCH_SWITCH( + x_values.scalar_type(), + "jagged_dense_elementwise_add_jagged_output_fwd_xpu_kn", + AT_DISPATCH_CASE( + at::ScalarType::Half, + [&] { + jagged_dense_elementwise_jagged_output_opt_( + x_values, + offsets, + dense, + output_values, + SimpleAddFunctor3()); // device lambda + } // lambda + ) // CASE + FBGEMM_DISPATCH_FLOAT_AND_BFLOAT16_CASE([&] { + jagged_dense_elementwise_jagged_output_( + x_values, + offsets, + dense, + output_values, + SimpleAddFunctor3()); // device lambda + } // lambda + ) // CASE_FLOATING_TYPES_AND + ); // SWITCH +} + +template +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<2>)) +void reorder_batched_ad_lengths_kernel_( + // reorder lengths from (ragged) [B x T x #num_ads_b)] to + // [T][B][#num_ads_b], i.e. [T][sum(#num_ads_b)]. + const GenericPackedTensorAccessor + cat_ad_lengths, + const GenericPackedTensorAccessor< + int32_t, + 1, + at::RestrictPtrTraits, + int32_t> batch_offsets, + GenericPackedTensorAccessor + reordered_cat_ad_lengths, + const int32_t T, + const bool broadcast_lengths) { + const int32_t B = batch_offsets.size(0) - 1; + + const int32_t num_ads_in_batch = batch_offsets[B]; + // warp-per-segment. + auto item = syclext::this_work_item::get_nd_item<2>(); + const auto b_t = + item.get_group(0) * item.get_local_range(1) + item.get_local_id(1); + const int32_t b = b_t % B; + const int32_t t = b_t / B; + if (t >= T) { + return; + } + + const int32_t num_ads_b = batch_offsets[b + 1] - batch_offsets[b]; + const int32_t input_segment_start = + broadcast_lengths ? T * b + t : T * batch_offsets[b] + t * num_ads_b; + const int32_t output_segment_start = t * num_ads_in_batch + batch_offsets[b]; + + for (auto i = item.get_local_id(0); i < num_ads_b; + i += item.get_local_range(0)) { + reordered_cat_ad_lengths[output_segment_start + i] = broadcast_lengths + ? cat_ad_lengths[input_segment_start] + : cat_ad_lengths[input_segment_start + i]; + } +} + +void reorder_batched_ad_lengths_xpu_kernel( + const Tensor& cat_ad_lengths, + const Tensor& batch_offsets, + Tensor& reordered_cat_ad_lengths, + const int32_t T, + const bool broadcast_lengths, + const int32_t grid_size) { + FBGEMM_DISPATCH_ALL_TYPES( + cat_ad_lengths.scalar_type(), + "reorder_batched_ad_lengths_xpu_kernel", + [&] { + sycl_kernel_submit>( + sycl::range<2>(32 * grid_size, 32), + sycl::range<2>(32, 32), + getCurrentSYCLQueue(), + 0, + cat_ad_lengths + .packed_accessor32(), + batch_offsets + .packed_accessor32(), + reordered_cat_ad_lengths + .packed_accessor32(), + T, + broadcast_lengths); + }); +} + +template +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>)) +void narrow_broadcast_indices_kernel_( + const GenericPackedTensorAccessor< + index_t, + 1, + at::RestrictPtrTraits, + int32_t> cat_ad_offsets, + const GenericPackedTensorAccessor + cat_ad_indices, + GenericPackedTensorAccessor + reordered_cat_ad_indices, + const int num_ads_in_batch, + const int reordered_cat_ad_batches, + const int subGroupSize) { + auto item = syclext::this_work_item::get_nd_item<1>(); + const auto lane_id = item.get_local_id(0) % subGroupSize; + const auto warp_id = + (item.get_group(0) * item.get_local_range(0) + item.get_local_id(0)) / + subGroupSize; + const auto table_idx = warp_id / num_ads_in_batch; + const auto ads_idx = warp_id % num_ads_in_batch; + const auto start_offset = cat_ad_offsets[table_idx]; + const auto end_offset = cat_ad_offsets[table_idx + 1]; + const auto num_ads = end_offset - start_offset; + if (warp_id < reordered_cat_ad_batches) { + for (auto i = lane_id; i < num_ads; i += subGroupSize) { + reordered_cat_ad_indices + [start_offset * num_ads_in_batch + ads_idx * num_ads + i] = + cat_ad_indices[start_offset + i]; + } + } +} + +template +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>)) +void narrow_batched_broadcast_indices_kernel_( + const GenericPackedTensorAccessor< + index_t, + 1, + at::RestrictPtrTraits, + int32_t> cat_ad_offsets, + const GenericPackedTensorAccessor + cat_ad_indices, + const GenericPackedTensorAccessor< + index_t, + 1, + at::RestrictPtrTraits, + int32_t> reordered_cat_ad_offsets, + GenericPackedTensorAccessor + reordered_cat_ad_indices, + const GenericPackedTensorAccessor< + int32_t, + 1, + at::RestrictPtrTraits, + int32_t> batch_offsets, + const int32_t T, + const int subGroupSize) { + const auto B = batch_offsets.size(0) - 1; + const auto num_ads_in_batch = static_cast(batch_offsets[B]); + // calculate table_id and batch_id for this warp + auto item = syclext::this_work_item::get_nd_item<1>(); + const auto warp_id = + (item.get_group(0) * item.get_local_range(0) + item.get_local_id(0)) / + static_cast(subGroupSize); + const auto table_id = warp_id / num_ads_in_batch; + const auto warp_id_in_table = warp_id % num_ads_in_batch; + // warps in a table equally splited for each B + const auto num_warp_in_batch = num_ads_in_batch / B; + const auto batch_id = warp_id_in_table / num_warp_in_batch; + if (table_id >= T || batch_id >= B) { + return; + } + + // all table_id and batch_id for this warp is the same + const auto num_ads_b = batch_offsets[batch_id + 1] - batch_offsets[batch_id]; + const auto output_segment_offset_start = + table_id * num_ads_in_batch + batch_offsets[batch_id]; + const auto output_segment_start = + reordered_cat_ad_offsets[output_segment_offset_start]; + const auto input_segment_offset_start = T * batch_id + table_id; + const auto input_segment_offset_end = input_segment_offset_start + 1; + const auto input_segment_start = cat_ad_offsets[input_segment_offset_start]; + const auto input_segment_end = cat_ad_offsets[input_segment_offset_end]; + const auto num_elements = input_segment_end - input_segment_start; + + const auto warp_id_in_batch = warp_id_in_table % num_warp_in_batch; + const auto lane_id_in_warp = item.get_local_id(0) % subGroupSize; + for (auto i = warp_id_in_batch; i < num_ads_b; i += num_warp_in_batch) { + for (auto j = lane_id_in_warp; j < num_elements; j += subGroupSize) { + reordered_cat_ad_indices[output_segment_start + i * num_elements + j] = + cat_ad_indices[input_segment_start + j]; + } + } +} + +template +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<2>)) +void reorder_batched_ad_indices_kernel_( + // reorder indices from (ragged) [B x T x #num_ads_b x length_{b, t, a})] + // to [T][B][#num_ads_b][length_{b, t, a}], i.e. [sum(length_{b, t, a})], + // laid out as [T][B][A][L] (if all lengths were equal). + + // if broadcast_indices is enabled, all the indices will be copies of the + // first batch of the cat_ad_indices, this is useful for request-only + // broadcast + const GenericPackedTensorAccessor< + index_t, + 1, + at::RestrictPtrTraits, + int32_t> cat_ad_offsets, + const GenericPackedTensorAccessor + cat_ad_indices, + const GenericPackedTensorAccessor< + index_t, + 1, + at::RestrictPtrTraits, + int32_t> reordered_cat_ad_offsets, + GenericPackedTensorAccessor + reordered_cat_ad_indices, + const GenericPackedTensorAccessor< + int32_t, + 1, + at::RestrictPtrTraits, + int32_t> batch_offsets, + const int32_t T, + const bool broadcast_indices) { + const int32_t B = batch_offsets.size(0) - 1; + const int32_t num_ads_in_batch = batch_offsets[B]; + // warp-per-segment. + auto item = syclext::this_work_item::get_nd_item<2>(); + const auto b_t = + item.get_group(0) * item.get_local_range(1) + item.get_local_id(1); + const int32_t b = b_t % B; + const int32_t t = b_t / B; + if (t >= T) { + return; + } + + const auto num_ads_b = batch_offsets[b + 1] - batch_offsets[b]; + const auto output_segment_offset_start = + t * num_ads_in_batch + batch_offsets[b]; + const auto output_segment_start = + reordered_cat_ad_offsets[output_segment_offset_start]; + const int32_t input_segment_offset_start = + broadcast_indices ? T * b + t : T * batch_offsets[b] + t * num_ads_b; + const int32_t input_segment_offset_end = broadcast_indices + ? input_segment_offset_start + 1 + : input_segment_offset_start + num_ads_b; + const auto input_segment_start = cat_ad_offsets[input_segment_offset_start]; + const auto input_segment_end = cat_ad_offsets[input_segment_offset_end]; + const auto num_elements = input_segment_end - input_segment_start; + + if (broadcast_indices) { + for (auto i = item.get_local_id(0); i < num_ads_b * num_elements; + i += item.get_local_range(0)) { + reordered_cat_ad_indices[output_segment_start + i] = + cat_ad_indices[input_segment_start + i % num_elements]; + } + } else { + // Idea: we want to copy the entire segment of size sum_a(length_{b, t, a}) + // from starting point (given by cat_ad_offsets[b, t]) + // to end point (given by reordered_cat_ad_indices[t][b]) + for (auto i = item.get_local_id(0); + i < input_segment_end - input_segment_start; + i += item.get_local_range(0)) { + reordered_cat_ad_indices[output_segment_start + i] = + cat_ad_indices[input_segment_start + i]; + } + } +} + +void reorder_batched_ad_indices_xpu_kernel( + const at::Tensor& cat_ad_offsets, + const at::Tensor& cat_ad_indices, + const at::Tensor& reordered_cat_ad_offsets, + const at::Tensor& batch_offsets, + at::Tensor& reordered_cat_ad_indices, + const int64_t num_ads_in_batch, + const int64_t B, + const int64_t T, + const bool broadcast_indices) { + const int subGroupSize = syclMaxSubGroupSize(); + if (broadcast_indices && T <= 320 && B < 64) { + TORCH_CHECK(num_ads_in_batch * T == reordered_cat_ad_offsets.numel() - 1); + if (B == 1) { + // for B = 1 broadcast case + constexpr auto NUM_WARPS = 16; + const int workGroupSize = NUM_WARPS * subGroupSize; + const int global_dim = + xpu_calc_xblock_count( + reordered_cat_ad_offsets.numel() - 1, NUM_WARPS) * + workGroupSize; + FBGEMM_DISPATCH_ALL_TYPES( + cat_ad_indices.scalar_type(), + "narrow_broadcast_indices_kernel_1", + [&] { + AT_DISPATCH_INDEX_TYPES( + cat_ad_offsets.scalar_type(), + "narrow_broadcast_indices_kernel_2", + [&] { + sycl_kernel_submit< + narrow_broadcast_indices_kernel_>( + sycl::range<1>(global_dim), + sycl::range<1>(workGroupSize), + getCurrentSYCLQueue(), + 0, + cat_ad_offsets.packed_accessor32< + index_t, + 1, + at::RestrictPtrTraits>(), + cat_ad_indices.packed_accessor32< + scalar_t, + 1, + at::RestrictPtrTraits>(), + reordered_cat_ad_indices.packed_accessor32< + scalar_t, + 1, + at::RestrictPtrTraits>(), + num_ads_in_batch, + reordered_cat_ad_offsets.numel() - 1, + subGroupSize); + }); + }); + return; + } else { + // for B > 1 and B < 64 broadcast case + constexpr auto NUM_WARPS = 16; + const int workGroupSize = NUM_WARPS * subGroupSize; + const int global_dim = + xpu_calc_xblock_count(T * num_ads_in_batch, NUM_WARPS) * + workGroupSize; + FBGEMM_DISPATCH_ALL_TYPES( + cat_ad_indices.scalar_type(), + "narrow_batched_broadcast_indices_kernel_1", + [&] { + AT_DISPATCH_INDEX_TYPES( + cat_ad_offsets.scalar_type(), + "narrow_batched_broadcast_indices_kernel_2", + [&] { + sycl_kernel_submit>( + sycl::range<1>(global_dim), + sycl::range<1>(workGroupSize), + getCurrentSYCLQueue(), + 0, + cat_ad_offsets.packed_accessor32< + index_t, + 1, + at::RestrictPtrTraits>(), + cat_ad_indices.packed_accessor32< + scalar_t, + 1, + at::RestrictPtrTraits>(), + reordered_cat_ad_offsets.packed_accessor32< + index_t, + 1, + at::RestrictPtrTraits>(), + reordered_cat_ad_indices.packed_accessor32< + scalar_t, + 1, + at::RestrictPtrTraits>(), + batch_offsets.packed_accessor32< + int32_t, + 1, + at::RestrictPtrTraits>(), + T, + subGroupSize); + }); + }); + return; + } + } + FBGEMM_DISPATCH_ALL_TYPES( + cat_ad_indices.scalar_type(), + "reorder_batched_ad_indices_xpu_kernel_1", + [&] { + AT_DISPATCH_INDEX_TYPES( + cat_ad_offsets.scalar_type(), + "reorder_batched_ad_indices_xpu_kernel_2", + [&] { + constexpr auto NUM_WARPS = 32; + const int maxWorkGroupSize = syclDeviceMaxWorkGroupSize(); + auto maxWarpSize = maxWorkGroupSize / NUM_WARPS; + const int gloal_dim_y = + maxWarpSize < subGroupSize ? maxWarpSize : subGroupSize; + const int global_dim_x = + xpu_calc_xblock_count(B * T, NUM_WARPS) * NUM_WARPS; + sycl_kernel_submit< + reorder_batched_ad_indices_kernel_>( + sycl::range<2>(global_dim_x, gloal_dim_y), + sycl::range<2>(NUM_WARPS, gloal_dim_y), + getCurrentSYCLQueue(), + 0, + cat_ad_offsets + .packed_accessor32(), + cat_ad_indices + .packed_accessor32(), + reordered_cat_ad_offsets + .packed_accessor32(), + reordered_cat_ad_indices + .packed_accessor32(), + batch_offsets + .packed_accessor32(), + T, + broadcast_indices); + }); + }); +} + +// Kernel for permuting the lengths. Used for permutation of sparse features. +template +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>)) +void permute_2D_lengths_kernel_( + int32_t T, + int32_t B, + const index_t* __restrict__ lengths, + const int32_t* __restrict__ permute, + index_t* __restrict__ permuted_lengths) { + auto item = syclext::this_work_item::get_nd_item<1>(); + XPU_KERNEL_LOOP(item, b_t, B * T) { + int32_t b = b_t % B; + int32_t t = b_t / B; + permuted_lengths[b_t] = lengths[permute[t] * B + b]; + } +} + +void permute_2D_lengths_kernel_xpu( + int32_t T, + int32_t B, + const at::Tensor& lengths_contig, + const at::Tensor& permute_contig, + at::Tensor& permuted_lengths) { + constexpr int32_t threads_1 = 256; + const auto blocks_1 = xpu_calc_xblock_count(B * T, threads_1); + AT_DISPATCH_INDEX_TYPES( + lengths_contig.scalar_type(), "permute_2D_lengths_kernel", [&] { + sycl_kernel_submit>( + sycl::range<1>(blocks_1 * threads_1), + sycl::range<1>(threads_1), + getCurrentSYCLQueue(), + 0, + T, + B, + lengths_contig.data_ptr(), + permute_contig.data_ptr(), + permuted_lengths.data_ptr()); + }); +} + +template < + bool has_weight, + typename offsets_t, + typename indices_t, + typename weights_t> +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<2>)) +void permute_2D_data_kernel_( + int32_t len, + int32_t T, + int32_t B, + const indices_t* __restrict__ indices, + const weights_t* __restrict__ weights, + const int32_t weights_columns, + const int32_t* __restrict__ permute, + const offsets_t* __restrict__ input_offsets, + const offsets_t* __restrict__ output_offsets, + indices_t* __restrict__ permuted_indices, + weights_t* __restrict__ permuted_weights) { + auto item = syclext::this_work_item::get_nd_item<2>(); + auto b_t_start = + item.get_group(0) * item.get_local_range(1) + item.get_local_id(1); + const auto stride = item.get_group_range(0) * item.get_local_range(1); + for (int b_t = b_t_start; b_t < B * T; b_t += stride) { + int32_t b = b_t % B; + int32_t t = b_t / B; + offsets_t output_start = output_offsets[b_t]; + offsets_t segment_length; + if (b_t == B * T - 1) { + segment_length = len - output_offsets[b_t]; + } else { + segment_length = output_offsets[b_t + 1] - output_offsets[b_t]; + } + offsets_t input_start = input_offsets[permute[t] * B + b]; + for (auto i = item.get_local_id(0); i < segment_length; + i += item.get_local_range(0)) { + permuted_indices[output_start + i] = indices[input_start + i]; + if (has_weight) { + for (auto w_col = 0; w_col < weights_columns; ++w_col) { + permuted_weights[(output_start + i) * weights_columns + w_col] = + weights[(input_start + i) * weights_columns + w_col]; + } + } + } + } +} + +void permute_2D_data_kernel_xpu( + int32_t permuted_indices_size, + int32_t T, + int32_t B, + const Tensor& indices_contig, + const std::optional& weights, + const int32_t weights_columns, + const Tensor& permute_contig, + const Tensor& input_offsets, + const Tensor& output_offsets, + Tensor& permuted_indices, + const std::optional& permuted_weights) { + constexpr int32_t BT_blocks = 32; + const auto blocks_2 = xpu_calc_xblock_count(B * T, BT_blocks); + AT_DISPATCH_INDEX_TYPES( + input_offsets.scalar_type(), "permute_2D_data_kernel_1", [&] { + using offsets_t = index_t; + FBGEMM_DISPATCH_ALL_TYPES( + indices_contig.scalar_type(), "permute_2D_data_kernel_2", [&] { + using indices_t = scalar_t; + if (weights.has_value()) { + const auto weights_value_contig = weights.value().contiguous(); + FBGEMM_DISPATCH_ALL_TYPES_AND_DOUBLE( + weights_value_contig.scalar_type(), + "permute_2D_data_kernel_3", + [&] { + using weights_t = scalar_t; + sycl_kernel_submit>( + sycl::range<2>(blocks_2 * 32, BT_blocks), + sycl::range<2>(32, BT_blocks), + getCurrentSYCLQueue(), + 0, + permuted_indices_size, + T, + B, + indices_contig.data_ptr(), + weights_value_contig.data_ptr(), + weights_columns, + permute_contig.data_ptr(), + input_offsets.data_ptr(), + output_offsets.data_ptr(), + permuted_indices.data_ptr(), + permuted_weights.value().data_ptr()); + }); // for each weights_t + } else { + sycl_kernel_submit>( // false float type here as wa since + // std::nullptr_t cannot be + // supported in free function + // kernel + sycl::range<2>(blocks_2 * 32, BT_blocks), + sycl::range<2>(32, BT_blocks), + getCurrentSYCLQueue(), + 0, + permuted_indices_size, + T, + B, + indices_contig.data_ptr(), + nullptr, + 0, + permute_contig.data_ptr(), + input_offsets.data_ptr(), + output_offsets.data_ptr(), + permuted_indices.data_ptr(), + nullptr); + } + }); // for each indices_t + }); // for each offsets_t +} + +#undef INVOKE_KERNEL_WITH_DIM + +} // namespace xpu +} // namespace native +} // namespace at diff --git a/src/ATen/native/xpu/sycl/FbgemmKernels.h b/src/ATen/native/xpu/sycl/FbgemmKernels.h new file mode 100644 index 000000000..497c6a857 --- /dev/null +++ b/src/ATen/native/xpu/sycl/FbgemmKernels.h @@ -0,0 +1,358 @@ +#pragma once + +#include + +#include + +#include + +namespace syclext = sycl::ext::oneapi; +namespace syclexp = sycl::ext::oneapi::experimental; + +namespace at { + +template +struct RestrictPtrTraits { + typedef T* __restrict__ PtrType; +}; + +namespace native { +namespace xpu { + +#define FBGEMM_DISPATCH_FLOAT_AND_BFLOAT16_CASE(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define FBGEMM_DISPATCH_FLOATING_TYPES_CASE(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define FBGEMM_DISPATCH_INTEGRAL_TYPES_CASE(...) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +#define FBGEMM_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + FBGEMM_DISPATCH_FLOATING_TYPES_CASE(__VA_ARGS__) \ + FBGEMM_DISPATCH_INTEGRAL_TYPES_CASE(__VA_ARGS__)) + +#define FBGEMM_DISPATCH_ALL_TYPES_AND_DOUBLE(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + FBGEMM_DISPATCH_FLOATING_TYPES_CASE(__VA_ARGS__) \ + FBGEMM_DISPATCH_INTEGRAL_TYPES_CASE(__VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__)) + +uint32_t xpu_calc_xblock_count_base(int num_items, int threads_per_block) { + // The number of threads can be as high as 2048 on some newer architectures, + // but this is not portable. + TORCH_CHECK( + threads_per_block <= syclDeviceMaxWorkGroupSize(), + "Number of threads must be <=1024!"); + constexpr uint64_t max_blocks = 2147483647; + const auto u_num_items = static_cast(num_items); + const auto u_threads = static_cast(threads_per_block); + // Overflow safe variant of (a + b - 1) / b + const uint64_t blocks = + u_num_items / u_threads + (u_num_items % u_threads != 0); + return static_cast(std::min(blocks, max_blocks)); +} + +// See: xpu_calc_xblock_count_base +uint32_t xpu_calc_xblock_count(int num_items, int threads_per_block) { + TORCH_CHECK( + num_items >= 0, + "When calculating block counts, the number of items must be positive!"); + return xpu_calc_xblock_count_base(num_items, threads_per_block); +} + +constexpr size_t kStackArrayMaxDims = 5; + +template +inline auto CeilDivUp(T a, V b) { + return (a + b - 1) / b; +} + +template +inline auto round_down(T a, V b) { + return a / b * b; +} + +inline bool torch_tensor_undefined(const at::Tensor& ten) { + return ten.defined(); +} + +inline bool torch_tensor_undefined(const std::optional& ten) { + return !ten.has_value() || torch_tensor_undefined(ten.value()); +} + +inline bool torch_tensor_on_xpu_check(const at::Tensor& ten) { + return ten.is_xpu(); +} + +inline bool torch_tensor_on_xpu_check(const std::optional& ten) { + return !ten.has_value() || torch_tensor_on_xpu_check(ten.value()); +} + +inline std::optional get_device_index_from_tensor( + const at::Tensor& ten) { + return {ten.device().index()}; +} + +inline std::optional get_device_index_from_tensor( + const std::optional& ten) { + if (ten) { + return {ten->device().index()}; + } else { + return {}; + } +} + +inline std::string torch_tensor_device_name(const at::Tensor& ten) { + return c10::DeviceTypeName(ten.device().type()); +} + +inline std::string torch_tensor_device_name( + const std::optional& ten) { + if (ten.has_value()) { + return torch_tensor_device_name(ten.value()); + } else { + return "N/A"; + } +} + +template +std::string tensor_on_same_xpu_if_not_optional_check( + const std::string& var_names_str, + const Tensors&... tensors) { + std::optional xpu_index; + bool on_same_xpu = true; + + // Collect the XPU index of the first non-empty optional tensor and make sure + // that all tensors are on this same index. + ( + [&](const auto& tensor) { + if (!torch_tensor_undefined(tensor)) { + return; + } + if (!torch_tensor_on_xpu_check(tensor)) { + on_same_xpu = false; + return; + } + const auto my_xpu_index = get_device_index_from_tensor(tensor); + if (my_xpu_index) { + if (!xpu_index) { + xpu_index = my_xpu_index; + } else if (*xpu_index != my_xpu_index) { + on_same_xpu = false; + } + } + }(tensors), + ...); + + if (on_same_xpu) { + return ""; + } + + std::vector var_names; + { + std::string temp; + for (const auto& x : var_names_str) { + if (x == ',') { + var_names.push_back(temp); + temp = ""; + } else { + temp.push_back(x); + } + } + var_names.push_back(temp); + } + + // Not all the tensors on a GPU or on the same GPU, generate a message. + std::string msg = "Not all tensors were on the same GPU: "; + size_t current_idx = 0; + ( + [&](const auto& tensor) { + if (current_idx > 0) { + msg.append(", "); + } + msg.append( + var_names.at(current_idx++) + "(" + + torch_tensor_device_name(tensor)); + const auto xpu_device_index = get_device_index_from_tensor(tensor); + if (xpu_device_index) { + msg.append(":" + std::to_string(*xpu_device_index)); + } + msg.append(")"); + }(tensors), + ...); + + return msg; +} + +#define TENSORS_ON_SAME_XPU_IF_NOT_OPTIONAL(...) \ + do { \ + const auto tensors_on_same_xpu = \ + tensor_on_same_xpu_if_not_optional_check(#__VA_ARGS__, __VA_ARGS__); \ + TORCH_CHECK(tensors_on_same_xpu.empty(), tensors_on_same_xpu); \ + } while (false) + +struct VecType128 { + typedef sycl::float4 TType; // Transaction Type + typedef struct __attribute__((aligned(16))) { + sycl::half a, b, c, d, w, x, y, z; + } half8; + + union Data { + half8 val; + TType mask; + Data() { + mask = sycl::float4(0.0f, 0.0f, 0.0f, 0.0f); + } + } data; +}; + +struct VecType64 { + typedef sycl::vec TType; // Transaction Type + typedef struct __attribute__((aligned(8))) { + sycl::half a, b, c, d; + } half4; + + union Data { + half4 val; + TType mask; + Data() { + mask = sycl::vec(0.0f, 0.0f); + } + } data; +}; + +struct VecType32 { + typedef float TType; // Transaction Type + + union Data { + sycl::vec val; + TType mask; + Data() { + mask = 0.0f; + } + } data; +}; + +template +void f128( + VecType128& v_out, + const VecType128& x, + const VecType128& y0, + const VecType128& y1, + F f) { + v_out.data.val.a = f(x.data.val.a, y0.data.val.a, y1.data.val.a); + v_out.data.val.b = f(x.data.val.b, y0.data.val.b, y1.data.val.b); + v_out.data.val.c = f(x.data.val.c, y0.data.val.c, y1.data.val.c); + v_out.data.val.d = f(x.data.val.d, y0.data.val.d, y1.data.val.d); + v_out.data.val.w = f(x.data.val.w, y0.data.val.w, y1.data.val.w); + v_out.data.val.x = f(x.data.val.x, y0.data.val.x, y1.data.val.x); + v_out.data.val.y = f(x.data.val.y, y0.data.val.y, y1.data.val.y); + v_out.data.val.z = f(x.data.val.z, y0.data.val.z, y1.data.val.z); +} + +template +void f64( + VecType64& v_out, + const VecType64& x, + const VecType64& y0, + const VecType64& y1, + F f) { + v_out.data.val.a = f(x.data.val.a, y0.data.val.a, y1.data.val.a); + v_out.data.val.b = f(x.data.val.b, y0.data.val.b, y1.data.val.b); + v_out.data.val.c = f(x.data.val.c, y0.data.val.c, y1.data.val.c); + v_out.data.val.d = f(x.data.val.d, y0.data.val.d, y1.data.val.d); +} + +template +void f32( + VecType32& v_out, + const VecType32& x, + const VecType32& y0, + const VecType32& y1, + F f) { + v_out.data.val = sycl::vec( + f(x.data.val.x(), y0.data.val.x(), y1.data.val.x()), + f(x.data.val.y(), y0.data.val.y(), y1.data.val.y())); +} + +template +void fh( + sycl::half& v_out, + const sycl::half& x, + const sycl::half& y0, + const sycl::half& y1, + F f) { + v_out = f(x, y0, y1); +} + +void dense_to_jagged_forward_xpu_kernel( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y, + const Tensor& output_values); + +void jagged_to_padded_dense_forward_xpu_kernel( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y, + const Tensor& output, + const double padding_value = 0.0); + +void jagged_dense_elementwise_add_jagged_output_fwd_xpu_kn( + const Tensor& x_values, + const std::vector& offsets, + const Tensor& dense, + const Tensor& output_values); + +void reorder_batched_ad_lengths_xpu_kernel( + const Tensor& cat_ad_lengths, + const Tensor& batch_offsets, + Tensor& reordered_cat_ad_lengths, + const int32_t T, + const bool broadcast_lengths, + const int32_t grid_size); + +void reorder_batched_ad_indices_xpu_kernel( + const at::Tensor& cat_ad_offsets, + const at::Tensor& cat_ad_indices, + const at::Tensor& reordered_cat_ad_offsets, + const at::Tensor& batch_offsets, + at::Tensor& reordered_cat_ad_indices, + const int64_t num_ads_in_batch, + const int64_t B, + const int64_t T, + const bool broadcast_indices = false); + +void permute_2D_lengths_kernel_xpu( + int32_t T, + int32_t B, + const at::Tensor& lengths_contig, + const at::Tensor& permute_contig, + at::Tensor& permuted_lengths); + +void permute_2D_data_kernel_xpu( + int32_t permuted_indices_size, + int32_t T, + int32_t B, + const Tensor& indices_contig, + const std::optional& weights, + const int32_t weights_columns, + const Tensor& permute_contig, + const Tensor& input_offsets, + const Tensor& output_offsets, + Tensor& permuted_indices, + const std::optional& permuted_weights); + +} // namespace xpu +} // namespace native +} // namespace at diff --git a/test/xpu/test_fbgemm_ops_xpu.py b/test/xpu/test_fbgemm_ops_xpu.py new file mode 100644 index 000000000..a8822cd98 --- /dev/null +++ b/test/xpu/test_fbgemm_ops_xpu.py @@ -0,0 +1,1153 @@ +# Owner(s): ["module: intel"] +import itertools +import random +from itertools import accumulate +from typing import List, Optional, Tuple, Type + +import hypothesis.strategies as st +import numpy as np +import numpy.typing as npt +import torch +from hypothesis import assume, given, settings, Verbosity +from torch.testing._internal.common_device_type import instantiate_device_type_tests +from torch.testing._internal.common_utils import run_tests, TestCase + +try: + from xpu_test_utils import XPUPatchForImport +except Exception as e: + from .xpu_test_utils import XPUPatchForImport + +# define fbgemm ops schemas here since we cannot register them in torch-xpu-ops. +# otherwise, it will fail fbgemm lib due to duplicate schema registration. +# for user, they can import fbgemm_gpu first before accessing fbgemm ops on xpu. +lib = torch.library.Library("fbgemm", "DEF") + +lib.define("asynchronous_complete_cumsum(Tensor t_in) -> Tensor") +lib.define( + "dense_to_jagged(Tensor dense, Tensor[] offsets, SymInt? total_L=None) -> (Tensor, Tensor[])" +) +lib.define( + "dense_to_jagged_forward(Tensor dense, Tensor[] offsets, SymInt? total_L=None) -> Tensor" +) +lib.define( + "jagged_to_padded_dense(Tensor values, Tensor[] offsets, SymInt[] max, float padding_value=0.0) -> Tensor" +) +lib.define( + "jagged_to_padded_dense_forward(Tensor values, Tensor[] offsets, SymInt[] max, float padding_value=0.0) -> Tensor" +) +lib.define( + "jagged_dense_elementwise_add_jagged_output(Tensor values, Tensor[] offsets, Tensor y) -> (Tensor, Tensor[])" +) +lib.define( + "reorder_batched_ad_lengths(Tensor cat_ad_lengths, Tensor batch_offsets, int num_ads_in_batch, bool broadcast_lengths, int max_batch_size=0) -> Tensor" +) +lib.define( + "reorder_batched_ad_indices(Tensor cat_ad_offsets, Tensor cat_ad_indices, Tensor reordered_cat_ad_offsets, Tensor batch_offsets, int num_ads_in_batch, bool broadcast_indices, int num_indices_after_broadcast) -> Tensor" +) +lib.define( + "permute_2D_sparse_data(Tensor permute, Tensor lengths, Tensor indices, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)" +) + + +def generate_jagged_tensor( + num_jagged_dim: int, + outer_dense_size: int, + inner_dense_size: int, + dtype: torch.dtype, + device: torch.device, + fold_inner_dense: bool = False, + # dynamo to mark the input as dynamic shape to make sure symbolic + # shape is generated + mark_dynamic: bool = False, +) -> Tuple[torch.Tensor, List[torch.LongTensor], npt.NDArray]: + max_lengths = np.random.randint(low=1, high=10, size=(num_jagged_dim,)) + x_offsets: List[torch.LongTensor] = [] + num_lengths = outer_dense_size + for d in range(num_jagged_dim): + # Sometimes length[i] exceed max_L meaning jagged->dense will be + # truncation vs. padding + lengths = torch.randint( + # PT2 specialize 0/1 dims as non-symbolic shape. So we need + # to make it non 0/1 for testing. In real cases it'll likelyl + # not 0/1 anyway (if so, they'll be recompiled) + low=0 if not mark_dynamic else 1, + high=max_lengths[d] * 2, + # pyre-fixme[6]: For 3rd param expected `Union[List[int], Size, + # typing.Tuple[int, ...]]` but got `Tuple[Union[bool, float, int]]`. + size=(num_lengths,), + device=device, + ) + offset = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + if mark_dynamic: + torch._dynamo.mark_dynamic(offset, 0) + x_offsets.append(offset) + num_lengths = x_offsets[-1][-1].item() + + x_values = torch.rand( + # pyre-fixme[6]: For 1st param expected `Union[List[int], Size, + # typing.Tuple[int, ...]]` but got `Tensor`. + x_offsets[-1][-1] * inner_dense_size, + dtype=dtype, + device=device, + ) + if inner_dense_size != 1 or not fold_inner_dense: + # pyre-fixme[6]: For 1st param expected `int` but got `Union[bool, float, int]`. + x_values = x_values.reshape(x_offsets[-1][-1].item(), inner_dense_size) + + if mark_dynamic: + for i in range(inner_dense_size): + torch._dynamo.mark_dynamic(x_values, i) + + return x_values, x_offsets, max_lengths + + +def to_padded_dense( + values: torch.Tensor, + offsets: List[torch.LongTensor], + max_lengths: npt.NDArray, + padding_value: float = 0, +) -> torch.Tensor: + outer_dense_size = len(offsets[0]) - 1 + # canonicalize by unsqueeze the last dim if the inner dense dimension + # is 1 and folded. + inner_dense_size = 1 if values.ndim == 1 else values.size(-1) + dense = torch.empty( + (outer_dense_size,) + tuple(max_lengths) + (inner_dense_size,), + dtype=values.dtype, + device=values.device, + ) + for i in range(outer_dense_size): + for jagged_coord in itertools.product( + *(list(range(max_l)) for max_l in max_lengths) + ): + cur_offset = i + is_zero = False + for d in range(len(max_lengths)): + # pyre-fixme[6]: For 1st argument expected `Union[None, _NestedSe... + begin = offsets[d][cur_offset].item() + # pyre-fixme[6]: For 1st argument expected `Union[None, _NestedSe... + end = offsets[d][cur_offset + 1].item() + # pyre-fixme[6]: For 1st param expected `int` but got + # `Union[bool, float, int]`. + if jagged_coord[d] >= end - begin: + is_zero = True + break + cur_offset = begin + jagged_coord[d] + dense[(i,) + jagged_coord] = ( + padding_value + if is_zero + # pyre-fixme[6]: For 1st argument expected `Union[None, _NestedSe... + else values[cur_offset] + ) + return dense.squeeze(-1) if values.ndim == 1 else dense + + +def permute_indices_ref_( + lengths: torch.Tensor, + indices: torch.Tensor, + weights: Optional[torch.Tensor], + permute: torch.LongTensor, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + T = lengths.size(0) + B = lengths.size(1) + if T == 0 or B == 0: + return lengths, indices, weights + + permuted_lengths = torch.index_select(lengths.view(T, -1), 0, permute) + original_segment_lengths = lengths.view(T, -1).sum(dim=1, dtype=torch.int32) + original_segment_start = [0] + list(accumulate(original_segment_lengths.view(-1))) + + permuted_indices = [] + permuted_weights = [] + for i in range(permute.size(0)): + start = original_segment_start[permute[i]] + end = start + original_segment_lengths[permute[i]] + permuted_indices.append(indices[start:end]) + if weights is not None: + permuted_weights.append(weights[start:end]) + + permuted_indices = torch.cat(permuted_indices, dim=0).flatten() + + if weights is None: + permuted_weights = None + else: + permuted_weights = torch.cat(permuted_weights, dim=0).flatten() + + return permuted_lengths, permuted_indices, permuted_weights + + +with XPUPatchForImport(False): + + class CumSumTest(TestCase): + @given( + n=st.integers(min_value=0, max_value=10), + index_types=st.sampled_from( + [ + (torch.int64, np.int64), + (torch.int32, np.int32), + (torch.float32, np.float32), + ] + ), + device=st.just(torch.device("xpu")), + ) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_cumsum( + self, + n: int, + index_types: Tuple[Type[object], Type[object]], + device: torch.device, + ) -> None: + (pt_index_dtype, np_index_dtype) = index_types + + # The CPU variants of asynchronous_*_cumsum support floats, since some + # downstream tests appear to be relying on this behavior. As such, the + # test is disabled for GPU + float test cases. + if device == torch.device("xpu") and pt_index_dtype is torch.float32: + return + + # pyre-ignore-errors[16] + x = ( + torch.randint(low=0, high=100, size=(n,)) + .type(pt_index_dtype) + .to(device) + ) + zc = torch.ops.fbgemm.asynchronous_complete_cumsum(x) + + torch.testing.assert_close( + torch.from_numpy( + (np.cumsum([0] + x.cpu().numpy().tolist())).astype(np_index_dtype) + ), + zc.cpu(), + ) + + class DenseToJaggedTest(TestCase): + def _test_dense_to_jagged( + self, + num_jagged_dim: int, + outer_dense_size: int, + inner_dense_size: int, + dtype: torch.dtype, + device: torch.device, + precompute_total_L: bool, + ) -> None: + # Generate multi-dim jagged tensor + values_2d, offsets, max_lengths = generate_jagged_tensor( + num_jagged_dim, outer_dense_size, inner_dense_size, dtype, device + ) + # values_2d = values_2d.clone().detach().requires_grad_(True) + + # jagged -> dense + dense = torch.ops.fbgemm.jagged_to_padded_dense( + values_2d, offsets, max_lengths + ) + + # dense -> jagged (op which is being tested) + if precompute_total_L: + total_L = values_2d.size(0) + jagged_values, jagged_offsets = torch.ops.fbgemm.dense_to_jagged( + dense, offsets, total_L + ) + jagged_values_f = torch.ops.fbgemm.dense_to_jagged_forward( + dense, offsets, total_L + ) + torch.testing.assert_close(jagged_values, jagged_values_f) + else: + jagged_values, jagged_offsets = torch.ops.fbgemm.dense_to_jagged( + dense, offsets + ) + jagged_values_f = torch.ops.fbgemm.dense_to_jagged_forward( + dense, offsets + ) + torch.testing.assert_close(jagged_values, jagged_values_f) + + # jagged -> dense + dense2 = torch.ops.fbgemm.jagged_to_padded_dense( + jagged_values, jagged_offsets, max_lengths + ) + + # verify forward + torch.testing.assert_close(dense, dense2) + + # verify backward + + @given( + num_jagged_dim=st.integers(1, 5), + outer_dense_size=st.integers(0, 5), + inner_dense_size=st.integers(0, 5), + # num_jagged_dim=st.integers(4, 5), + # outer_dense_size=st.integers(4, 5), + # inner_dense_size=st.integers(4, 5), + dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]), + device=st.just(torch.device("xpu")), + precompute_total_L=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_dense_to_jagged( + self, + num_jagged_dim: int, + outer_dense_size: int, + inner_dense_size: int, + dtype: torch.dtype, + device: torch.device, + precompute_total_L: bool, + ) -> None: + self._test_dense_to_jagged( + num_jagged_dim, + outer_dense_size, + inner_dense_size, + dtype, + device, + precompute_total_L, + ) + + @given( + num_jagged_dim=st.just(1), + outer_dense_size=st.integers(0, 6000), + inner_dense_size=st.sampled_from([8, 16, 23, 24, 48, 50, 64, 72, 96, 192]), + dtype=st.just(torch.half), + device=st.just(torch.device("xpu")), + precompute_total_L=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_dense_to_jagged_opt( + self, + num_jagged_dim: int, + outer_dense_size: int, + inner_dense_size: int, + dtype: torch.dtype, + device: torch.device, + precompute_total_L: bool, + ) -> None: + self._test_dense_to_jagged( + num_jagged_dim, + outer_dense_size, + inner_dense_size, + dtype, + device, + precompute_total_L, + ) + + # (8000+1) * 8 (size of the element of LongTensor/int64_t offsets) + # = ~62.5KB > 48KB default shared memory on V100/A100. + @given( + num_jagged_dim=st.just(1), + outer_dense_size=st.just(8000), + inner_dense_size=st.just(16), + dtype=st.just(torch.half), + device=st.just(torch.device("xpu")), + precompute_total_L=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_dense_to_jagged_opt_large_batch( + self, + num_jagged_dim: int, + outer_dense_size: int, + inner_dense_size: int, + dtype: torch.dtype, + device: torch.device, + precompute_total_L: bool, + ) -> None: + self._test_dense_to_jagged( + num_jagged_dim, + outer_dense_size, + inner_dense_size, + dtype, + device, + precompute_total_L, + ) + + @given( + num_jagged_dim=st.integers(1, 5), + # TODO: size = 0/1 will be incorrectly specialized + outer_dense_size=st.integers(2, 5), + inner_dense_size=st.integers(2, 5), + dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]), + device=st.just(torch.device("xpu")), + ) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_dense_to_jagged_dynamic_shape( + self, + num_jagged_dim: int, + outer_dense_size: int, + inner_dense_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> None: + # Start a fresh compile for each parameter of the test case + torch._dynamo.reset() + + values_2d, offsets, max_lengths = generate_jagged_tensor( + num_jagged_dim, + outer_dense_size, + inner_dense_size, + dtype, + device, + mark_dynamic=True, + ) + + def jagged_to_dense( + values: torch.Tensor, + offsets: List[torch.LongTensor], + max_lengths: List[int], + ) -> torch.Tensor: + return torch.ops.fbgemm.jagged_to_padded_dense( + values, offsets, max_lengths + ) + + # jagged -> dense + dense = jagged_to_dense(values_2d, offsets, max_lengths.tolist()) + + # dense -> jagged, it is required to pre-compute totalL + total_L = values_2d.size(0) + dense = dense.clone().detach().to(device) + + torch._dynamo.mark_dynamic(dense, 0) + torch._dynamo.mark_dynamic(dense, -1) + + def dense_to_jagged_withL( + dense: torch.Tensor, offsets: List[torch.LongTensor], total_L: List[int] + ) -> Tuple[torch.Tensor, torch.Tensor]: + jagged_values, jagged_offsets = torch.ops.fbgemm.dense_to_jagged( + dense, offsets, total_L + ) + jagged_values_f = torch.ops.fbgemm.dense_to_jagged_forward( + dense, offsets, total_L + ) + torch.testing.assert_close(jagged_values, jagged_values_f) + return jagged_values, jagged_offsets + + def dense_to_jagged_noL( + dense: torch.Tensor, offsets: List[torch.LongTensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + jagged_values, jagged_offsets = torch.ops.fbgemm.dense_to_jagged( + dense, offsets + ) + jagged_values_f = torch.ops.fbgemm.dense_to_jagged_forward( + dense, offsets + ) + torch.testing.assert_close(jagged_values, jagged_values_f) + return jagged_values, jagged_offsets + + jagged_values, jagged_offsets = dense_to_jagged_noL(dense, offsets) + jagged_values, jagged_offsets = dense_to_jagged_withL( + dense, offsets, total_L + ) + + jagged_values.to(device) + # jagged -> dense + dense2 = torch.ops.fbgemm.jagged_to_padded_dense( + jagged_values, jagged_offsets, max_lengths + ) + + # verify forward + assert dense.size() == dense2.size() + + class JaggedToPaddedDenseTest(TestCase): + @given( + num_jagged_dim=st.integers(1, 5), + outer_dense_size=st.integers(0, 5), + inner_dense_size=st.integers(0, 5), + fold_inner_dense=st.booleans(), + padding_value=st.sampled_from([0, -1e-8]), + dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]), + device_type=st.just("xpu"), + ) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_jagged_to_padded_dense( + self, + num_jagged_dim: int, + outer_dense_size: int, + inner_dense_size: int, + fold_inner_dense: bool, + padding_value: float, + dtype: torch.dtype, + device_type: str, + ) -> None: + # CPU doesn't support bfloat16 + assume(device_type != "cpu" or dtype != torch.bfloat16) + assume(not fold_inner_dense or inner_dense_size == 1) + + # Testing with a basic crafted example. + # dense representation is + # [[[[0, 1], [ 0, 0], [0, 0]], + # [[2, 3], [ 4, 5], [6, 7]], + # [[0, 0], [ 0, 0], [0, 0]], + # [[0, 0], [ 0, 0], [0, 0]]], + # [[[0, 0], [ 0, 0], [0, 0]], + # [[0, 0], [ 0, 0], [0, 0]], + # [[0, 0], [ 0, 0], [0, 0]], + # [[0, 0], [ 0, 0], [0, 0]]], + # [[[8, 9], [10, 11], [0, 0]], + # [[0, 0], [ 0, 0], [0, 0]], + # [[0, 0], [ 0, 0], [0, 0]], + # [[0, 0], [ 0, 0], [0, 0]]]], + # inner_dense_size = 2 + # x_offsets = [ + # torch.LongTensor([0, 2, 2, 3]), # lengths torch.Tensor([2, 0, 1]), + # torch.LongTensor([0, 1, 4, 6]), # lengths torch.Tensor([1, 3, 2]), + # ] + # outer_dense_size = len(x_offsets[0]) - 1 + # max_lengths = [4, 3] + + device = torch.device(device_type) + + x_values, x_offsets, max_lengths = generate_jagged_tensor( + num_jagged_dim, + outer_dense_size, + inner_dense_size, + torch.float, + device, + fold_inner_dense, + ) + + output_ref = to_padded_dense( + x_values, x_offsets, max_lengths, padding_value=padding_value + ) + output = torch.ops.fbgemm.jagged_to_padded_dense( + x_values, + x_offsets, + max_lengths, + padding_value=padding_value, + ) + + output_f = torch.ops.fbgemm.jagged_to_padded_dense_forward( + x_values, + x_offsets, + max_lengths, + padding_value=padding_value, + ) + + torch.testing.assert_close(output, output_ref) + torch.testing.assert_close(output_f, output_ref) + + class ElementwiseBinaryTest(TestCase): + def _test_jagged_elementwise_binary( + self, + num_jagged_dim: int, + outer_dense_size: int, + inner_dense_size: int, + operation: str, + dtype: torch.dtype, + device: torch.device, + ) -> None: + x_values, x_offsets, max_lengths = generate_jagged_tensor( + num_jagged_dim, outer_dense_size, inner_dense_size, dtype, device + ) + y = torch.rand( + outer_dense_size * np.prod(max_lengths) * inner_dense_size, + dtype=dtype, + device=device, + ).reshape((outer_dense_size,) + tuple(max_lengths) + (inner_dense_size,)) + + x_padded = to_padded_dense(x_values, x_offsets, max_lengths) + + assert operation == "add_jagged_output" + # create a jagged tensor and then densify + y = to_padded_dense( + torch.rand( + ( + max(outer_dense_size * np.prod(max_lengths), x_values.size(0)), + inner_dense_size, + ), + dtype=dtype, + device=device, + ), + x_offsets, + max_lengths, + ) + output_ref = x_padded + y + ( + output, + output_offsets, + ) = torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output( + x_values, x_offsets, y + ) + output = to_padded_dense(output, output_offsets, max_lengths) + + torch.testing.assert_close(output, output_ref) + + @given( + num_jagged_dim=st.integers(1, 4), + outer_dense_size=st.integers(0, 4), + inner_dense_size=st.integers(0, 4), + operation=st.just("add_jagged_output"), + dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]), + device=st.just(torch.device("xpu")), + ) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_jagged_elementwise_binary( + self, + num_jagged_dim: int, + outer_dense_size: int, + inner_dense_size: int, + operation: str, + dtype: torch.dtype, + device: torch.device, + ) -> None: + self._test_jagged_elementwise_binary( + num_jagged_dim, + outer_dense_size, + inner_dense_size, + operation, + dtype, + device, + ) + + @given( + num_jagged_dim=st.just(1), + outer_dense_size=st.integers(0, 8), + inner_dense_size=st.sampled_from([16, 64, 96, 192]), + operation=st.just("add_jagged_output"), + dtype=st.just(torch.half), + device=st.just(torch.device("xpu")), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_jagged_elementwise_binary_opt( + self, + num_jagged_dim: int, + outer_dense_size: int, + inner_dense_size: int, + operation: str, + dtype: torch.dtype, + device: torch.device, + ) -> None: + self._test_jagged_elementwise_binary( + num_jagged_dim, + outer_dense_size, + inner_dense_size, + operation, + dtype, + device, + ) + + @given( + num_jagged_dim=st.integers(1, 5), + outer_dense_size=st.integers(2, 5), + inner_dense_size=st.integers(2, 5), + operation=st.just("add_jagged_output"), + dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]), + device=st.just(torch.device("xpu")), + ) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_jagged_elementwise_binary_dynamic_shape( + self, + num_jagged_dim: int, + outer_dense_size: int, + inner_dense_size: int, + operation: str, + dtype: torch.dtype, + device: torch.device, + ) -> None: + # Start a fresh compile for each parameter of the test case + torch._dynamo.reset() + + x_values, x_offsets, max_lengths = generate_jagged_tensor( + num_jagged_dim, + outer_dense_size, + inner_dense_size, + dtype, + device, + mark_dynamic=True, + ) + y = torch.rand( + outer_dense_size * np.prod(max_lengths) * inner_dense_size, + dtype=dtype, + device=device, + ).reshape((outer_dense_size,) + tuple(max_lengths) + (inner_dense_size,)) + + x_padded = to_padded_dense(x_values, x_offsets, max_lengths) + + def jagged_dense_elementwise_add( + x_values: torch.Tensor, + x_offsets: List[torch.LongTensor], + y: torch.Tensor, + ) -> torch.Tensor: + return torch.ops.fbgemm.jagged_dense_elementwise_add( + x_values, x_offsets, y + ) + + def jagged_dense_elementwise_add_jagged_output( + x_values: torch.Tensor, + x_offsets: List[torch.LongTensor], + y: torch.Tensor, + ) -> Tuple[torch.Tensor, List[torch.LongTensor]]: + return torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output( + x_values, x_offsets, y + ) + + def jagged_dense_elementwise_mul( + x_values: torch.Tensor, + x_offsets: List[torch.LongTensor], + y: torch.Tensor, + ) -> Tuple[torch.Tensor, List[torch.LongTensor]]: + return torch.ops.fbgemm.jagged_dense_elementwise_mul( + x_values, x_offsets, y + ) + + assert operation == "add_jagged_output" + # create a jagged tensor and then densify + y = to_padded_dense( + torch.rand( + ( + max(outer_dense_size * np.prod(max_lengths), x_values.size(0)), + inner_dense_size, + ), + dtype=dtype, + device=device, + ), + x_offsets, + max_lengths, + ) + output_ref = x_padded + y + ( + output, + output_offsets, + ) = jagged_dense_elementwise_add_jagged_output(x_values, x_offsets, y) + output = to_padded_dense(output, output_offsets, max_lengths) + + assert output.size() == output_ref.size() + + class ReorderBatchedTest(TestCase): + @given( + B=st.integers(min_value=1, max_value=20), + T=st.integers(min_value=1, max_value=20), + L=st.integers(min_value=2, max_value=20), + A=st.integers(min_value=1, max_value=20), + Dtype=st.sampled_from([torch.int32, torch.float, torch.int64]), + broadcast_lengths=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_reorder_batched_ad_lengths( + self, + B: int, + T: int, + L: int, + A: int, + Dtype: torch.dtype, + broadcast_lengths: bool, + ) -> None: + if broadcast_lengths: + cat_ad_lengths = ( + torch.cat( + [torch.tensor([L for _ in range(T)]) for _ in range(B)], 0 + ) + .xpu() + .to(Dtype) + ) + cat_ad_lengths_broadcasted = cat_ad_lengths.tile([A]) + else: + cat_ad_lengths = ( + torch.cat( + [torch.tensor([L for _ in range(T * A)]) for _ in range(B)], 0 + ) + .xpu() + .to(Dtype) + ) + cat_ad_lengths_broadcasted = cat_ad_lengths + batch_offsets = torch.tensor([A * b for b in range(B + 1)]).int().xpu() + num_ads_in_batch = B * A + reordered_batched_ad_lengths = torch.ops.fbgemm.reorder_batched_ad_lengths( + cat_ad_lengths, batch_offsets, num_ads_in_batch, broadcast_lengths + ) + torch.testing.assert_close( + cat_ad_lengths_broadcasted, reordered_batched_ad_lengths + ) + + @given( + B=st.integers(min_value=1, max_value=20), + T=st.integers(min_value=1, max_value=20), + L=st.integers(min_value=2, max_value=20), + A=st.integers(min_value=1, max_value=20), + Dtype=st.sampled_from( + [torch.int32, torch.float, torch.int64, torch.bfloat16] + ), + Itype=st.sampled_from([torch.int32, torch.int64]), + broadcast_indices=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_reorder_batched_ad_indices( + self, + B: int, + T: int, + L: int, + A: int, + Dtype: torch.dtype, + Itype: torch.dtype, + broadcast_indices: bool, + ) -> None: + if broadcast_indices: + cat_ad_indices = ( + torch.randint( + low=0, + high=100, + size=(B * T * L,), + ) + .int() + .xpu() + .to(Dtype) + ) + cat_ad_lengths = ( + torch.cat( + [torch.tensor([L for _ in range(T)]) for _ in range(B)], + 0, + ) + .int() + .xpu() + ) + cat_ad_lengths_broadcasted = cat_ad_lengths.tile([A]) + else: + cat_ad_indices = ( + torch.randint( + low=0, + high=100, + size=(B * T * A * L,), + ) + .int() + .xpu() + .to(Dtype) + ) + cat_ad_lengths = ( + torch.cat( + [torch.tensor([L for _ in range(T * A)]) for _ in range(B)], + 0, + ) + .int() + .xpu() + ) + cat_ad_lengths_broadcasted = cat_ad_lengths + batch_offsets = torch.tensor([A * b for b in range(B + 1)]).int().xpu() + num_ads_in_batch = B * A + reordered_cat_ad_lengths = torch.ops.fbgemm.reorder_batched_ad_lengths( + cat_ad_lengths, batch_offsets, num_ads_in_batch, broadcast_indices + ) + torch.testing.assert_close( + cat_ad_lengths_broadcasted, reordered_cat_ad_lengths + ) + + cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + cat_ad_lengths + ).to(Itype) + reordered_cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + reordered_cat_ad_lengths + ).to(Itype) + reordered_cat_ad_indices = torch.ops.fbgemm.reorder_batched_ad_indices( + cat_ad_offsets, + cat_ad_indices, + reordered_cat_ad_offsets, + batch_offsets, + num_ads_in_batch, + broadcast_indices, + B * T * A * L, + ) + + torch.testing.assert_close( + reordered_cat_ad_indices.view(T, B, A, L).permute(1, 0, 2, 3), + ( + cat_ad_indices.view(B, T, 1, L).tile([1, 1, A, 1]) + if broadcast_indices + else cat_ad_indices.view(B, T, A, L) + ), + ) + + class Permute2DSparseFeaturesTest(TestCase): + @given( + B=st.integers(min_value=0, max_value=20), + T=st.integers(min_value=0, max_value=20), + L=st.integers(min_value=2, max_value=20), + long_index=st.booleans(), + has_weight=st.booleans(), + W=st.integers(min_value=4, max_value=8), + ) + def test_permute_indices( + self, + B: int, + T: int, + L: int, + long_index: bool, + has_weight: bool, + W: int, + ) -> None: + index_dtype = torch.int64 if long_index else torch.int32 + length_splits: Optional[List[torch.Tensor]] = None + lengths = torch.randint(low=1, high=L, size=(T, B)).type(index_dtype) + + # pyre-fixme[6]: For 1st param expected `Union[List[int], Size, + # typing.Tuple[int, ...]]` but got `Union[bool, float, int]`. + weights = torch.rand(lengths.sum().item()).float() if has_weight else None + indices = torch.randint( + low=1, + high=int(1e5), + # pyre-fixme[6]: Expected `Union[int, typing.Tuple[int, ...]]` for 3rd + # param but got `Tuple[typing.Union[float, int]]`. + size=(lengths.sum().item(),), + ).type(index_dtype) + + permute_list = list(range(T)) + random.shuffle(permute_list) + + permute = torch.IntTensor(permute_list) + ( + permuted_lengths_ref, + permuted_indices_ref, + permuted_weights_ref, + # pyre-fixme[6]: For 4th param expected `LongTensor` but got `Tensor`. + ) = permute_indices_ref_(lengths, indices, weights, permute.long()) + ( + permuted_lengths_xpu, + permuted_indices_xpu, + permuted_weights_xpu, + ) = torch.ops.fbgemm.permute_2D_sparse_data( + permute.xpu(), + lengths.xpu(), + indices.xpu(), + weights.xpu() if has_weight else None, + None, + ) + if has_weight: + torch.testing.assert_close( + permuted_weights_xpu.cpu(), permuted_weights_ref + ) + else: + assert permuted_weights_xpu is None and permuted_weights_ref is None + + torch.testing.assert_close( + permuted_indices_xpu.cpu(), permuted_indices_ref + ) + torch.testing.assert_close( + permuted_lengths_xpu.cpu(), permuted_lengths_ref + ) + self.assertIsNone(permuted_weights_xpu) + + @given( + B=st.integers(min_value=2, max_value=20), + T=st.integers(min_value=2, max_value=20), + L=st.integers(min_value=2, max_value=20), + long_index=st.booleans(), + ) + def test_permute_indices_non_contiguous( + self, + B: int, + T: int, + L: int, + long_index: bool, + ) -> None: + index_dtype = torch.int64 if long_index else torch.int32 + lengths = torch.randint(low=1, high=L, size=(T, B)).type(index_dtype) + + indices = torch.randint( + low=1, + high=int(1e5), + # pyre-fixme[6]: Expected `Union[int, typing.Tuple[int, ...]]` for 3rd + # param but got `Tuple[typing.Union[float, int]]`. + size=(lengths.sum().item(),), + ).type(index_dtype) + + permute_list = list(range(T)) + random.shuffle(permute_list) + permute = torch.IntTensor(permute_list) + + def create_non_contiguous(x: torch.Tensor) -> torch.Tensor: + # Create a diluted tensor with 2x elements, and then take every other element + # with the value from the original tensor. For example, if x = [1, 2, 3, 4], + # then the diluted tensor is [1, 0, 2, 0, 3, 0, 4, 0]. + diluted = x.new_zeros(x.numel() * 2).flatten() + diluted[::2] = x.flatten() + # Returns the sliced tensor, which is non-contiguous. + return diluted[::2].view(x.shape) + + ( + permuted_lengths_ref, + permuted_indices_ref, + permuted_weights_ref, + # pyre-fixme[6]: For 4th param expected `LongTensor` but got `Tensor`. + ) = permute_indices_ref_(lengths, indices, None, permute.long()) + + permute_xpu = create_non_contiguous(permute.xpu()) + lengths_xpu = create_non_contiguous(lengths.xpu()) + indices_xpu = create_non_contiguous(indices.xpu()) + self.assertFalse(permute_xpu.is_contiguous()) + self.assertFalse(lengths_xpu.is_contiguous()) + self.assertFalse(indices_xpu.is_contiguous()) + + ( + permuted_lengths_xpu, + permuted_indices_xpu, + permuted_weights_xpu, + ) = torch.ops.fbgemm.permute_2D_sparse_data( + permute_xpu, + lengths_xpu, + indices_xpu, + None, + None, + ) + torch.testing.assert_close(permuted_indices_xpu.cpu(), permuted_indices_ref) + torch.testing.assert_close(permuted_lengths_xpu.cpu(), permuted_lengths_ref) + self.assertIsNone(permuted_weights_xpu) + + def test_permute_indices_scripted_with_none_weights( + self, + ) -> None: + index_dtype = torch.int32 + lengths = torch.randint(low=1, high=2, size=(1, 1)).type(index_dtype) + weights = None + indices = torch.randint( + low=1, + high=int(1e5), + # pyre-fixme[6]: Expected `Union[int, typing.Tuple[int, ...]]` for 3rd + # param but got `Tuple[typing.Union[float, int]]`. + size=(lengths.sum().item(),), + ).type(index_dtype) + permute_list = list(range(1)) + random.shuffle(permute_list) + + permute = torch.IntTensor(permute_list) + + ( + permuted_lengths_xpu, + permuted_indices_xpu, + permuted_weights_xpu, + ) = torch.ops.fbgemm.permute_2D_sparse_data( + permute.xpu(), lengths.xpu(), indices.xpu(), None, None + ) + ( + permuted_lengths_ref, + permuted_indices_ref, + permuted_weights_ref, + # pyre-fixme[6]: For 4th param expected `LongTensor` but got `Tensor`. + ) = permute_indices_ref_(lengths, indices, weights, permute.long()) + self.assertTrue( + torch.equal(permuted_indices_xpu.cpu(), permuted_indices_ref) + ) + self.assertTrue( + torch.equal(permuted_lengths_xpu.cpu(), permuted_lengths_ref) + ) + self.assertEqual(permuted_weights_xpu, None) + self.assertEqual(permuted_weights_ref, None) + + @given( + B=st.integers(min_value=1, max_value=20), + T=st.integers(min_value=1, max_value=20), + L=st.integers(min_value=2, max_value=20), + long_index=st.booleans(), + has_weight=st.booleans(), + ) + def test_permute_indices_with_repeats( + self, B: int, T: int, L: int, long_index: bool, has_weight: bool + ) -> None: + index_dtype = torch.int64 if long_index else torch.int32 + lengths = torch.randint(low=1, high=L, size=(T, B)).type(index_dtype) + # pyre-fixme[6]: For 1st param expected `Union[List[int], Size, + # typing.Tuple[int, ...]]` but got `Union[bool, float, int]`. + weights = torch.rand(lengths.sum().item()).float() if has_weight else None + indices = torch.randint( + low=1, + high=int(1e5), + # pyre-fixme[6]: Expected `Union[int, typing.Tuple[int, ...]]` for 3rd + # param but got `Tuple[typing.Union[float, int]]`. + size=(lengths.sum().item(),), + ).type(index_dtype) + permute_list = list(range(T)) + + num_repeats = random.randint(0, T) + for _ in range(num_repeats): + permute_list.append(random.randint(0, T - 1)) + + random.shuffle(permute_list) + permute = torch.IntTensor(permute_list) + + ( + permuted_lengths_ref, + permuted_indices_ref, + permuted_weights_ref, + # pyre-fixme[6]: For 4th param expected `LongTensor` but got `Tensor`. + ) = permute_indices_ref_(lengths, indices, weights, permute.long()) + + ( + permuted_lengths_xpu, + permuted_indices_xpu, + permuted_weights_xpu, + ) = torch.ops.fbgemm.permute_2D_sparse_data( + permute.xpu(), + lengths.xpu(), + indices.xpu(), + # pyre-fixme[16]: `Optional` has no attribute `cuda`. + weights.xpu() if has_weight else None, + ) + torch.testing.assert_close(permuted_indices_xpu.cpu(), permuted_indices_ref) + torch.testing.assert_close(permuted_lengths_xpu.cpu(), permuted_lengths_ref) + if has_weight: + torch.testing.assert_close( + permuted_weights_xpu.cpu(), permuted_weights_ref + ) + else: + assert permuted_weights_xpu is None + + def test_permute_2D_sparse_data(self) -> None: + lengths = torch.tensor( + [[0, 0, 1], [0, 1, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 1]], + dtype=torch.int32, + device="xpu", + ) + indices = torch.tensor( + [500, 1000, 1999], + dtype=torch.int32, + device="xpu", + ) + permute = torch.tensor( + [0, 3, 1, 4, 2, 5], + dtype=torch.int32, + device="xpu", + ) + weights = torch.rand((3, 64), device="xpu") + ( + lengths_actual, + values_actual, + weights_actual, + ) = torch.ops.fbgemm.permute_2D_sparse_data( + permute, lengths, indices, weights, indices.numel() + ) + self.assertTrue( + torch.equal( + lengths_actual, + torch.tensor( + [ + [0, 0, 1], + [0, 0, 0], + [0, 1, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 1], + ], + dtype=torch.int32, + device="xpu", + ), + ) + ) + self.assertTrue(torch.equal(values_actual, indices)) + self.assertTrue(torch.equal(weights_actual, weights)) + + +instantiate_device_type_tests(CumSumTest, globals(), only_for="xpu", allow_xpu=True) + +instantiate_device_type_tests( + DenseToJaggedTest, globals(), only_for="xpu", allow_xpu=True +) + +instantiate_device_type_tests( + JaggedToPaddedDenseTest, globals(), only_for="xpu", allow_xpu=True +) + +instantiate_device_type_tests( + ElementwiseBinaryTest, globals(), only_for="xpu", allow_xpu=True +) + +instantiate_device_type_tests( + ReorderBatchedTest, globals(), only_for="xpu", allow_xpu=True +) + +instantiate_device_type_tests( + Permute2DSparseFeaturesTest, globals(), only_for="xpu", allow_xpu=True +) + +if __name__ == "__main__": + run_tests()