Skip to content

Commit 80e0c3b

Browse files
authored
Merge pull request #512 from gordicaleksa/refactor_encoder_bwd_kernel
Remove redundant CPU computation in encoder bwd
2 parents cb5eff0 + 290c00a commit 80e0c3b

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

dev/cuda/encoder_backward.cu

+6-3
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,17 @@ int main(int argc, char **argv) {
163163
}
164164
printf("Using kernel %d\n", kernel_num);
165165

166-
// set up block sizes
166+
// first check the correctness of the kernel
167+
encoder_backward_cpu(dwte, dwpe, dout, inp, B, T, C);
168+
169+
// time the kernel at different block sizes
167170
int block_sizes[] = {32, 64, 128, 256, 512, 1024};
168171

169-
// first check the correctness of the kernel
170172
for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {
171173
int block_size = block_sizes[j];
174+
cudaCheck(cudaMemset(d_dwte, 0, V * C * sizeof(float)));
175+
cudaCheck(cudaMemset(d_dwpe, 0, T * C * sizeof(float)));
172176
printf("Checking block size %d.\n", block_size);
173-
encoder_backward_cpu(dwte, dwpe, dout, inp, B, T, C);
174177
encoder_backward(kernel_num, d_dwte, d_dwpe, d_dout, d_inp, B, T, C, block_size);
175178
validate_result(d_dwte, dwte, "dwte", V * C, 1e-5f);
176179
validate_result(d_dwpe, dwpe, "dwpe", T * C, 1e-5f);

0 commit comments

Comments
 (0)