1
+ #include < ATen/core/Tensor.h>
2
+ #include < ATen/native/transformers/attention.h>
3
+ #include < ATen/native/transformers/sdp_utils_cpp.h>
4
+
5
+ #ifndef AT_PER_OPERATOR_HEADERS
6
+ #include < ATen/Functions.h>
7
+ #include < ATen/NativeFunctions.h>
8
+ #else
9
+ #include < ATen/ops/empty_like.h>
10
+ #include < ATen/ops/linear.h>
11
+ #include < ATen/ops/scaled_dot_product_attention.h>
12
+ #endif
13
+
14
+ #include < ATen/native/cutlass/Attention.h>
15
+ #include < ATen/native/cutlass/sycl/AttentionKernels.h>
16
+
17
+ #include < comm/SYCLContext.h>
18
+
19
+ namespace at {
20
+ namespace native {
21
+ namespace cutlass_sycl {
22
+
23
+ void sdpa_backward (
24
+ int batch_size,
25
+ int num_head_q,
26
+ int num_head_kv,
27
+ int seq_len_q,
28
+ int seq_len_kv,
29
+ int head_dim_qk,
30
+ int head_dim_v,
31
+ const Tensor& grad_out,
32
+ const Tensor& query,
33
+ const Tensor& key,
34
+ const Tensor& value,
35
+ const Tensor& out,
36
+ const Tensor& logsumexp,
37
+ std::optional<at::Tensor> attn_mask,
38
+ bool is_causal,
39
+ double scale,
40
+ Tensor& grad_query,
41
+ Tensor& grad_key,
42
+ Tensor& grad_value) {
43
+
44
+ std::cout << " lfq: entering cutlass sdpa_backward" << std::endl;
45
+
46
+ auto ps = at::matmul (query, key.transpose (-2 , -1 ));
47
+ ps = ps / std::sqrt (scale);
48
+ ps = at::softmax (ps, -1 ).to (query.dtype ());
49
+ auto dps = at::empty_like (ps);
50
+ cutlass_sdpa_backward (batch_size, num_head_q, num_head_kv, seq_len_q, seq_len_kv,
51
+ head_dim_qk, head_dim_v,
52
+ grad_out.data_ptr (),
53
+ query.data_ptr (),
54
+ key.data_ptr (),
55
+ value.data_ptr (),
56
+ ps.data_ptr (),
57
+ nullptr ,
58
+ grad_query.data_ptr (),
59
+ grad_key.data_ptr (),
60
+ grad_value.data_ptr (),
61
+ dps.data_ptr ());
62
+ }
63
+ } // cutlass_sycl
64
+ } // namespace native
65
+ } // namespace at
0 commit comments