mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Reformat Python code with yapf. (#2589)
I've add an option to yapf to do what we want for long lines, see https://github.com/google/yapf/pull/1177. We can now have a real Python formatter, yay! To make this PR, I ran my modified yapf over the repository, then looked over the full diff. Where yapf was mangling the param list of long function decls/calls (mostly kernels), I manually added `#` to put linebreaks where we want. I fixed up other formatting too -- mostly adding or removing a trailing comma from lists. Overall, trailing `#` was sufficient to get formatting similar to our current code. I didn't have to disable yapf anywhere. --------- Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
@@ -13,8 +13,7 @@ import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.common.backend import (BaseBackend, compute_core_version_key,
|
||||
register_backend)
|
||||
from triton.common.backend import (BaseBackend, compute_core_version_key, register_backend)
|
||||
from triton.common.build import quiet
|
||||
from triton.compiler.make_launcher import make_so_cache_key
|
||||
from triton.runtime.cache import get_cache_manager
|
||||
@@ -81,6 +80,7 @@ def build_for_backend(name, src, srcdir):
|
||||
|
||||
|
||||
class ExtensionUtils:
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
cls.instance = super(ExtensionUtils, cls).__new__(cls)
|
||||
@@ -110,6 +110,7 @@ class ExtensionUtils:
|
||||
|
||||
|
||||
class ExtensionDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
cls.instance = super(ExtensionDriver, cls).__new__(cls)
|
||||
@@ -256,13 +257,13 @@ def test_dummy_backend():
|
||||
|
||||
inp = torch.randn(10)
|
||||
out = torch.randn(10)
|
||||
kernel[(10,)](inp, out, 10, XBLOCK=16)
|
||||
kernel[(10, )](inp, out, 10, XBLOCK=16)
|
||||
spec = importlib.util.spec_from_file_location("__triton_launcher", ExtensionBackend.stub_so_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
launch_counter = getattr(mod, "launch_counter")
|
||||
|
||||
for _ in range(100):
|
||||
kernel[(10,)](inp, out, 10, XBLOCK=16)
|
||||
kernel[(10, )](inp, out, 10, XBLOCK=16)
|
||||
|
||||
assert launch_counter() > 0
|
||||
|
||||
@@ -4,9 +4,7 @@ import pytest
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--backend", action="store", default="", help="Codegen backend"
|
||||
)
|
||||
parser.addoption("--backend", action="store", default="", help="Codegen backend")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -24,10 +24,10 @@ def test_xpu_backend(cmdopt):
|
||||
|
||||
if has_ipex:
|
||||
for _ in range(1000):
|
||||
x = torch.randn((65536,), device="xpu", dtype=torch.float32)
|
||||
y = torch.randn((65536,), device="xpu", dtype=torch.float32)
|
||||
z = torch.zeros((65536,), device="xpu", dtype=torch.float32)
|
||||
kernel[(65536,)](x, y, z, num_warps=32)
|
||||
x = torch.randn((65536, ), device="xpu", dtype=torch.float32)
|
||||
y = torch.randn((65536, ), device="xpu", dtype=torch.float32)
|
||||
z = torch.zeros((65536, ), device="xpu", dtype=torch.float32)
|
||||
kernel[(65536, )](x, y, z, num_warps=32)
|
||||
assert torch.all(x + y == z)
|
||||
else:
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user