[FRONTEND] Fix addptr code generation (#2122)

`offset + ptr` and `ptr + offset` both work now
This commit is contained in:
Keren Zhou
2023-08-17 00:22:08 -04:00
committed by GitHub
parent eb940823c4
commit 2d513dbf50
3 changed files with 36 additions and 8 deletions

View File

@@ -10,11 +10,11 @@ using ::mlir::triton::gpu::getTotalElemsPerThread;
const std::string Fp16_to_Fp8E5M2 =
"{ \n"
".reg .b32 a<2>; \n"
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
"}";
const std::string Fp8E5M2_to_Fp16 = "{ \n"
@@ -170,7 +170,8 @@ const std::string Fp8E4M3_to_Fp16 =
"add.u32 b0, b0, a0; \n" // b0 = b0 + a0
"add.u32 b1, b1, a1; \n" // (move sign to the left)
"mad.lo.u32 $0, b0, 128, 0x20002000; \n" // out0 = (b0 << 7) + 0x20002000
"mad.lo.u32 $1, b1, 128, 0x20002000; \n" // (shift into position and bias exponent)
"mad.lo.u32 $1, b1, 128, 0x20002000; \n" // (shift into position and bias
// exponent)
"}";
// Fp16 -> Fp8E4M3 (packed)
@@ -180,7 +181,8 @@ const std::string Fp16_to_Fp8E4M3 =
"and.b32 a0, $1, 0x7fff7fff; \n" // a0 = input0 & 0x7fff7fff
"and.b32 a1, $2, 0x7fff7fff; \n" // (strip sign)
"mad.lo.u32 a0, a0, 2, 0x40804080; \n" // shift exponent (<< 1),
"mad.lo.u32 a1, a1, 2, 0x40804080; \n" // correct bias (0x40004000), and round to nearest
"mad.lo.u32 a1, a1, 2, 0x40804080; \n" // correct bias (0x40004000),
// and round to nearest
"lop3.b32 b0, $1, 0x80008000, a0, 0xe2; \n" // b0 = 0x80008000 ? in0 : a0
"lop3.b32 b1, $2, 0x80008000, a1, 0xe2; \n" // (restore sign)
"prmt.b32 $0, b0, b1, 0x7531; \n" // output = b1b0
@@ -197,7 +199,8 @@ const std::string Fp8E4M3_to_Bf16 =
"mad.lo.u32 b0, b0, 15, a0; \n" // b0 = b0 * 15 + a0
"mad.lo.u32 b1, b1, 15, a1; \n" // (move sign to the left)
"mad.lo.u32 $0, b0, 16, 0x3c003c00; \n" // out0 = (b0 << 4) + 0x3c003c00
"mad.lo.u32 $1, b1, 16, 0x3c003c00; \n" // (shift into position and bias exponent)
"mad.lo.u32 $1, b1, 16, 0x3c003c00; \n" // (shift into position and bias
// exponent)
"}";
const std::string Bf16_to_Fp8E4M3 =

View File

@@ -352,6 +352,29 @@ def test_bin_op(dtype_x, dtype_y, op, num_ctas, device):
num_ctas=num_ctas)
@pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]])
def test_addptr(dtype, order, device):
check_type_supported(dtype, device)
@triton.jit
def kernel(x, y, ORDER: tl.constexpr, SIZE: tl.constexpr):
offs = tl.arange(0, SIZE)
if ORDER == 0:
tl.store(y + offs, tl.load(x + offs))
else:
tl.store(offs + y, tl.load(offs + x))
SIZE = 1024
rs = RandomState(17)
x = numpy_random(SIZE, dtype_str=dtype, rs=rs)
y = numpy_random(SIZE, dtype_str=dtype, rs=rs)
x_tri = to_triton(x, dst_type=dtype, device=device)
y_tri = to_triton(y, dst_type=dtype, device=device)
y = x
kernel[1,](x_tri, y_tri, order, SIZE)
np.testing.assert_allclose(y, to_numpy(y_tri))
@pytest.mark.parametrize("dtype_x, dtype_y",
[(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
[(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]

View File

@@ -133,6 +133,8 @@ def add(input: tl.tensor,
# ptr + offset
if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr():
input, other = other, input
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
if input_scalar_ty.is_ptr():
return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type)
# float + float