Skip to content

Commit 12ae3be

Browse files
nikitavedpytorchmergebot
authored andcommitted
Faster mul(sparse, sparse) with broadcasting in dense dims. (pytorch#85336)
This is a combo PR of pytorch#84929 and ~pytorch#83428. Preliminary benchmarks (square matrices of shape (n, n)). <details> <summary>Script</summary> ```python import torch import math from IPython import get_ipython from itertools import product, repeat import pickle from torch.utils.benchmark import Timer, Compare torch.manual_seed(13) problem_dims = ( # n > nnz (10000, 100), (100000, 1000), (1000000, 10000), # n < nnz (10, 100), (10, 1000), (10, 10000), (100, 1000), (100, 10000), (1000, 10000), (1000, 100000), (1000, 1000000), #(1000000, 1000000000), ) name = "PR" device = "cuda" results = [] for n, nnz in problem_dims: def gen_tensor(coalesce=False): shape = (n, n) nrows, ncols = shape rowidx = torch.randint(low=0, high=nrows, size=(nnz,), device=device) colidx = torch.randint(low=0, high=ncols, size=(nnz,), device=device) itemidx = torch.vstack((rowidx, colidx)) xvalues = torch.randn(nnz, device=device) itemidx = torch.hstack((itemidx, itemidx)) xvalues = torch.hstack((xvalues, xvalues)) res = torch.sparse_coo_tensor(itemidx, xvalues, size=shape) if coalesce: return res.coalesce() else: return res for x_coalesce, y_coalesce in product(*repeat((True, False), 2)): x = gen_tensor(x_coalesce) y = gen_tensor(y_coalesce) smtp = "x * y" timer = Timer(smtp, globals=globals(), label="coo.mul", description=f"{name}: mul, device: {device}", sub_label=f"n={n}, nnz={nnz}, coalesce=({x_coalesce, y_coalesce})", num_threads=torch.get_num_threads()) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open(f"{name}_{device}_mul.pickle", 'wb') as f: pickle.dump(results, f) ``` </details> <details> <summary>Gather results</summary> ```python import pickle from torch.utils.benchmark import Timer, Compare files = [ "PR", "master" ] device = 'cuda' timers = [] for name in files: with open("{}_{}_mul.pickle".format(name, device), 'rb') as f: timers += pickle.load(f) compare = Compare(timers) compare.trim_significant_figures() compare.print() ``` </details> <details> <summary>CUDA</summary> ``` [------------------------------------------------- coo.mul -------------------------------------------------] | PR: mul, device: cuda | master: mul, device: cuda 24 threads: ------------------------------------------------------------------------------------------------- n=10000, nnz=100, coalesce=((True, True)) | 95 | 91 n=10000, nnz=100, coalesce=((True, False)) | 87 | 242 n=10000, nnz=100, coalesce=((False, True)) | 87 | 226 n=10000, nnz=100, coalesce=((False, False)) | 130 | 371 n=100000, nnz=1000, coalesce=((True, True)) | 100 | 521 n=100000, nnz=1000, coalesce=((True, False)) | 90 | 649 n=100000, nnz=1000, coalesce=((False, True)) | 100 | 659 n=100000, nnz=1000, coalesce=((False, False)) | 200 | 781 n=1000000, nnz=10000, coalesce=((True, True)) | 100 | 4861 n=1000000, nnz=10000, coalesce=((True, False)) | 100 | 5012 n=1000000, nnz=10000, coalesce=((False, True)) | 98 | 5010 n=1000000, nnz=10000, coalesce=((False, False)) | 384 | 5174 n=10, nnz=100, coalesce=((True, True)) | 100 | 79 n=10, nnz=100, coalesce=((True, False)) | 100 | 221 n=10, nnz=100, coalesce=((False, True)) | 100 | 221 n=10, nnz=100, coalesce=((False, False)) | 100 | 350 n=10, nnz=1000, coalesce=((True, True)) | 100 | 100 n=10, nnz=1000, coalesce=((True, False)) | 100 | 240 n=10, nnz=1000, coalesce=((False, True)) | 100 | 254 n=10, nnz=1000, coalesce=((False, False)) | 100 | 392 n=10, nnz=10000, coalesce=((True, True)) | 100 | 110 n=10, nnz=10000, coalesce=((True, False)) | 110 | 286 n=10, nnz=10000, coalesce=((False, True)) | 110 | 286 n=10, nnz=10000, coalesce=((False, False)) | 271 | 455 n=100, nnz=1000, coalesce=((True, True)) | 110 | 851 n=100, nnz=1000, coalesce=((True, False)) | 110 | 1000 n=100, nnz=1000, coalesce=((False, True)) | 110 | 990 n=100, nnz=1000, coalesce=((False, False)) | 140 | 1124 n=100, nnz=10000, coalesce=((True, True)) | 110 | 5137 n=100, nnz=10000, coalesce=((True, False)) | 110 | 5391 n=100, nnz=10000, coalesce=((False, True)) | 100 | 5405 n=100, nnz=10000, coalesce=((False, False)) | 249 | 5539 n=1000, nnz=10000, coalesce=((True, True)) | 100 | 8598 n=1000, nnz=10000, coalesce=((True, False)) | 100 | 8800 n=1000, nnz=10000, coalesce=((False, True)) | 100 | 8782 n=1000, nnz=10000, coalesce=((False, False)) | 255 | 8956 n=1000, nnz=100000, coalesce=((True, True)) | 120 | 84500 n=1000, nnz=100000, coalesce=((True, False)) | 200 | 88560 n=1000, nnz=100000, coalesce=((False, True)) | 160 | 89000 n=1000, nnz=100000, coalesce=((False, False)) | 373 | 89000 n=1000, nnz=1000000, coalesce=((True, True)) | 312 | 606400 n=1000, nnz=1000000, coalesce=((True, False)) | 1340 | 609200 n=1000, nnz=1000000, coalesce=((False, True)) | 1340 | 609100 n=1000, nnz=1000000, coalesce=((False, False)) | 4408 | 611400 Times are in microseconds (us). ``` </details> <details> <summary>CPU</summary> ``` [------------------------------------------------ coo.mul ------------------------------------------------] | PR: mul, device: cpu | master: mul, device: cpu 24 threads: ----------------------------------------------------------------------------------------------- n=10000, nnz=100, coalesce=((True, True)) | 8 | 8 n=10000, nnz=100, coalesce=((True, False)) | 32 | 34 n=10000, nnz=100, coalesce=((False, True)) | 32 | 34 n=10000, nnz=100, coalesce=((False, False)) | 41 | 56 n=100000, nnz=1000, coalesce=((True, True)) | 24 | 24 n=100000, nnz=1000, coalesce=((True, False)) | 90 | 100 n=100000, nnz=1000, coalesce=((False, True)) | 87 | 100 n=100000, nnz=1000, coalesce=((False, False)) | 231 | 255 n=1000000, nnz=10000, coalesce=((True, True)) | 190 | 200 n=1000000, nnz=10000, coalesce=((True, False)) | 908 | 2023 n=1000000, nnz=10000, coalesce=((False, True)) | 800 | 2036 n=1000000, nnz=10000, coalesce=((False, False)) | 3684 | 3989 n=10, nnz=100, coalesce=((True, True)) | 8 | 7 n=10, nnz=100, coalesce=((True, False)) | 34 | 30 n=10, nnz=100, coalesce=((False, True)) | 33 | 30 n=10, nnz=100, coalesce=((False, False)) | 44 | 50 n=10, nnz=1000, coalesce=((True, True)) | 8 | 7 n=10, nnz=1000, coalesce=((True, False)) | 100 | 100 n=10, nnz=1000, coalesce=((False, True)) | 130 | 100 n=10, nnz=1000, coalesce=((False, False)) | 746 | 210 n=10, nnz=10000, coalesce=((True, True)) | 8 | 7 n=10, nnz=10000, coalesce=((True, False)) | 1000 | 1500 n=10, nnz=10000, coalesce=((False, True)) | 1000 | 1510 n=10, nnz=10000, coalesce=((False, False)) | 3063 | 2457 n=100, nnz=1000, coalesce=((True, True)) | 25 | 25 n=100, nnz=1000, coalesce=((True, False)) | 180 | 130 n=100, nnz=1000, coalesce=((False, True)) | 200 | 130 n=100, nnz=1000, coalesce=((False, False)) | 271 | 255 n=100, nnz=10000, coalesce=((True, True)) | 100 | 100 n=100, nnz=10000, coalesce=((True, False)) | 2444 | 2290 n=100, nnz=10000, coalesce=((False, True)) | 2455 | 2357 n=100, nnz=10000, coalesce=((False, False)) | 5316 | 3783 n=1000, nnz=10000, coalesce=((True, True)) | 204 | 211 n=1000, nnz=10000, coalesce=((True, False)) | 2457 | 2480 n=1000, nnz=10000, coalesce=((False, True)) | 2448 | 2539 n=1000, nnz=10000, coalesce=((False, False)) | 3665 | 4801 n=1000, nnz=100000, coalesce=((True, True)) | 2293 | 2374 n=1000, nnz=100000, coalesce=((True, False)) | 9000 | 24620 n=1000, nnz=100000, coalesce=((False, True)) | 8000 | 25080 n=1000, nnz=100000, coalesce=((False, False)) | 26500 | 47650 n=1000, nnz=1000000, coalesce=((True, True)) | 10000 | 13000 n=1000, nnz=1000000, coalesce=((True, False)) | 80000 | 362200 n=1000, nnz=1000000, coalesce=((False, True)) | 78050 | 392600 n=1000, nnz=1000000, coalesce=((False, False)) | 312100 | 766900 Times are in microseconds (us). ``` </details> Pull Request resolved: pytorch#85336 Approved by: https://github.com/cpuhrsch
1 parent 40d3e55 commit 12ae3be

10 files changed

+196
-173
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2+
#include <ATen/native/sparse/SparseStubs.h>
3+
#include <ATen/native/sparse/SparseBinaryOpIntersectionCommon.h>
4+
#include <ATen/native/cuda/Loops.cuh>
5+
6+
namespace at {
7+
namespace native {
8+
9+
namespace {
10+
11+
template <typename func_t>
12+
struct CUDAKernelLauncher {
13+
static void launch(TensorIteratorBase& iter, const func_t& f) {
14+
gpu_kernel(iter, f);
15+
}
16+
};
17+
18+
struct MulOp {
19+
static Tensor apply(const Tensor& a, const Tensor& b) {
20+
return a.mul(b);
21+
}
22+
};
23+
24+
void mul_sparse_sparse_out_cuda_kernel(
25+
Tensor& result,
26+
const Tensor& x,
27+
const Tensor& y) {
28+
_sparse_binary_op_intersection_kernel_out<CUDAKernelLauncher, MulOp>(
29+
result, x, y
30+
);
31+
}
32+
33+
}
34+
35+
REGISTER_CUDA_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cuda_kernel);
36+
37+
}}

