Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .prospector.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ pep8:
pylint:
run: true
options:
max-args: 10
max-positional-arguments: 10
max-args: 15
max-positional-arguments: 15
disable:
- import-error
- import-outside-toplevel
Expand Down
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ build_cmake: clean
build_wheel:
export CMAKE_GENERATOR="Unix Makefiles" && pip wheel -v . --extra-index-url https://download.pytorch.org/whl/cpu

build_cpu:
python scripts/build_cpu.py

# 'make compile_abs' compiles 'kernel_abs.cpp' into 'libkernel_abs.so' without building the whole wheel package.
# This is useful for development and debugging of individual kernels.
Expand All @@ -47,3 +49,6 @@ test:

test_tri_inv:
pytest tests/test_tri_inv_*.py

test_cpu:
python scripts/test_cpu.py
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ make test

---

## CPU Simulation

Subset of the kernels supports CPU simulation. Kernels can be built with

```bash
make build_cpu
```
and tested with

```bash
make test_cpu
```

---

## Repository Structure

```
Expand Down
57 changes: 36 additions & 21 deletions csrc/host/pybind11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,36 @@ for the full License text.

#include "torch_abs.h"
#include "torch_batch_matrix_square.h"
#include "torch_chunk_cumsum.h"
#include "torch_csr_gather.h"
#include "torch_gdn_chunk_h.h"
#include "torch_gdn_chunk_o.h"
#include "torch_gdn_scaled_dot_kkt.h"
#include "torch_gdn_wy_fast.h"
#include "torch_scan_ul1.h"
#include "torch_simple_matmul.h"
#include "torch_swiglu.h"
#include "torch_tri_inv.h"
#include "torch_tri_inv_ns.h"
#include "torch_tri_inv_rec_unroll.h"
#include "torch_tri_inv_trick.h"
#ifndef __CPU_SIM
#include "torch_chunk_cumsum.h"
#include "torch_gdn_chunk_h.h"
#include "torch_gdn_chunk_o.h"
#include "torch_gdn_scaled_dot_kkt.h"
#include "torch_gdn_wy_fast.h"
#include "torch_tri_inv_rec_unroll.h"
#endif

using namespace pto_isa_ops;

// Not really needed, but to ensure different modules in PyTorch and avoid any
// potential name clashing. This way, both can coexist in a single runtime.
#ifdef __CPU_SIM
#define MODULE_NAME pto_kernels_cpu
#else
#define MODULE_NAME pto_kernels_ops
#endif

/**
* @brief Pybind11 module.
*/
PYBIND11_MODULE(pto_kernels_ops, m) {
PYBIND11_MODULE(MODULE_NAME, m) {
m.doc() = "PTO-ISA Kernels";
m.def(
"get_aic_cores",
Expand All @@ -43,15 +53,31 @@ PYBIND11_MODULE(pto_kernels_ops, m) {
},
pybind11::arg("device_id") = 0);
m.def("pto_abs", &pto_isa_ops::run_abs);
m.def("pto_batch_matrix_square", &pto_isa_ops::run_batch_matrix_square);
m.def("pto_csr_gather", &pto_isa_ops::run_csr_gather);
m.def("pto_scan_ul1", &pto_isa_ops::run_scan_ul1);
m.def("pto_simple_matmul", &pto_isa_ops::run_simple_matmul);
m.def("pto_swiglu", &pto_isa_ops::run_swiglu, py::arg("x"),
py::arg("dim") = -1);
m.def("pto_tri_inv_trick", &pto_isa_ops::run_tri_inv_trick);
m.def("pto_tri_inv_ns", &pto_isa_ops::run_tri_inv_ns, py::arg("M"),
py::arg("num_iters") = 0, py::arg("scale_value") = 0.0f);
m.def("pto_tri_inv", &pto_isa_ops::run_tri_inv);
#ifndef __CPU_SIM
m.def("pto_tri_inv_rec_unroll", &pto_isa_ops::run_tri_inv_rec_unroll,
py::arg("M"), py::arg("cu_seqlens") = at::zeros({1}),
py::arg("is_bsnd_format") = false,
py::arg("dtype_out") = at::ScalarType::Half);
m.def("pto_tri_inv_ns", &pto_isa_ops::run_tri_inv_ns, py::arg("M"),
py::arg("num_iters") = 0, py::arg("scale_value") = 0.0f);
m.def("pto_tri_inv", &pto_isa_ops::run_tri_inv);
m.def("pto_chunk_h", &pto_isa_ops::run_gdn_chunk_h, py::arg("K"),
py::arg("W"), py::arg("U"), py::arg("G"),
py::arg("cu_seqlens") = at::zeros({1}), py::arg("batch_size"),
py::arg("seq_len"), py::arg("total_chunks"));
m.def("pto_chunk_cumsum", &pto_isa_ops::run_chunk_cumsum, py::arg("g"),
py::arg("batch_size"), py::arg("seq_len"),
py::arg("cu_seqlens") = at::zeros({1}));
m.def("pto_batch_matrix_square", &pto_isa_ops::run_batch_matrix_square);
m.def("pto_csr_gather", &pto_isa_ops::run_csr_gather);
m.def("pto_gdn_scaled_dot_kkt", &pto_isa_ops::run_gdn_scaled_dot_kkt,
py::arg("K"), py::arg("Beta"), py::arg("G"), py::arg("Msk"),
py::arg("batch_size"), py::arg("seq_len"),
Expand All @@ -60,20 +86,9 @@ PYBIND11_MODULE(pto_kernels_ops, m) {
py::arg("K"), py::arg("V"), py::arg("S"), py::arg("G"), py::arg("Msk"),
py::arg("batch_size"), py::arg("seq_len"),
py::arg("cu_seqlens") = at::zeros({1}));
m.def("pto_scan_ul1", &pto_isa_ops::run_scan_ul1);
m.def("pto_simple_matmul", &pto_isa_ops::run_simple_matmul);
m.def("pto_swiglu", &pto_isa_ops::run_swiglu, py::arg("x"),
py::arg("dim") = -1);
m.def("pto_tri_inv_trick", &pto_isa_ops::run_tri_inv_trick);
m.def("pto_tri_inv_rec_unroll", &pto_isa_ops::run_tri_inv_rec_unroll,
py::arg("M"), py::arg("cu_seqlens") = at::zeros({1}),
py::arg("is_bsnd_format") = false,
py::arg("dtype_out") = at::ScalarType::Half);
m.def("pto_tri_inv_ns", &pto_isa_ops::run_tri_inv_ns, py::arg("M"),
py::arg("num_iters") = 0, py::arg("scale_value") = 0.0f);
m.def("pto_tri_inv", &pto_isa_ops::run_tri_inv);
m.def("pto_gdn_wy_fast", &pto_isa_ops::run_gdn_wy_fast, py::arg("K"),
py::arg("V"), py::arg("Beta"), py::arg("G"), py::arg("A"),
py::arg("batch_size"), py::arg("seq_len"),
py::arg("cu_seqlens") = at::zeros({1}));
#endif
}
7 changes: 7 additions & 0 deletions csrc/host/torch_abs.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@ for the full License text.
#include <ATen/ATen.h>
#include <torch/library.h>

#ifdef __CPU_SIM
#include "../kernel/kernel_abs.h"
#include "utils_cpu.h"
#else
#include "aclrtlaunch_vabs_fp16.h"
#include "aclrtlaunch_vabs_fp32.h"
#include "utils.h"
#endif

namespace pto_isa_ops {

Expand All @@ -40,8 +45,10 @@ at::Tensor run_abs(const at::Tensor& x) {
block_dim = total_tiles;
}

#ifndef __CPU_SIM
TORCH_CHECK(x.device().type() == DEVICE_TYPE,
"pto_abs: tensor must be on NPU, got ", x.device());
#endif
TORCH_CHECK(dtype == at::kHalf || dtype == at::kFloat,
"pto_abs: dtype must be fp16 or float32, got ", dtype);
if (dtype == at::kHalf) {
Expand Down
7 changes: 7 additions & 0 deletions csrc/host/torch_batch_matrix_square.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@ for the full License text.
#include <ATen/ATen.h>
#include <torch/library.h>

#ifdef __CPU_SIM
#include "../kernel/kernel_batch_matrix_square.h"
#include "utils_cpu.h"
#else
#include "aclrtlaunch_batch_matrix_square_fp16.h"
#include "aclrtlaunch_batch_matrix_square_fp32.h"
#include "utils.h"
#endif

namespace pto_isa_ops {

Expand All @@ -29,8 +34,10 @@ at::Tensor run_batch_matrix_square(const at::Tensor& x) {
const auto dtype = x.options().dtype();
const auto dtype_out = at::kFloat;

#ifndef __CPU_SIM
TORCH_CHECK(device.type() == DEVICE_TYPE,
"batch_matrix_square: tensor must be on NPU, got ", device);
#endif
TORCH_CHECK(dtype == at::kHalf || dtype == at::kFloat,
"batch_matrix_square: dtype must be fp16 or float32, got ",
dtype);
Expand Down
7 changes: 7 additions & 0 deletions csrc/host/torch_csr_gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@ for the full License text.
#include <ATen/ATen.h>
#include <torch/library.h>

#ifdef __CPU_SIM
#include "../kernel/kernel_csr_gather.h"
#include "utils_cpu.h"
#else
#include "aclrtlaunch_csr_gather_fp16.h"
#include "aclrtlaunch_csr_gather_fp32.h"
#include "utils.h"
#endif

namespace pto_isa_ops {

Expand Down Expand Up @@ -52,12 +57,14 @@ at::Tensor run_csr_gather(const at::Tensor& values, const at::Tensor& indices,
block_dim = total_tiles;
}

#ifndef __CPU_SIM
TORCH_CHECK(values.device().type() == DEVICE_TYPE,
"csr_gather: tensor must be on NPU, got ", values.device());
TORCH_CHECK(indices.device().type() == DEVICE_TYPE,
"csr_gather: tensor must be on NPU, got ", indices.device());
TORCH_CHECK(x.device().type() == DEVICE_TYPE,
"csr_gather: tensor must be on NPU, got ", x.device());
#endif
TORCH_CHECK(dtype == at::kHalf || dtype == at::kFloat,
"csr_gather: dtype must be fp16 or float32, got ", dtype);
if (dtype == at::kHalf) {
Expand Down
7 changes: 7 additions & 0 deletions csrc/host/torch_scan_ul1.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@ for the full License text.
#include <ATen/ATen.h>
#include <torch/library.h>

#ifdef __CPU_SIM
#include "../kernel/kernel_scan_ul1.h"
#include "utils_cpu.h"
#else
#include "aclrtlaunch_scan_ul1_fp16.h"
#include "aclrtlaunch_scan_ul1_fp32.h"
#include "utils.h"
#endif

namespace pto_isa_ops {

Expand All @@ -28,8 +33,10 @@ at::Tensor run_scan_ul1(const at::Tensor& x) {
const auto dtype = x.options().dtype();
const auto dtype_out = at::kFloat;

#ifndef __CPU_SIM
TORCH_CHECK(device.type() == DEVICE_TYPE,
"scan_ul1: tensor must be on NPU, got ", device);
#endif
TORCH_CHECK(dtype == at::kHalf || dtype == at::kFloat,
"scan_ul1: dtype must be fp16 or float32, got ", dtype);

Expand Down
7 changes: 7 additions & 0 deletions csrc/host/torch_simple_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@ for the full License text.
#include <ATen/ATen.h>
#include <torch/library.h>

#ifdef __CPU_SIM
#include "../kernel/kernel_simple_matmul.h"
#include "utils_cpu.h"
#else
#include "aclrtlaunch_simple_matmul_bf16.h"
#include "aclrtlaunch_simple_matmul_fp16.h"
#include "aclrtlaunch_simple_matmul_fp32.h"
#include "utils.h"
#endif

namespace pto_isa_ops {

Expand All @@ -30,10 +35,12 @@ at::Tensor run_simple_matmul(const at::Tensor& a, const at::Tensor& b) {
const auto dtype = a.options().dtype();
const auto dtype_out = at::kFloat;

#ifndef __CPU_SIM
TORCH_CHECK(device.type() == DEVICE_TYPE,
"simple_matmul: tensor must be on NPU, got ", device);
TORCH_CHECK(b.device().type() == DEVICE_TYPE,
"simple_matmul: tensor must be on NPU, got ", b.device());
#endif
TORCH_CHECK(
dtype == at::kHalf || dtype == at::kFloat || dtype == at::kBFloat16,
"simple_matmul: dtype must be fp16, bf16, or float32, got ", dtype);
Expand Down
7 changes: 7 additions & 0 deletions csrc/host/torch_swiglu.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@ for the full License text.

#include <limits>

#ifdef __CPU_SIM
#include "../kernel/kernel_swiglu.h"
#include "utils_cpu.h"
#else
#include "aclrtlaunch_swiglu_fp16.h"
#include "utils.h"
#endif

namespace pto_isa_ops {

Expand All @@ -35,8 +40,10 @@ at::Tensor run_swiglu(const at::Tensor& x, int64_t dim = -1) {
dim += x.dim();
}
TORCH_CHECK(dim == 1, "swiglu: currently supports only dim=-1");
#ifndef __CPU_SIM
TORCH_CHECK(x.device().type() == DEVICE_TYPE,
"swiglu: tensor must be on NPU, got ", x.device());
#endif
TORCH_CHECK(x.scalar_type() == at::kHalf, "swiglu: dtype must be fp16, got ",
x.scalar_type());
TORCH_CHECK(x.is_contiguous(), "swiglu: expects a contiguous input tensor");
Expand Down
7 changes: 7 additions & 0 deletions csrc/host/torch_tri_inv.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@ for the full License text.
#include <ATen/ATen.h>
#include <torch/library.h>

#ifdef __CPU_SIM
#include "../kernel/kernel_tri_inv_col_sweep.h"
#include "utils_cpu.h"
#else
#include "aclrtlaunch_triv_inv_col_sweep_fp16.h"
#include "aclrtlaunch_triv_inv_col_sweep_fp32.h"
#include "utils.h"
#endif

namespace pto_isa_ops {

Expand All @@ -28,8 +33,10 @@ namespace pto_isa_ops {
at::Tensor run_tri_inv(const at::Tensor& x) {
const at::Device device = x.options().device();
const auto dtype = x.options().dtype();
#ifndef __CPU_SIM
TORCH_CHECK(device.type() == DEVICE_TYPE,
"tri_inv: tensor must be on NPU, got ", device);
#endif
TORCH_CHECK(x.dim() >= 2,
"tri_inv: input tensor must have at least 2 dimensions, got ",
x.dim());
Expand Down
7 changes: 7 additions & 0 deletions csrc/host/torch_tri_inv_ns.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@ for the full License text.

#include <cmath>

#ifdef __CPU_SIM
#include "../kernel/kernel_tri_inv_ns.h"
#include "utils_cpu.h"
#else
#include "aclrtlaunch_tri_inv_ns_fp16.h"
#include "utils.h"
#endif

namespace pto_isa_ops {

Expand Down Expand Up @@ -43,8 +48,10 @@ at::Tensor run_tri_inv_ns(const at::Tensor& M, uint32_t num_iters = 0,
const auto dtype = M.options().dtype();
const auto dtype_out = at::kFloat;

#ifndef __CPU_SIM
TORCH_CHECK(device.type() == DEVICE_TYPE,
"tri_inv_ns: tensor must be on NPU, got ", device);
#endif
TORCH_CHECK(dtype == at::kHalf, "tri_inv_ns: dtype must be fp16, got ",
dtype);
const uint32_t n = static_cast<uint32_t>(M.size(-1));
Expand Down
7 changes: 7 additions & 0 deletions csrc/host/torch_tri_inv_rec_unroll.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@ for the full License text.
#include <ATen/ATen.h>
#include <torch/library.h>

#ifdef __CPU_SIM
#include "../kernel/kernel_tri_inv_rec_unroll.h"
#include "utils_cpu.h"
#else
#include "aclrtlaunch_tri_inv_rec_unroll_fp16fp16.h"
#include "aclrtlaunch_tri_inv_rec_unroll_fp16fp32.h"
#include "utils.h"
#endif

namespace pto_isa_ops {

Expand Down Expand Up @@ -41,8 +46,10 @@ at::Tensor run_tri_inv_rec_unroll(
const at::Device device = M.options().device();
const auto dtype = M.options().dtype();

#ifndef __CPU_SIM
TORCH_CHECK(device.type() == DEVICE_TYPE,
"tri_inv_ns: tensor must be on NPU, got ", device);
#endif
TORCH_CHECK(dtype_out == at::kHalf || dtype_out == at::kFloat,
"tri_inv_rec_unroll: dtype_out must be fp16 or float32, got ",
dtype_out);
Expand Down
Loading
Loading