Skip to content

Commit 02dc017

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent feb1b87 commit 02dc017

1 file changed

Lines changed: 16 additions & 14 deletions

File tree

transformer_engine/common/util/vectorized_pointwise.h

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,8 @@ __launch_bounds__(unary_kernel_threads) __global__
228228
loader.load(tid, size);
229229
#pragma unroll
230230
for (int i = 0; i < nvec; ++i) {
231-
const size_t global_idx = (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment()));
231+
const size_t global_idx =
232+
(aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment()));
232233
if (global_idx >= size) continue;
233234

234235
ComputeType val = static_cast<ComputeType>(loader.separate()[i]);
@@ -332,7 +333,8 @@ __launch_bounds__(unary_kernel_threads) __global__
332333
grad_loader.load(tid, size);
333334
#pragma unroll
334335
for (int i = 0; i < nvec; ++i) {
335-
const size_t global_idx = (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment()));
336+
const size_t global_idx =
337+
(aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment()));
336338
if (global_idx >= size) continue;
337339

338340
ComputeType val = static_cast<ComputeType>(loader.separate()[i]);
@@ -466,19 +468,19 @@ void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, Out
466468
switch (align) {
467469
case Alignment::SAME_ALIGNED:
468470
unary_kernel<nvec, true, fp32, Param, OP><<<grid, threads, 0, stream>>>(
469-
input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements,
470-
offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
471+
input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets,
472+
first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
471473
break;
472474
case Alignment::SAME_UNALIGNED:
473475
unary_kernel<nvec, false, fp32, Param, OP><<<grid, threads, 0, stream>>>(
474-
input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements,
475-
offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
476+
input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets,
477+
first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
476478
break;
477479
case Alignment::DIFFERENT: {
478480
// If the pointers are aligned differently we cannot vectorize
479481
unary_kernel<1, true, fp32, Param, OP><<<grid, threads, 0, stream>>>(
480-
input, noop, output, scale, amax, scale_inv, params, N, N,
481-
offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
482+
input, noop, output, scale, amax, scale_inv, params, N, N, offsets, first_dims,
483+
last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
482484
break;
483485
}
484486
}
@@ -508,19 +510,19 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp
508510
switch (align) {
509511
case Alignment::SAME_ALIGNED:
510512
unary_grad_kernel<nvec, true, fp32, Param, OP><<<grid, threads, 0, stream>>>(
511-
grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements,
512-
offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
513+
grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets,
514+
first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
513515
break;
514516
case Alignment::SAME_UNALIGNED:
515517
unary_grad_kernel<nvec, false, fp32, Param, OP><<<grid, threads, 0, stream>>>(
516-
grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements,
517-
offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
518+
grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets,
519+
first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
518520
break;
519521
case Alignment::DIFFERENT: {
520522
// If the pointers are aligned differently we cannot vectorize
521523
unary_grad_kernel<1, true, fp32, Param, OP><<<grid, threads, 0, stream>>>(
522-
grad, input, output, scale, amax, scale_inv, params, N, N,
523-
offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
524+
grad, input, output, scale, amax, scale_inv, params, N, N, offsets, first_dims,
525+
last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
524526
break;
525527
}
526528
}

0 commit comments

Comments
 (0)