mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Enable ruff linter instead of flake8 (#2574)
[FRONTEND] Enable ruff linter instead of flake8.
This fixes a few issues automatically, and also flagged two issues to
fix manually in test_core.py: We had two duplicate function names! One
of these function bodies was a duplicate, so I deleted it. The other
function body was not a duplicate, so I gave it a new name.
AIUI all of these errors should have been picked up by flake8. I'm
confused why it wasn't working. Anyway this is working, and it's faster
than flake8, so it seems like an improvement in all dimensions.
This commit is contained in:
@@ -12,4 +12,4 @@ max_line_length = 88
|
||||
line-length = 120
|
||||
|
||||
[tool.ruff.lint]
|
||||
ignore = ["E501", "E701", "E731"]
|
||||
ignore = ["E501", "E701", "E731", "E741"]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# flake8: noqa: F821,F841
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -10,7 +9,6 @@ import torch
|
||||
from numpy.random import RandomState
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
import triton.language as tl
|
||||
from triton.common.build import is_hip
|
||||
from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret
|
||||
@@ -443,44 +441,6 @@ def test_unsigned_name_mangling(device='cuda'):
|
||||
assert (expect[1] == to_numpy(actual[1])).all()
|
||||
|
||||
|
||||
def test_unsigned_name_mangling(device):
|
||||
# Test that uint32 and int32 are mangled differently by the compiler
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
|
||||
@triton.jit
|
||||
def kernel(O1, O2, X, Y, SIZE: tl.constexpr):
|
||||
off = tl.arange(0, SIZE)
|
||||
x = tl.load(X + off)
|
||||
y = tl.load(Y + off)
|
||||
out1 = tl.abs(x) # uint32 -> nop
|
||||
out2 = tl.abs(-y) # int32 -> should have an effect
|
||||
tl.store(O1 + off, out1)
|
||||
tl.store(O2 + off, out2)
|
||||
|
||||
dtype_x = 'uint32'
|
||||
dtype_y = 'int32'
|
||||
# inputs
|
||||
rs = RandomState(17)
|
||||
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
|
||||
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs)
|
||||
# reference result
|
||||
expect = (np.abs(x), np.abs(-y))
|
||||
# triton result
|
||||
x_tri = to_triton(x, device=device, dst_type=dtype_x)
|
||||
y_tri = to_triton(y, device=device, dst_type=dtype_y)
|
||||
actual = tuple(
|
||||
to_triton(np.empty_like(e), device=device)
|
||||
for e in expect
|
||||
)
|
||||
kernel[(1, )](actual[0], actual[1], x_tri, y_tri, SIZE=SIZE, num_warps=4)
|
||||
|
||||
# Bitwise op, so expect exact equality
|
||||
assert (expect[0] == to_numpy(actual[0])).all()
|
||||
assert (expect[1] == to_numpy(actual[1])).all()
|
||||
|
||||
|
||||
# ---------------
|
||||
# test bitwise ops
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||
@@ -1597,7 +1557,7 @@ def test_reduce1d(op, dtype_str, shape, num_ctas, device):
|
||||
check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
if is_hip():
|
||||
pytest.skip(f"test_reduce1d not supported on HIP")
|
||||
pytest.skip("test_reduce1d not supported on HIP")
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
@@ -1698,7 +1658,7 @@ def test_reduce(op, dtype_str, shape, axis, num_ctas, device):
|
||||
check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
if is_hip():
|
||||
pytest.skip(f"test_reduce2d not supported on HIP")
|
||||
pytest.skip("test_reduce2d not supported on HIP")
|
||||
# triton kernel
|
||||
|
||||
@triton.jit
|
||||
@@ -2145,11 +2105,11 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis):
|
||||
|
||||
op_str = ""
|
||||
if op == "sum":
|
||||
op_str = f"""
|
||||
op_str = """
|
||||
%13 = arith.addi %arg2, %arg3 : i32
|
||||
tt.reduce.return %13 : i32"""
|
||||
elif op == "max":
|
||||
op_str = f"""
|
||||
op_str = """
|
||||
%13 = arith.cmpi "sgt", %arg2, %arg3 : i32
|
||||
%14 = arith.select %13, %arg2, %arg3 : i32
|
||||
tt.reduce.return %14 : i32"""
|
||||
@@ -2245,7 +2205,7 @@ def test_generic_reduction(device):
|
||||
def test_permute(dtype_str, shape, perm, num_ctas, device):
|
||||
check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested
|
||||
if is_hip():
|
||||
pytest.skip(f"test_permute is not supported in HIP")
|
||||
pytest.skip("test_permute is not supported in HIP")
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
@@ -2483,7 +2443,6 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
end = ptx.find("cvt.rn.f16.f32")
|
||||
red_code = ptx[start:end]
|
||||
assert len(red_code) > 0
|
||||
import os
|
||||
|
||||
# skip this check on hopper because there are some functions whose name contain "shared" in ptx.
|
||||
# TODO: we should eliminate these unused functions in ptx code.
|
||||
@@ -2594,7 +2553,6 @@ def test_dot_mulbroadcastred(in_dtype, device):
|
||||
# as the loaded value is in rowmajor. But MMAv3 requires it's second
|
||||
# operand is in colmajor because transpose is not supported for MMAv3
|
||||
# with float32 input.
|
||||
import os
|
||||
if capability[0] >= 9:
|
||||
assert "triton_gpu.async_wait {num = 1 : i32}" in h.asm['ttgir']
|
||||
else:
|
||||
@@ -3098,7 +3056,7 @@ def test_constexpr_scalar_shape(device):
|
||||
|
||||
@triton.jit
|
||||
def static_assert_func():
|
||||
tl.static_assert(tl.constexpr(False), f"Assert is firing because the constexpr progation did not work properly")
|
||||
tl.static_assert(tl.constexpr(False), "Assert is firing because the constexpr progation did not work properly")
|
||||
|
||||
|
||||
def test_constexpr_propagation():
|
||||
@@ -3260,7 +3218,7 @@ def test_math_tensor(dtype_str, expr, lib_path, num_ctas, device):
|
||||
y_ref = x * pow(2, 2)
|
||||
elif expr == 'math.pow_dtype':
|
||||
x = np.abs(x)
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.math.pow(x, 0.5)'})
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.math.pow(x, 0.5)'})
|
||||
y_ref = np.power(x, 0.5)
|
||||
elif expr == 'math.pow':
|
||||
# numpy does not allow negative factors in power, so we use abs()
|
||||
@@ -3625,7 +3583,7 @@ def test_while(device):
|
||||
assert out_j[0] == bound[0]
|
||||
|
||||
|
||||
def test_while(device):
|
||||
def test_while2(device):
|
||||
@triton.jit
|
||||
def nested_while(data, countPtr):
|
||||
for i in range(10):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .autotuner import Autotuner, Config, Heuristics, OutOfResources, autotune, heuristics
|
||||
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune,
|
||||
heuristics)
|
||||
from .driver import driver
|
||||
from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret
|
||||
|
||||
|
||||
Reference in New Issue
Block a user