aten/src/ATen/native/sparse/Macros.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
#endif
1111

1212
#if defined(_WIN32) || defined(_WIN64)
13-
#define RESTRICT __restrict
13+
// Temporarily disable __restrict on Windows,
14+
// as it turns out not all MSVC versions are aware of it.
15+
// #define RESTRICT __restrict
16+
#define RESTRICT
1417
#else
1518
#define RESTRICT __restrict__
1619
#endif
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2+
#include <ATen/native/sparse/SparseStubs.h>
3+
#include <ATen/native/sparse/SparseBinaryOpIntersectionCommon.h>
4+
#include <ATen/native/cpu/Loops.h>
5+
6+
namespace at {
7+
namespace native {
8+
9+
namespace {
10+
11+
template <typename func_t>
12+
struct CPUKernelLauncher {
13+
static void launch(TensorIteratorBase& iter, const func_t& f) {
14+
cpu_kernel(iter, f);
15+
}
16+
};
17+
18+
19+
struct MulOp {
20+
static Tensor apply(const Tensor& a, const Tensor& b) {
21+
return a.mul(b);
22+
}
23+
};
24+
25+
void mul_sparse_sparse_out_cpu_kernel(
26+
Tensor& result,
27+
const Tensor& x,
28+
const Tensor& y) {
29+
_sparse_binary_op_intersection_kernel_out<CPUKernelLauncher, MulOp>(
30+
result, x, y
31+
);
32+
}
33+
34+
}
35+
36+
REGISTER_ARCH_DISPATCH(mul_sparse_sparse_out_stub, DEFAULT, &mul_sparse_sparse_out_cpu_kernel);
37+
REGISTER_AVX512_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
38+
REGISTER_AVX2_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
39+
REGISTER_VSX_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
40+
REGISTER_ZVECTOR_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
41+
42+
}}
+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include <ATen/native/DispatchStub.h>
4+
5+
namespace at {
6+
7+
class Tensor;
8+
9+
namespace native {
10+
11+
using mul_sparse_sparse_out_fn = void (*)(Tensor& res, const Tensor& x, const Tensor& y);
12+
DECLARE_DISPATCH(mul_sparse_sparse_out_fn, mul_sparse_sparse_out_stub);
13+
14+
}
15+
16+
}

