mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Better error messages for noinline functions (#1657)
```
at 10:18:def val_multiplier_noinline(val, i):
return val * i
^
Function val_multiplier_noinline is marked noinline, but was called with non-scalar argument val:fp32[constexpr[128]]
```
This commit is contained in:
@@ -2334,31 +2334,47 @@ def val_multiplier(val, i):
|
||||
return val * i
|
||||
|
||||
|
||||
@triton.jit(noinline=True)
|
||||
def val_multiplier_noinline(val, i):
|
||||
return val * i
|
||||
|
||||
|
||||
@triton.jit
|
||||
def vecmul_kernel(ptr, n_elements, rep):
|
||||
def vecmul_kernel(ptr, n_elements, rep, type: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
offsets = pid * 128 + tl.arange(0, 128)
|
||||
mask = offsets < n_elements
|
||||
vec = tl.load(ptr + offsets, mask=mask)
|
||||
for i in range(1, rep):
|
||||
vec = val_multiplier(vec, i)
|
||||
if type == "inline":
|
||||
vec = val_multiplier(vec, i)
|
||||
else:
|
||||
vec = val_multiplier_noinline(vec, i)
|
||||
tl.store(ptr + offsets, vec, mask=mask)
|
||||
|
||||
|
||||
def test_call():
|
||||
@pytest.mark.parametrize("type", ["inline", "noinline"])
|
||||
def test_call(type):
|
||||
|
||||
@triton.jit
|
||||
def kernel(ptr, n_elements, num1, num2):
|
||||
vecmul_kernel(ptr, n_elements, num1)
|
||||
vecmul_kernel(ptr, n_elements, num2)
|
||||
def kernel(ptr, n_elements, num1, num2, type: tl.constexpr):
|
||||
vecmul_kernel(ptr, n_elements, num1, type)
|
||||
vecmul_kernel(ptr, n_elements, num2, type)
|
||||
|
||||
size = 1024
|
||||
rand_val = numpy_random((size,), dtype_str="float32")
|
||||
rand_val_tri = to_triton(rand_val, device='cuda')
|
||||
kernel[(size // 128,)](rand_val_tri, size, 3, 5)
|
||||
err_msg = ""
|
||||
try:
|
||||
kernel[(size // 128,)](rand_val_tri, size, 3, 5, type)
|
||||
except Exception as e:
|
||||
err_msg = str(e)
|
||||
|
||||
ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
|
||||
np.testing.assert_equal(to_numpy(rand_val_tri), ans)
|
||||
if type == "noinline":
|
||||
assert err_msg is not ""
|
||||
else:
|
||||
ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
|
||||
np.testing.assert_equal(to_numpy(rand_val_tri), ans)
|
||||
|
||||
# -------------
|
||||
# test if
|
||||
|
||||
Reference in New Issue
Block a user