Reformat Python code with yapf. (#2589)

I've add an option to yapf to do what we want for long lines, see
https://github.com/google/yapf/pull/1177.  We can now have a real Python
formatter, yay!

To make this PR, I ran my modified yapf over the repository, then looked
over the full diff.  Where yapf was mangling the param list of long
function decls/calls (mostly kernels), I manually added `#` to put
linebreaks where we want.  I fixed up other formatting too -- mostly
adding or removing a trailing comma from lists.

Overall, trailing `#` was sufficient to get formatting similar to our
current code.  I didn't have to disable yapf anywhere.

---------

Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
Justin Lebar
2023-11-02 20:44:17 -07:00
committed by GitHub
parent dced22c4b7
commit df08301e76
85 changed files with 3802 additions and 3880 deletions

View File

@@ -13,8 +13,7 @@ import torch
import triton
import triton.language as tl
from triton.common.backend import (BaseBackend, compute_core_version_key,
register_backend)
from triton.common.backend import (BaseBackend, compute_core_version_key, register_backend)
from triton.common.build import quiet
from triton.compiler.make_launcher import make_so_cache_key
from triton.runtime.cache import get_cache_manager
@@ -81,6 +80,7 @@ def build_for_backend(name, src, srcdir):
class ExtensionUtils:
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(ExtensionUtils, cls).__new__(cls)
@@ -110,6 +110,7 @@ class ExtensionUtils:
class ExtensionDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(ExtensionDriver, cls).__new__(cls)
@@ -256,13 +257,13 @@ def test_dummy_backend():
inp = torch.randn(10)
out = torch.randn(10)
kernel[(10,)](inp, out, 10, XBLOCK=16)
kernel[(10, )](inp, out, 10, XBLOCK=16)
spec = importlib.util.spec_from_file_location("__triton_launcher", ExtensionBackend.stub_so_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
launch_counter = getattr(mod, "launch_counter")
for _ in range(100):
kernel[(10,)](inp, out, 10, XBLOCK=16)
kernel[(10, )](inp, out, 10, XBLOCK=16)
assert launch_counter() > 0

View File

@@ -4,9 +4,7 @@ import pytest
def pytest_addoption(parser):
parser.addoption(
"--backend", action="store", default="", help="Codegen backend"
)
parser.addoption("--backend", action="store", default="", help="Codegen backend")
@pytest.fixture

View File

@@ -24,10 +24,10 @@ def test_xpu_backend(cmdopt):
if has_ipex:
for _ in range(1000):
x = torch.randn((65536,), device="xpu", dtype=torch.float32)
y = torch.randn((65536,), device="xpu", dtype=torch.float32)
z = torch.zeros((65536,), device="xpu", dtype=torch.float32)
kernel[(65536,)](x, y, z, num_warps=32)
x = torch.randn((65536, ), device="xpu", dtype=torch.float32)
y = torch.randn((65536, ), device="xpu", dtype=torch.float32)
z = torch.zeros((65536, ), device="xpu", dtype=torch.float32)
kernel[(65536, )](x, y, z, num_warps=32)
assert torch.all(x + y == z)
else:
return