aten/src/ATen/native/sparse/SparseTensorMath.cpp

+30-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <c10/util/MaybeOwned.h>
88
#include <ATen/core/Tensor.h>
99
#include <ATen/Dispatch.h>
10+
#include <ATen/native/sparse/SparseStubs.h>
1011
#include <ATen/Parallel.h>
1112
#include <ATen/SparseTensorImpl.h>
1213
#include <ATen/ExpandUtils.h>
@@ -1087,6 +1088,13 @@ Tensor& _mul_sparse_sparse_zero_dim_out(const Tensor& zero_dim, const Tensor& ot
10871088
return _mul_dense_sparse_out(scalar_val, other, r);
10881089
}
10891090

1091+
DEFINE_DISPATCH(mul_sparse_sparse_out_stub);
1092+
1093+
Tensor& _mul_sparse_sparse_out(const Tensor& x, const Tensor& y, Tensor& res) {
1094+
mul_sparse_sparse_out_stub(res.device().type(), res, x, y);
1095+
return res;
1096+
}
1097+
10901098
SparseTensor& mul_out_sparse_cpu(const Tensor& t_, const Tensor& src_, Tensor& r) {
10911099
AT_ASSERT(!t_.is_cuda()); // dispatch argument
10921100
TORCH_CHECK(!r.is_cuda(), "mul: expected 'out' to be CPU tensor, but got CUDA tensor");
@@ -1109,14 +1117,35 @@ SparseTensor& mul_out_sparse_cpu(const Tensor& t_, const Tensor& src_, Tensor& r
11091117
return _mul_sparse_sparse_zero_dim_out(t_, src_, r);
11101118
}
11111119

1112-
TORCH_CHECK(t_.sizes().equals(src_.sizes()), "mul: expected 'self' and 'other' to have same sizes when both are sparse"
1120+
const auto is_equal_size_inputs = t_.sizes().equals(src_.sizes());
1121+
1122+
// mul(sparse, sparse) with inputs which broadcast only in dense dims
1123+
if (!is_equal_size_inputs) {
1124+
_mul_sparse_sparse_out(t_, src_, r);
1125+
return r;
1126+
}
1127+
1128+
TORCH_CHECK(is_equal_size_inputs, "mul: expected 'self' and 'other' to have same sizes when both are sparse"
11131129
", but ", t_.sizes(), " != ", src_.sizes());
11141130

1131+
// Short circuit when there is zero nnz
1132+
// Not strictly necessary, but there are tests checking whether
1133+
// resize in mul fails if run on tensors coming from .data/.detach.
11151134
if (!t_._nnz() || !src_._nnz()) {
11161135
r.resize_as_(t_);
11171136
return r.zero_();
11181137
}
11191138

1139+
// _mul_sparse_sparse_out is faster for large inputs
1140+
// and when either of the inputs is uncoalesced.
1141+
if (!t_.is_coalesced() || !src_.is_coalesced()) {
1142+
_mul_sparse_sparse_out(t_, src_, r);
1143+
return r;
1144+
}
1145+
1146+
// Otherwise _mul_sparse_sparse_out might be slower
1147+
// than the brute-force solution below.
1148+
11201149
SparseTensor t = t_.coalesce();
11211150
SparseTensor src = src_.coalesce();
11221151

aten/src/ATen/native/sparse/SparseTensorMath.h

+1
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ TORCH_API sparse::SparseTensor& mul_out_sparse_scalar(sparse::SparseTensor& r, c
88
TORCH_API sparse::SparseTensor& mul_out_sparse_zerodim(sparse::SparseTensor& r, const sparse::SparseTensor& t, const Tensor& value);
99
TORCH_API sparse::SparseTensor& _mul_dense_sparse_out(const Tensor& d, const Tensor& s, Tensor& res);
1010
TORCH_API sparse::SparseTensor& _mul_sparse_sparse_zero_dim_out(const Tensor& zero_dim, const Tensor& other, Tensor& res);
11+
TORCH_API sparse::SparseTensor& _mul_sparse_sparse_out(const Tensor& x, const Tensor& y, Tensor& res);
1112

1213
}}

aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh

-111
Original file line numberDiff line numberDiff line change
@@ -192,117 +192,6 @@ __global__ void indexSparseUnionKernel(
192192
*resultNnz = r_i;
193193
}
194194

195-
template <typename Op, typename IndexType, typename Real>
196-
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
197-
C10_LAUNCH_BOUNDS_2(cuda::getApplyBlockSize(), cuda::getApplyBlocksPerSM())
198-
#endif
199-
__global__ void valueSparseIntersectionKernel(
200-
Op op,
201-
TensorInfo<indexT, IndexType> r_indices,
202-
TensorInfo<indexT, IndexType> t_indices,
203-
TensorInfo<indexT, IndexType> s_indices,
204-
TensorInfo<Real, IndexType> r_values,
205-
TensorInfo<Real, IndexType> t_values,
206-
TensorInfo<Real, IndexType> s_values,
207-
const IndexType t_nnz, const IndexType s_nnz) {
208-
IndexType t_indskip = t_indices.strides[0];
209-
IndexType s_indskip = s_indices.strides[0];
210-
int64_t match, d;
211-
int64_t nDimI = r_indices.sizes[0];
212-
IndexType valueSize = r_values.strides[0];
213-
// reset valueSize if a dense dimension is zero:
214-
for (d=0; d<r_values.dims; d++) {
215-
if (r_values.sizes[d] == 0) {
216-
valueSize = 0;
217-
break;
218-
}
219-
}
220-
IndexType r_i = 0, t_i = 0, s_i = 0;
221-
while (t_i < t_nnz && s_i < s_nnz) {
222-
match = 1;
223-
for (d = 0; d < nDimI; d++) {
224-
if (t_indices.data[d * t_indskip + t_i] < s_indices.data[d * s_indskip + s_i]) {
225-
t_i++;
226-
match = 0;
227-
break;
228-
}
229-
if (t_indices.data[d * t_indskip + t_i] > s_indices.data[d * s_indskip + s_i]) {
230-
s_i++;
231-
match = 0;
232-
break;
233-
}
234-
}
235-
if (!match) continue;
236-
applyOp3(op, valueSize, r_values, r_i++, t_values, t_i++, s_values, s_i++);
237-
}
238-
}
239-
240-
// TODO find a way to parallelize this...
241-
template <typename IndexType, typename Real>
242-
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
243-
C10_LAUNCH_BOUNDS_2(cuda::getApplyBlockSize(), cuda::getApplyBlocksPerSM())
244-
#endif
245-
__global__ void indexSparseIntersectionKernel(
246-
TensorInfo<indexT, IndexType> r_indices,
247-
TensorInfo<indexT, IndexType> t_indices,
248-
TensorInfo<indexT, IndexType> s_indices,
249-
const IndexType t_nnz, const IndexType s_nnz, IndexType *resultNnz) {
250-
IndexType r_indskip = r_indices.strides[0];
251-
IndexType t_indskip = t_indices.strides[0];
252-
IndexType s_indskip = s_indices.strides[0];
253-
int64_t match, d;
254-
int64_t nDimI = r_indices.sizes[0];
255-
IndexType r_i = 0, t_i = 0, s_i = 0;
256-
while (t_i < t_nnz && s_i < s_nnz) {
257-
match = 1;
258-
for (d = 0; d < nDimI; d++) {
259-
if (t_indices.data[d * t_indskip + t_i] < s_indices.data[d * s_indskip + s_i]) {
260-
t_i++;
261-
match = 0;
262-
break;
263-
}
264-
if (t_indices.data[d * t_indskip + t_i] > s_indices.data[d * s_indskip + s_i]) {
265-
s_i++;
266-
match = 0;
267-
break;
268-
}
269-
}
270-
if (!match) continue;
271-
for (d = 0; d < nDimI; d++) {
272-
r_indices.data[d * r_indskip + r_i] = t_indices.data[d * t_indskip + t_i];
273-
}
274-
r_i++; t_i++; s_i++;
275-
}
276-
*resultNnz = r_i;
277-
}
278-
279-
// template <typename Dtype, typename Acctype>
280-
// __global__ void coalesceValuesKernel_gridStrided(
281-
// long *segment_offsets, long *value_indices,
282-
// Dtype *values, Dtype *newValues,
283-
// long nnz, long newNnz, long stride) {
284-
//
285-
// long chunksPerSeg = THCCeilDiv(stride, (long) blockDim.x);
286-
// long numChunks = newNnz * chunksPerSeg;
287-
// long chunkOffset = blockIdx.x * blockDim.y + threadIdx.y;
288-
// long chunkStride = gridDim.x * blockDim.y;
289-
//
290-
// for (long chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) {
291-
// long featureDim = (chunk % chunksPerSeg) * blockDim.x + threadIdx.x;
292-
// if (featureDim < stride) {
293-
// auto valFeat = values + featureDim;
294-
// long seg = chunk / chunksPerSeg;
295-
// auto begin = segment_offsets[seg];
296-
// auto end = (seg < newNnz - 1) ? segment_offsets[seg + 1] : nnz;
297-
// Acctype valSum = static_cast<Acctype>::to(0);
298-
// for (long valIdx = begin; valIdx < end; valIdx++) {
299-
// const long valRow = value_indices[valIdx] * stride;
300-
// valSum += static_cast<Acctype>::to(valFeat[valRow]);
301-
// }
302-
// newValues[seg * stride + featureDim] = static_cast<Dtype>::to(valSum);
303-
// }
304-
// }
305-
// }
306195

307196
template <typename Dtype, typename Acctype>
308197
C10_LAUNCH_BOUNDS_1(num_threads())

aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu

+7-56
Original file line numberDiff line numberDiff line change
@@ -482,66 +482,17 @@ SparseTensor& mul_out_sparse_cuda(const Tensor& t_, const Tensor& src_, SparseTe
482482
TORCH_CHECK(t_.is_cuda(), "mul: expected 'self' to be CUDA, but got CPU");
483483
TORCH_CHECK(src_.is_cuda(), "mul: expected 'other' to be CUDA, but got CPU");
484484
TORCH_CHECK(cuda::check_device({r_, t_, src_}));
485-
TORCH_CHECK(t_.sizes().equals(src_.sizes()), "mul: expected 'self' and 'other' to have same size, but ", t_.sizes(), " != ", src_.sizes());
486485

487-
SparseTensor t = t_.coalesce();
488-
SparseTensor src = src_.coalesce();
486+
// mul(sparse, sparse)
489487

490-
if (src_._nnz() == 0 || t_._nnz() == 0) {
491-
r_.resize_as_(src_);
488+
// Short circuit when there is zero nnz.
489+
// Not strictly necessary, but there are tests checking whether
490+
// resize in mul fails if run on tensors coming from .data/.detach.
491+
if (t_.sizes().equals(src_.sizes()) && (!t_._nnz() || !src_._nnz())) {
492+
r_.resize_as_(t_);
492493
return r_.zero_();
493494
}
494-
495-
// saving those because they can be overwritten when doing in-place operations
496-
int64_t t_nnz = t._nnz(), s_nnz = src._nnz();
497-
int64_t max_nnz = std::min(t_nnz, s_nnz); // multiply by zero is zero, and can be dropped
498-
int64_t sparse_dim = src.sparse_dim();
499-
auto commonDtype = at::result_type(t, src);
500-
TORCH_CHECK(canCast(commonDtype, r_.scalar_type()), "Can't convert result type ", commonDtype, " to output ", r_.scalar_type());
501-
Tensor t_indices_ = t._indices().contiguous();
502-
Tensor t_values_ = t._values().to(commonDtype);
503-
Tensor s_indices_ = src._indices().contiguous();
504-
Tensor s_values_ = src._values().to(commonDtype);
505-
Tensor r_indices_ = at::empty({sparse_dim, max_nnz}, t_indices_.options());
506-
r_.resize_as_(src);
507-
508-
Tensor r_values_ = new_values_with_size_of(t_values_, max_nnz).zero_();
509-
510-
int64_t valueSize = std::max<int64_t>(1, t_values_.stride(0));
511-
const dim3 block = dim3(std::min(static_cast<int64_t>(cuda::getApplyBlock().x), valueSize));
512-
dim3 grid;
513-
int curDevice = -1;
514-
cudaGetDevice(&curDevice);
515-
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
516-
TORCH_CHECK(cuda::getApplyGrid(valueSize, grid, curDevice), "mul: Argument #0: tensor too large or too many dimensions");
517-
518-
Tensor resultNnz = at::empty({1}, CUDA(kLong));
519-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
520-
at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "mul_out_sparse_cuda", [&] {
521-
apply::valueSparseIntersectionKernel<<<grid, block, 0, stream>>>(
522-
TensorMulOp<scalar_t>(),
523-
I_INFO(r_indices_), I_INFO(t_indices_), I_INFO(s_indices_),
524-
V_INFO(r_values_), V_INFO(t_values_), V_INFO(s_values_),
525-
static_cast<uint64_t>(t_nnz), static_cast<uint64_t>(s_nnz));
526-
C10_CUDA_KERNEL_LAUNCH_CHECK();
527-
528-
apply::indexSparseIntersectionKernel<uint64_t, scalar_t>
529-
<<<1, 1, 0, stream>>>(
530-
I_INFO(r_indices_), I_INFO(t_indices_), I_INFO(s_indices_),
531-
// reinterpret_cast shenanigans, because we don't actually have
532-
// unsigned tensors...
533-
static_cast<uint64_t>(t_nnz), static_cast<uint64_t>(s_nnz), reinterpret_cast<uint64_t*>(resultNnz.data_ptr()));
534-
C10_CUDA_KERNEL_LAUNCH_CHECK();
535-
});
536-
r_values_ = r_values_.to(r_.scalar_type());
537-
get_sparse_impl(r_)->set_indices_and_values_unsafe(r_indices_, r_values_);
538-
539-
// sync! (surely there is a more idiomatic way to do this...)
540-
Tensor cpu_resultNnz = at::empty({1}, CPU(kLong));
541-
cpu_resultNnz.copy_(resultNnz);
542-
get_sparse_impl(r_)->set_nnz_and_narrow(cpu_resultNnz.accessor<int64_t, 1>()[0]);
543-
544-
return r_._coalesced_(true);
495+
return _mul_sparse_sparse_out(t_, src_, r_);
545496
}
546497

547498
// --------------------------------------------------------------------

build_variables.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -1412,6 +1412,7 @@ aten_native_source_non_codegen_list = [
14121412
"aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp",
14131413
"aten/src/ATen/native/sparse/SparseFactories.cpp",
14141414
"aten/src/ATen/native/sparse/ValidateCompressedIndicesKernel.cpp",
1415+
"aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp",
14151416
"aten/src/ATen/native/transformers/attention.cpp",
14161417
"aten/src/ATen/native/transformers/transformer.cpp",
14171418
"aten/src/ATen/native/xnnpack/Activation.cpp",

0 commit comments

Comments
 (0)