mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Fix addptr code generation (#2122)
`offset + ptr` and `ptr + offset` both work now
This commit is contained in:
@@ -10,11 +10,11 @@ using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
const std::string Fp16_to_Fp8E5M2 =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>; \n"
|
||||
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
|
||||
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
|
||||
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
|
||||
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
|
||||
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
|
||||
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
|
||||
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
|
||||
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
|
||||
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
|
||||
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
|
||||
"}";
|
||||
|
||||
const std::string Fp8E5M2_to_Fp16 = "{ \n"
|
||||
@@ -170,7 +170,8 @@ const std::string Fp8E4M3_to_Fp16 =
|
||||
"add.u32 b0, b0, a0; \n" // b0 = b0 + a0
|
||||
"add.u32 b1, b1, a1; \n" // (move sign to the left)
|
||||
"mad.lo.u32 $0, b0, 128, 0x20002000; \n" // out0 = (b0 << 7) + 0x20002000
|
||||
"mad.lo.u32 $1, b1, 128, 0x20002000; \n" // (shift into position and bias exponent)
|
||||
"mad.lo.u32 $1, b1, 128, 0x20002000; \n" // (shift into position and bias
|
||||
// exponent)
|
||||
"}";
|
||||
|
||||
// Fp16 -> Fp8E4M3 (packed)
|
||||
@@ -180,7 +181,8 @@ const std::string Fp16_to_Fp8E4M3 =
|
||||
"and.b32 a0, $1, 0x7fff7fff; \n" // a0 = input0 & 0x7fff7fff
|
||||
"and.b32 a1, $2, 0x7fff7fff; \n" // (strip sign)
|
||||
"mad.lo.u32 a0, a0, 2, 0x40804080; \n" // shift exponent (<< 1),
|
||||
"mad.lo.u32 a1, a1, 2, 0x40804080; \n" // correct bias (0x40004000), and round to nearest
|
||||
"mad.lo.u32 a1, a1, 2, 0x40804080; \n" // correct bias (0x40004000),
|
||||
// and round to nearest
|
||||
"lop3.b32 b0, $1, 0x80008000, a0, 0xe2; \n" // b0 = 0x80008000 ? in0 : a0
|
||||
"lop3.b32 b1, $2, 0x80008000, a1, 0xe2; \n" // (restore sign)
|
||||
"prmt.b32 $0, b0, b1, 0x7531; \n" // output = b1b0
|
||||
@@ -197,7 +199,8 @@ const std::string Fp8E4M3_to_Bf16 =
|
||||
"mad.lo.u32 b0, b0, 15, a0; \n" // b0 = b0 * 15 + a0
|
||||
"mad.lo.u32 b1, b1, 15, a1; \n" // (move sign to the left)
|
||||
"mad.lo.u32 $0, b0, 16, 0x3c003c00; \n" // out0 = (b0 << 4) + 0x3c003c00
|
||||
"mad.lo.u32 $1, b1, 16, 0x3c003c00; \n" // (shift into position and bias exponent)
|
||||
"mad.lo.u32 $1, b1, 16, 0x3c003c00; \n" // (shift into position and bias
|
||||
// exponent)
|
||||
"}";
|
||||
|
||||
const std::string Bf16_to_Fp8E4M3 =
|
||||
|
||||
@@ -352,6 +352,29 @@ def test_bin_op(dtype_x, dtype_y, op, num_ctas, device):
|
||||
num_ctas=num_ctas)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]])
|
||||
def test_addptr(dtype, order, device):
|
||||
check_type_supported(dtype, device)
|
||||
|
||||
@triton.jit
|
||||
def kernel(x, y, ORDER: tl.constexpr, SIZE: tl.constexpr):
|
||||
offs = tl.arange(0, SIZE)
|
||||
if ORDER == 0:
|
||||
tl.store(y + offs, tl.load(x + offs))
|
||||
else:
|
||||
tl.store(offs + y, tl.load(offs + x))
|
||||
|
||||
SIZE = 1024
|
||||
rs = RandomState(17)
|
||||
x = numpy_random(SIZE, dtype_str=dtype, rs=rs)
|
||||
y = numpy_random(SIZE, dtype_str=dtype, rs=rs)
|
||||
x_tri = to_triton(x, dst_type=dtype, device=device)
|
||||
y_tri = to_triton(y, dst_type=dtype, device=device)
|
||||
y = x
|
||||
kernel[1,](x_tri, y_tri, order, SIZE)
|
||||
np.testing.assert_allclose(y, to_numpy(y_tri))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y",
|
||||
[(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
|
||||
[(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
|
||||
|
||||
@@ -133,6 +133,8 @@ def add(input: tl.tensor,
|
||||
# ptr + offset
|
||||
if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr():
|
||||
input, other = other, input
|
||||
input_scalar_ty = input.type.scalar
|
||||
other_scalar_ty = other.type.scalar
|
||||
if input_scalar_ty.is_ptr():
|
||||
return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type)
|
||||
# float + float
|
||||
|
||||
Reference in New Issue
Block a user