[FRONTEND] Better error messages for noinline functions (#1657)

```
at 10:18:def val_multiplier_noinline(val, i):
    return val * i

           ^
Function val_multiplier_noinline is marked noinline, but was called with non-scalar argument val:fp32[constexpr[128]]
```
This commit is contained in:
Keren Zhou
2023-05-11 15:46:25 -04:00
committed by GitHub
parent 35b27e1eed
commit 674f9bf7a6
2 changed files with 37 additions and 9 deletions

View File

@@ -2334,31 +2334,47 @@ def val_multiplier(val, i):
return val * i
@triton.jit(noinline=True)
def val_multiplier_noinline(val, i):
return val * i
@triton.jit
def vecmul_kernel(ptr, n_elements, rep):
def vecmul_kernel(ptr, n_elements, rep, type: tl.constexpr):
pid = tl.program_id(axis=0)
offsets = pid * 128 + tl.arange(0, 128)
mask = offsets < n_elements
vec = tl.load(ptr + offsets, mask=mask)
for i in range(1, rep):
vec = val_multiplier(vec, i)
if type == "inline":
vec = val_multiplier(vec, i)
else:
vec = val_multiplier_noinline(vec, i)
tl.store(ptr + offsets, vec, mask=mask)
def test_call():
@pytest.mark.parametrize("type", ["inline", "noinline"])
def test_call(type):
@triton.jit
def kernel(ptr, n_elements, num1, num2):
vecmul_kernel(ptr, n_elements, num1)
vecmul_kernel(ptr, n_elements, num2)
def kernel(ptr, n_elements, num1, num2, type: tl.constexpr):
vecmul_kernel(ptr, n_elements, num1, type)
vecmul_kernel(ptr, n_elements, num2, type)
size = 1024
rand_val = numpy_random((size,), dtype_str="float32")
rand_val_tri = to_triton(rand_val, device='cuda')
kernel[(size // 128,)](rand_val_tri, size, 3, 5)
err_msg = ""
try:
kernel[(size // 128,)](rand_val_tri, size, 3, 5, type)
except Exception as e:
err_msg = str(e)
ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
np.testing.assert_equal(to_numpy(rand_val_tri), ans)
if type == "noinline":
assert err_msg is not ""
else:
ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
np.testing.assert_equal(to_numpy(rand_val_tri), ans)
# -------------
# test if