mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Mangle signed and unsigned integer types differently (#1340)
This is cherry-picked from #1305 If you call a `JITFunction` twice in the same kernel, first with `int32` then with `uint32`, the second call will treat the unsigned value as signed. This passes through MLIR without error because MLIR uses the same types for both, but different operation calls will be generated so you may silently get the wrong result.
This commit is contained in:
@@ -295,6 +295,43 @@ def test_floordiv(dtype_x, dtype_y, device='cuda'):
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||
|
||||
|
||||
def test_unsigned_name_mangling(device='cuda'):
|
||||
# Test that uint32 and int32 are mangled differently by the compiler
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
|
||||
@triton.jit
|
||||
def kernel(O1, O2, X, Y, SIZE: tl.constexpr):
|
||||
off = tl.arange(0, SIZE)
|
||||
x = tl.load(X + off)
|
||||
y = tl.load(Y + off)
|
||||
out1 = tl.abs(x) # uint32 -> nop
|
||||
out2 = tl.abs(-y) # int32 -> should have an effect
|
||||
tl.store(O1 + off, out1)
|
||||
tl.store(O2 + off, out2)
|
||||
|
||||
dtype_x = 'uint32'
|
||||
dtype_y = 'int32'
|
||||
# inputs
|
||||
rs = RandomState(17)
|
||||
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
|
||||
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs)
|
||||
# reference result
|
||||
expect = (np.abs(x), np.abs(-y))
|
||||
# triton result
|
||||
x_tri = to_triton(x, device=device, dst_type=dtype_x)
|
||||
y_tri = to_triton(y, device=device, dst_type=dtype_y)
|
||||
actual = tuple(
|
||||
to_triton(np.empty_like(e), device=device)
|
||||
for e in expect
|
||||
)
|
||||
kernel[(1, )](actual[0], actual[1], x_tri, y_tri, SIZE=SIZE, num_warps=4)
|
||||
|
||||
# Bitwise op, so expect exact equality
|
||||
assert (expect[0] == to_numpy(actual[0])).all()
|
||||
assert (expect[1] == to_numpy(actual[1])).all()
|
||||
|
||||
|
||||
# ---------------
|
||||
# test bitwise ops
|
||||
# ---------------
|
||||
|
||||
@@ -57,7 +57,9 @@ def mangle_ty(ty):
|
||||
if ty.is_ptr():
|
||||
return 'P' + mangle_ty(ty.element_ty)
|
||||
if ty.is_int():
|
||||
return 'i' + str(ty.int_bitwidth)
|
||||
SIGNED = triton.language.dtype.SIGNEDNESS.SIGNED
|
||||
prefix = 'i' if ty.int_signedness == SIGNED else 'u'
|
||||
return prefix + str(ty.int_bitwidth)
|
||||
if ty.is_fp8():
|
||||
return 'fp8'
|
||||
if ty.is_fp16():
|
||||
|
||||
Reference in New Issue
Block a user