Skip to content

Commit

Permalink
Reformat Python code with yapf. (triton-lang#2589)
Browse files Browse the repository at this point in the history
I've add an option to yapf to do what we want for long lines, see
google/yapf#1177.  We can now have a real Python
formatter, yay!

To make this PR, I ran my modified yapf over the repository, then looked
over the full diff.  Where yapf was mangling the param list of long
function decls/calls (mostly kernels), I manually added `#` to put
linebreaks where we want.  I fixed up other formatting too -- mostly
adding or removing a trailing comma from lists.

Overall, trailing `#` was sufficient to get formatting similar to our
current code.  I didn't have to disable yapf anywhere.

---------

Co-authored-by: Phil Tillet <[email protected]>
  • Loading branch information
jlebar and ptillet authored Nov 3, 2023
1 parent dced22c commit df08301
Show file tree
Hide file tree
Showing 85 changed files with 3,804 additions and 3,882 deletions.
22 changes: 8 additions & 14 deletions .github/workflows/torch-inductor/scripts/check_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from collections import namedtuple

# Create a named tuple for the output of the benchmark
BenchmarkOutput = namedtuple(
'BenchmarkOutput', ['dev', 'name', 'batch_size', 'speedup', 'latency'])
BenchmarkOutput = namedtuple('BenchmarkOutput', ['dev', 'name', 'batch_size', 'speedup', 'latency'])


def parse_output(file_path: str) -> dict:
Expand All @@ -19,13 +18,11 @@ def parse_output(file_path: str) -> dict:
batch_size = row[2]
speedup = float(row[3])
latency = float(row[4])
entries[name] = BenchmarkOutput(
dev, name, batch_size, speedup, latency)
entries[name] = BenchmarkOutput(dev, name, batch_size, speedup, latency)
return entries


def compare(baseline: dict, new: dict, threshold: float,
geomean_threshold: float) -> bool:
def compare(baseline: dict, new: dict, threshold: float, geomean_threshold: float) -> bool:
baseline_geomean = 1.0
new_geomean = 1.0
for key in new:
Expand All @@ -41,19 +38,16 @@ def compare(baseline: dict, new: dict, threshold: float,
continue

if new_latency < baseline_latency * (1 - threshold):
print(
f"New benchmark {key} is faster than baseline: {new_latency} vs {baseline_latency}")
print(f"New benchmark {key} is faster than baseline: {new_latency} vs {baseline_latency}")
elif new_latency > baseline_latency * (1 + threshold):
print(
f"New benchmark {key} is slower than baseline: {new_latency} vs {baseline_latency}")
print(f"New benchmark {key} is slower than baseline: {new_latency} vs {baseline_latency}")
else:
print(
f"New benchmark {key} is within threshold: {new_latency} vs {baseline_latency}")
print(f"New benchmark {key} is within threshold: {new_latency} vs {baseline_latency}")
baseline_geomean *= baseline[key].speedup
new_geomean *= new[key].speedup

baseline_geomean = baseline_geomean ** (1 / len(baseline))
new_geomean = new_geomean ** (1 / len(new))
baseline_geomean = baseline_geomean**(1 / len(baseline))
new_geomean = new_geomean**(1 / len(new))
print(f"Baseline geomean: {baseline_geomean}")
print(f"New geomean: {new_geomean}")
assert new_geomean >= baseline_geomean * (1 - geomean_threshold), \
Expand Down
4 changes: 0 additions & 4 deletions .isort.cfg

This file was deleted.

16 changes: 5 additions & 11 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,14 @@ repos:
^docs/conf.py$
)
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
- repo: https://github.com/jlebar/yapf
rev: bf301f5ef7777e137b97219842629ca78eb5ef2a
hooks:
- id: isort
exclude: '^python/triton/runtime/.*'
- id: yapf
args: ["-p", "-i"]
stages: [commit, push, manual]
exclude: "python/test/unit/language/test_line_info.py"

- repo: https://github.com/pre-commit/mirrors-autopep8
rev: v1.6.0
hooks:
- id: autopep8
exclude: '^python/triton/runtime/.*'
args: ["-i"]
stages: [commit, push, manual]
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v16.0.6
hooks:
Expand Down
17 changes: 7 additions & 10 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

# -- General configuration ------------------------------------------------


import os
import shutil
import sys
Expand Down Expand Up @@ -121,12 +120,9 @@ def documenter(app, obj, parent):
return old_documenter(app, obj, parent)

sphinx.ext.autosummary.get_documenter = documenter
sphinx.util.inspect.unwrap_all = forward_jit_fn(
sphinx.util.inspect.unwrap_all)
sphinx.util.inspect.signature = forward_jit_fn(
sphinx.util.inspect.signature)
sphinx.util.inspect.object_description = forward_jit_fn(
sphinx.util.inspect.object_description)
sphinx.util.inspect.unwrap_all = forward_jit_fn(sphinx.util.inspect.unwrap_all)
sphinx.util.inspect.signature = forward_jit_fn(sphinx.util.inspect.signature)
sphinx.util.inspect.object_description = forward_jit_fn(sphinx.util.inspect.object_description)


# Auto Doc
Expand All @@ -139,7 +135,8 @@ def documenter(app, obj, parent):
'sphinx.ext.coverage',
'sphinx.ext.napoleon',
'sphinx_multiversion',
'myst_parser']
'myst_parser',
]
autosummary_generate = True

