Skip to content

Commit 8b015d5

Browse files
committed
enable with host compiler
1 parent 728c3f1 commit 8b015d5

File tree

11 files changed

+1193
-3
lines changed

11 files changed

+1193
-3
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ if(USE_XCCL)
5656
endif()
5757
endif()
5858

59+
set(USE_CUTLASS ON)
60+
if (USE_CUTLASS)
61+
include(${TORCH_XPU_OPS_ROOT}/cmake/CUTLASS.cmake)
62+
endif()
63+
5964
if(BUILD_TEST)
6065
add_subdirectory(${TORCH_XPU_OPS_ROOT}/test/sycl ${CMAKE_BINARY_DIR}/test_sycl)
6166
endif()

cmake/BuildFlags.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ endfunction()
2626
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
2727
# # -- Host flags (SYCL_CXX_FLAGS)
2828
if(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
29-
list(APPEND SYCL_HOST_FLAGS /std:c++17)
29+
list(APPEND SYCL_HOST_FLAGS /std:c++20)
3030
list(APPEND SYCL_HOST_FLAGS /MD)
3131
list(APPEND SYCL_HOST_FLAGS /EHsc) # exception handling
3232
# SYCL headers warnings
3333
list(APPEND SYCL_HOST_FLAGS /wd4996) # allow usage of deprecated functions
3434
list(APPEND SYCL_HOST_FLAGS /wd4018) # allow signed and unsigned comparison
3535
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
3636
list(APPEND SYCL_HOST_FLAGS -fPIC)
37-
list(APPEND SYCL_HOST_FLAGS -std=c++17)
37+
list(APPEND SYCL_HOST_FLAGS -std=c++20)
3838
list(APPEND SYCL_HOST_FLAGS -Wunused-variable)
3939
# SYCL headers warnings
4040
list(APPEND SYCL_HOST_FLAGS -Wno-deprecated-declarations)

cmake/CUTLASS.cmake

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
if(NOT __CUTLASS_INCLUDED)
2+
set(__CUTLASS_INCLUDED TRUE)
3+
include(FetchContent)
4+
FetchContent_Declare(
5+
repo-cutlass-sycl
6+
GIT_REPOSITORY https://github.com/intel/cutlass-sycl #https://github.com/rolandschulz/cutlass-fork.git
7+
GIT_TAG sycl-develop #gcc-support
8+
GIT_SHALLOW OFF
9+
)
10+
FetchContent_GetProperties(repo-cutlass-sycl)
11+
if(NOT repo-cutlass-sycl_POPULATED)
12+
FetchContent_Populate(repo-cutlass-sycl)
13+
endif()
14+
set(CUTLASS_SYCL_INCLUDE_DIRS ${repo-cutlass-sycl_SOURCE_DIR}/include
15+
${repo-cutlass-sycl_SOURCE_DIR}/tools/util/include)
16+
set(CUTLASS_SYCL_COMPILE_DEFINITIONS CUTLASS_ENABLE_SYCL SYCL_INTEL_TARGET)
17+
endif()

cmake/Modules/FindSYCL.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,7 @@ macro(SYCL_ADD_LIBRARY sycl_target)
469469
target_link_libraries(
470470
${sycl_target}
471471
${SYCL_LINK_LIBRARIES_KEYWORD}
472+
PRIVATE
472473
${SYCL_LIBRARY})
473474

474475
set_target_properties(${sycl_target}
@@ -528,6 +529,7 @@ macro(SYCL_ADD_EXECUTABLE sycl_target)
528529
target_link_libraries(
529530
${sycl_target}
530531
${SYCL_LINK_LIBRARIES_KEYWORD}
532+
PRIVATE
531533
${SYCL_LIBRARY})
532534

533535
set_target_properties(${sycl_target}

src/ATen/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,23 @@
33
file(GLOB xpu_h "xpu/*.h")
44
file(GLOB xpu_cpp "xpu/*.cpp")
55
file(GLOB xpu_mkl "native/xpu/mkl/*.cpp")
6-
file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp")
6+
file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp" "native/cutlass/*.cpp")
77
file(GLOB xpu_sycl "native/xpu/sycl/*.cpp" "native/sparse/xpu/sycl/*.cpp" "native/nested/xpu/sycl/*.cpp" "native/transformers/sycl/*.cpp" "native/quantized/sycl/*.cpp")
8+
file(GLOB xpu_cutlass_sycl "native/cutlass/sycl/*.cpp")
89

910
list(APPEND ATen_XPU_CPP_SRCS ${xpu_cpp})
1011
if(USE_ONEMKL_XPU)
1112
list(APPEND ATen_XPU_MKL_SRCS ${xpu_mkl})
1213
endif()
1314
list(APPEND ATen_XPU_NATIVE_CPP_SRCS ${xpu_native_cpp})
1415
list(APPEND ATen_XPU_SYCL_SRCS ${xpu_sycl})
16+
list(APPEND ATen_XPU_CUTLASS_SYCL_SRCS ${xpu_cutlass_sycl})
1517

1618
set(ATen_XPU_CPP_SRCS ${ATen_XPU_CPP_SRCS} PARENT_SCOPE)
1719
set(ATen_XPU_MKL_SRCS ${ATen_XPU_MKL_SRCS} PARENT_SCOPE)
1820
set(ATen_XPU_NATIVE_CPP_SRCS ${ATen_XPU_NATIVE_CPP_SRCS} PARENT_SCOPE)
1921
set(ATen_XPU_SYCL_SRCS ${ATen_XPU_SYCL_SRCS} PARENT_SCOPE)
22+
set(ATen_XPU_CUTLASS_SYCL_SRCS ${ATen_XPU_CUTLASS_SYCL_SRCS} PARENT_SCOPE)
2023

2124
foreach(HEADER ${xpu_h})
2225
install(FILES ${HEADER} DESTINATION "${AT_INSTALL_INCLUDE_DIR}/ATen/xpu")
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#include <ATen/core/Tensor.h>
2+
#include <ATen/native/transformers/attention.h>
3+
#include <ATen/native/transformers/sdp_utils_cpp.h>
4+
5+
#ifndef AT_PER_OPERATOR_HEADERS
6+
#include <ATen/Functions.h>
7+
#include <ATen/NativeFunctions.h>
8+
#else
9+
#include <ATen/ops/empty_like.h>
10+
#include <ATen/ops/linear.h>
11+
#include <ATen/ops/scaled_dot_product_attention.h>
12+
#endif
13+
14+
#include <ATen/native/cutlass/Attention.h>
15+
#include <ATen/native/cutlass/sycl/AttentionKernels.h>
16+
17+
#include <comm/SYCLContext.h>
18+
19+
namespace at {
20+
namespace native {
21+
namespace cutlass_sycl{
22+
23+
void sdpa_backward(
24+
int batch_size,
25+
int num_head_q,
26+
int num_head_kv,
27+
int seq_len_q,
28+
int seq_len_kv,
29+
int head_dim_qk,
30+
int head_dim_v,
31+
const Tensor& grad_out,
32+
const Tensor& query,
33+
const Tensor& key,
34+
const Tensor& value,
35+
const Tensor& out,
36+
const Tensor& logsumexp,
37+
std::optional<at::Tensor> attn_mask,
38+
bool is_causal,
39+
double scale,
40+
Tensor& grad_query,
41+
Tensor& grad_key,
42+
Tensor& grad_value) {
43+
44+
std::cout << "lfq: entering cutlass sdpa_backward" << std::endl;
45+
46+
auto ps = at::matmul(query, key.transpose(-2, -1));
47+
ps = ps / std::sqrt(scale);
48+
ps = at::softmax(ps, -1).to(query.dtype());
49+
auto dps = at::empty_like(ps);
50+
cutlass_sdpa_backward(batch_size, num_head_q, num_head_kv, seq_len_q, seq_len_kv,
51+
head_dim_qk, head_dim_v,
52+
grad_out.data_ptr(),
53+
query.data_ptr(),
54+
key.data_ptr(),
55+
value.data_ptr(),
56+
ps.data_ptr(),
57+
nullptr,
58+
grad_query.data_ptr(),
59+
grad_key.data_ptr(),
60+
grad_value.data_ptr(),
61+
dps.data_ptr());
62+
}
63+
} // cutlass_sycl
64+
} // namespace native
65+
} // namespace at
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
5+
namespace at {
6+
namespace native {
7+
namespace cutlass_sycl{
8+
9+
void sdpa_backward(
10+
int batch_size,
11+
int num_head_q,
12+
int num_head_kv,
13+
int seq_len_q,
14+
int seq_len_kv,
15+
int head_dim_qk,
16+
int head_dim_v,
17+
const Tensor& grad_out,
18+
const Tensor& query,
19+
const Tensor& key,
20+
const Tensor& value,
21+
const Tensor& out,
22+
const Tensor& logsumexp,
23+
std::optional<at::Tensor> attn_mask,
24+
bool is_causal,
25+
double scale,
26+
Tensor& grad_query,
27+
Tensor& grad_key,
28+
Tensor& grad_value);
29+
30+
} // namespace cutlass_sycl
31+
} // namespace native
32+
} // namespace at

0 commit comments

Comments
 (0)