Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
const bool enable_optimizer_offloading,
{%- endif %}
{%- if is_index_select %}
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
const bool permute_output_dim_0_1
{%- else %}
{{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }}
Expand Down Expand Up @@ -436,7 +436,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
const bool enable_optimizer_offloading,
{%- endif %}
{%- if is_index_select %}
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
const bool permute_output_dim_0_1
{%- else %}
{{ args.split_kernel_args_no_defaults |
Expand Down Expand Up @@ -606,7 +606,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
const int32_t max_D,
const int32_t max_vecs_per_thread,
{%- if is_index_select %}
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
const bool permute_output_dim_0_1
{%- else %}
{{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }}
Expand Down Expand Up @@ -771,7 +771,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
const int32_t max_D,
const int32_t max_vecs_per_thread,
{%- if is_index_select %}
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
const bool permute_output_dim_0_1
{%- else %}
{{ args.split_kernel_args_no_defaults | replace_pta_namespace() | join(",\n ") | replace("cache_t", cache_type) }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
const bool enable_optimizer_offloading,
{%- endif %}
{%- if is_index_select %}
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
const bool permute_output_dim_0_1
{%- else %}
{{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }}
Expand Down Expand Up @@ -292,7 +292,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
const int32_t max_D,
const int32_t max_vecs_per_thread,
{%- if is_index_select %}
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
const bool permute_output_dim_0_1
{%- else %}
{{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }}
Expand Down Expand Up @@ -966,7 +966,6 @@ Tensor {{ embedding_cuda_op }}(
{%- endif %}

DISPATCH_OPTIMAL_KERNEL(max_D, [&] {

auto long_run_ids = at::empty({indices.numel()}, sorted_linear_indices_run_lengths.options());
auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt));

Expand All @@ -982,7 +981,6 @@ Tensor {{ embedding_cuda_op }}(
at::empty({indices.numel()}, sorted_linear_indices_run_lengths.options());
}


auto num_really_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt));
auto grad_accum_counter = at::empty(
use_deterministic_algorithms ? 0 : (indices.numel() / max_segment_length_per_cta),
Expand Down Expand Up @@ -1292,7 +1290,7 @@ Tensor {{ embedding_cuda_op }}(
enable_optimizer_offloading,
{%- endif %}
{%- if is_index_select %}
grad_offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
PTA_B(grad_offsets, int64_t, 1, 32),
permute_output_dim_0_1
{%- else %}
{{ args.split_kernel_arg_constructors | make_pta_acc_builder_format() | join(",\n ") }}
Expand Down
Loading