# versioning config
Expand Down Expand Up @@ -294,6 +291,6 @@ def documenter(app, obj, parent):
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'Triton', 'Triton Documentation', author,
'Triton', 'One line description of project.', 'Miscellaneous'),
(master_doc, 'Triton', 'Triton Documentation', author, 'Triton', 'One line description of project.',
'Miscellaneous'),
]
26 changes: 26 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[build-system]
requires = ["setuptools>=40.8.0", "wheel", "cmake>=3.18", "ninja>=1.11.1"]

[tool.yapf]
based_on_style = "pep8"
column_limit = 120
disable_split_list_with_comment = true
each_dict_entry_on_separate_line=false
split_before_named_assigns = false
split_complex_comprehension = true

[tool.yapfignore]
ignore_patterns = [
# This exclusion is also specified in .pre-commit-config.yaml.
# - We put it here because if you run yapf directly, we want it to skip the
# file.
# - We also put it in .pre-commit-config because yapf raises an error if
# pre-commit runs it but all of the files it might touch are ignored!
"python/test/unit/language/test_line_info.py"
]

[tool.ruff]
line-length = 120

[tool.ruff.lint]
ignore = ["E501", "E701", "E731", "E741"]
4 changes: 2 additions & 2 deletions python/examples/copy_strided.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

