[FRONTEND] Mangle signed and unsigned integer types differently (#1340)

This is cherry-picked from #1305

If you call a `JITFunction` twice in the same kernel, first with `int32`
then with `uint32`, the second call will treat the unsigned value as
signed. This passes through MLIR without error because MLIR uses the
same types for both, but different operation calls will be generated so
you may silently get the wrong result.
This commit is contained in:
peterbell10
2023-03-15 05:29:18 +00:00
committed by GitHub
parent ad81447ad0
commit 01b177afe7
2 changed files with 40 additions and 1 deletions

View File

@@ -295,6 +295,43 @@ def test_floordiv(dtype_x, dtype_y, device='cuda'):
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
def test_unsigned_name_mangling(device='cuda'):
# Test that uint32 and int32 are mangled differently by the compiler
SIZE = 128
# define the kernel / launch-grid
@triton.jit
def kernel(O1, O2, X, Y, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
x = tl.load(X + off)
y = tl.load(Y + off)
out1 = tl.abs(x) # uint32 -> nop
out2 = tl.abs(-y) # int32 -> should have an effect
tl.store(O1 + off, out1)
tl.store(O2 + off, out2)
dtype_x = 'uint32'
dtype_y = 'int32'
# inputs
rs = RandomState(17)
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs)
# reference result
expect = (np.abs(x), np.abs(-y))
# triton result
x_tri = to_triton(x, device=device, dst_type=dtype_x)
y_tri = to_triton(y, device=device, dst_type=dtype_y)
actual = tuple(
to_triton(np.empty_like(e), device=device)
for e in expect
)
kernel[(1, )](actual[0], actual[1], x_tri, y_tri, SIZE=SIZE, num_warps=4)
# Bitwise op, so expect exact equality
assert (expect[0] == to_numpy(actual[0])).all()
assert (expect[1] == to_numpy(actual[1])).all()
# ---------------
# test bitwise ops
# ---------------

View File

@@ -57,7 +57,9 @@ def mangle_ty(ty):
if ty.is_ptr():
return 'P' + mangle_ty(ty.element_ty)
if ty.is_int():
return 'i' + str(ty.int_bitwidth)
SIGNED = triton.language.dtype.SIGNEDNESS.SIGNED
prefix = 'i' if ty.int_signedness == SIGNED else 'u'
return prefix + str(ty.int_bitwidth)
if ty.is_fp8():
return 'fp8'
if ty.is_fp16():