# triton kernel
@triton.jit
def kernel(X, stride_xm,
Z, stride_zn,
def kernel(X, stride_xm, #
Z, stride_zn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
Expand Down
2 changes: 1 addition & 1 deletion python/examples/empty.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):


X = torch.randn(1, device="cuda")
pgm = kernel[(1,)](X, 1, 1, BLOCK=1024)
pgm = kernel[(1, )](X, 1, 1, BLOCK=1024)
43 changes: 28 additions & 15 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class Package(NamedTuple):
lib_flag: str
syspath_var_name: str


# pybind11


Expand All @@ -63,6 +64,7 @@ def get_pybind11_package_info():
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.11.1.tar.gz"
return Package("pybind11", name, url, "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH")


# llvm


Expand Down Expand Up @@ -121,6 +123,7 @@ def get_thirdparty_packages(triton_cache_path):
thirdparty_cmake_args.append(f"-D{p.lib_flag}={package_dir}/lib")
return thirdparty_cmake_args


# ---- package data ---


Expand Down Expand Up @@ -153,6 +156,7 @@ def download_and_copy(src_path, variable, version, url_func):
os.makedirs(os.path.split(dst_path)[0], exist_ok=True)
shutil.copy(src_path, dst_path)


# ---- cmake extension ----


Expand All @@ -170,18 +174,21 @@ def get_cmake_dir():


class CMakeClean(clean):

def initialize_options(self):
clean.initialize_options(self)
self.build_temp = get_cmake_dir()


class CMakeBuildPy(build_py):

def run(self) -> None:
self.run_command('build_ext')
return super().run()


class CMakeExtension(Extension):

def __init__(self, name, path, sourcedir=""):
Extension.__init__(self, name, sources=[])
self.sourcedir = os.path.abspath(sourcedir)
Expand All @@ -204,7 +211,8 @@ def run(self):
try:
out = subprocess.check_output(["cmake", "--version"])
except OSError:
raise RuntimeError("CMake must be installed to build the following extensions: " + ", ".join(e.name for e in self.extensions))
raise RuntimeError("CMake must be installed to build the following extensions: " +
", ".join(e.name for e in self.extensions))

match = re.search(r"version\s*(?P<major>\d+)\.(?P<minor>\d+)([\d.]+)?", out.decode())
cmake_major, cmake_minor = int(match.group("major")), int(match.group("minor"))
Expand All @@ -231,8 +239,10 @@ def build_extension(self, ext):
# python directories
python_include_dir = sysconfig.get_path("platinclude")
cmake_args = [
"-G", "Ninja", # Ninja is much faster than make
"-DCMAKE_MAKE_PROGRAM=" + ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path
"-G",
"Ninja", # Ninja is much faster than make
"-DCMAKE_MAKE_PROGRAM=" +
ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path
"-DCMAKE_EXPORT_COMPILE_COMMANDS=ON",
"-DLLVM_ENABLE_WERROR=ON",
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
Expand Down Expand Up @@ -266,12 +276,14 @@ def build_extension(self, ext):
build_args += ['-j' + max_jobs]

if check_env_flag("TRITON_BUILD_WITH_CLANG_LLD"):
cmake_args += ["-DCMAKE_C_COMPILER=clang",
"-DCMAKE_CXX_COMPILER=clang++",
"-DCMAKE_LINKER=lld",
"-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld",
"-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld",
"-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld"]
cmake_args += [
"-DCMAKE_C_COMPILER=clang",
"-DCMAKE_CXX_COMPILER=clang++",
"-DCMAKE_LINKER=lld",
"-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld",
"-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld",
"-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld",
]

# Note that asan doesn't work with binaries that use the GPU, so this is
# only useful for tools like triton-opt that don't run code on the GPU.
Expand Down Expand Up @@ -303,19 +315,22 @@ def build_extension(self, ext):
src_path="bin/ptxas",
variable="TRITON_PTXAS_PATH",
version="12.1.105",
url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2",
url_func=lambda arch, version:
f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2",
)
download_and_copy(
src_path="bin/cuobjdump",
variable="TRITON_CUOBJDUMP_PATH",
version="12.1.111",
url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2",
url_func=lambda arch, version:
f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2",
)
download_and_copy(
src_path="bin/nvdisasm",
variable="TRITON_NVDISASM_PATH",
version="12.1.105",
url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2",
url_func=lambda arch, version:
f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2",
)

setup(
Expand All @@ -339,9 +354,7 @@ def build_extension(self, ext):
"triton/third_party",
"triton/tools",
],
install_requires=[
"filelock"
],
install_requires=["filelock"],
include_package_data=True,
ext_modules=[CMakeExtension("triton", "triton/_C/")],
cmdclass={"build_ext": CMakeBuild, "build_py": CMakeBuildPy, "clean": CMakeClean},
Expand Down
9 changes: 5 additions & 4 deletions python/test/backend/test_device_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@

import triton
import triton.language as tl
from triton.common.backend import (BaseBackend, compute_core_version_key,
register_backend)
from triton.common.backend import (BaseBackend, compute_core_version_key, register_backend)
from triton.common.build import quiet
from triton.compiler.make_launcher import make_so_cache_key
from triton.runtime.cache import get_cache_manager
Expand Down Expand Up @@ -81,6 +80,7 @@ def build_for_backend(name, src, srcdir):


class ExtensionUtils:

def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(ExtensionUtils, cls).__new__(cls)
Expand Down Expand Up @@ -110,6 +110,7 @@ def __init__(self):


class ExtensionDriver(DriverBase):

def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(ExtensionDriver, cls).__new__(cls)
Expand Down Expand Up @@ -256,13 +257,13 @@ def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):

inp = torch.randn(10)
out = torch.randn(10)
kernel[(10,)](inp, out, 10, XBLOCK=16)
kernel[(10, )](inp, out, 10, XBLOCK=16)
spec = importlib.util.spec_from_file_location("__triton_launcher", ExtensionBackend.stub_so_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
launch_counter = getattr(mod, "launch_counter")

for _ in range(100):
kernel[(10,)](inp, out, 10, XBLOCK=16)
kernel[(10, )](inp, out, 10, XBLOCK=16)

assert launch_counter() > 0
4 changes: 1 addition & 3 deletions python/test/backend/third_party_backends/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@


def pytest_addoption(parser):
parser.addoption(
"--backend", action="store", default="", help="Codegen backend"
)
parser.addoption("--backend", action="store", default="", help="Codegen backend")


@pytest.fixture
Expand Down
8 changes: 4 additions & 4 deletions python/test/backend/third_party_backends/test_xpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def kernel(x_ptr, y_ptr, out_ptr):

if has_ipex:
for _ in range(1000):
x = torch.randn((65536,), device="xpu", dtype=torch.float32)
y = torch.randn((65536,), device="xpu", dtype=torch.float32)
z = torch.zeros((65536,), device="xpu", dtype=torch.float32)
kernel[(65536,)](x, y, z, num_warps=32)
x = torch.randn((65536, ), device="xpu", dtype=torch.float32)
y = torch.randn((65536, ), device="xpu", dtype=torch.float32)
z = torch.zeros((65536, ), device="xpu", dtype=torch.float32)
kernel[(65536, )](x, y, z, num_warps=32)
assert torch.all(x + y == z)
else:
return
Loading

0 comments on commit df08301

Please sign in to comment.