mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Fixing bug in elementwise conversion (#2517)
This commit is contained in:
@@ -64,16 +64,23 @@ static const std::string Fp8E5M2_to_Bf16(bool hasNativeFP) {
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
|
||||
"}";
|
||||
} else {
|
||||
ret = "{ \n"
|
||||
".reg .b32 a; \n"
|
||||
".reg .f16 a<2>; \n"
|
||||
".reg .b16 b<2>; \n"
|
||||
"cvt.rn.f16x2.e5m2x2 a, $1; \n"
|
||||
"mov.b32 {a0, a1}, a; \n"
|
||||
"cvt.bf16.f16 b0, a0; \n"
|
||||
"cvt.bf16.f16 b1, a1; \n"
|
||||
"mov.b32 $0, {b0, b1}; \n"
|
||||
"}";
|
||||
ret =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
|
||||
".reg .b32 e112; \n"
|
||||
"mov.u32 e112, 0x77807780; \n" // 2**112 represented as
|
||||
// bf16x2
|
||||
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
|
||||
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
|
||||
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
|
||||
"shr.b32 b1, b1, 3; \n" // shift into bf16 position
|
||||
"lop3.b32 b0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
|
||||
"lop3.b32 b1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
|
||||
"mul.rn.bf16x2 $0, b0, e112; \n" // b0.exp += 2**7-2**4
|
||||
"mul.rn.bf16x2 $1, b1, e112; \n" // exponent compensate = 112
|
||||
"}";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
@@ -129,7 +136,7 @@ static const std::string Bf16_to_Fp8E5M2(bool hasNativeFP) {
|
||||
"mov.b32 {a0, a1}, $1; \n"
|
||||
"cvt.f32.bf16 b0, a0; \n"
|
||||
"cvt.f32.bf16 b1, a1; \n"
|
||||
"cvt.rn.satfinite.e5m2x2.f32 $0, b0, b1; \n"
|
||||
"cvt.rn.satfinite.e5m2x2.f32 $0, b1, b0; \n"
|
||||
"}";
|
||||
}
|
||||
return ret;
|
||||
@@ -257,7 +264,7 @@ static const std::string Bf16_to_Fp8E4M3Nv =
|
||||
"mov.b32 {a0, a1}, $1; \n"
|
||||
"cvt.f32.bf16 b0, a0; \n"
|
||||
"cvt.f32.bf16 b1, a1; \n"
|
||||
"cvt.rn.satfinite.e4m3x2.f32 $0, b0, b1; \n"
|
||||
"cvt.rn.satfinite.e4m3x2.f32 $0, b1, b0; \n"
|
||||
"}";
|
||||
|
||||
/* ----- Packed integer to BF16 ------ */
|
||||
@@ -677,7 +684,7 @@ struct FpToFpOpConversion
|
||||
int inVecWidthBits = 32;
|
||||
int outVecWidthBits = 32;
|
||||
if (srcTy.isFloat8E4M3FNUZ() ||
|
||||
(computeCapability >= 90 && srcTy.isFloat8E5M2())) {
|
||||
(computeCapability >= 90 && srcTy.isFloat8E5M2() && dstTy.isF16())) {
|
||||
inVecWidthBits = 16;
|
||||
outVecWidthBits = 32;
|
||||
}
|
||||
@@ -717,7 +724,9 @@ struct FpToFpOpConversion
|
||||
if (srcElementType.isFloat8E4M3FNUZ() ||
|
||||
dstElementType.isFloat8E4M3FNUZ() ||
|
||||
(computeCapability >= 90 &&
|
||||
(srcElementType.isFloat8E5M2() || dstElementType.isFloat8E5M2()))) {
|
||||
((srcElementType.isFloat8E5M2() &&
|
||||
(dstElementType.isF16() || dstElementType.isF32())) ||
|
||||
dstElementType.isFloat8E5M2()))) {
|
||||
numElements = 2;
|
||||
}
|
||||
bool useFP16IntermediateSrc =
|
||||
@@ -725,9 +734,9 @@ struct FpToFpOpConversion
|
||||
!(computeCapability >= 90 &&
|
||||
(dstElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E5M2()));
|
||||
bool isDstFP32 = dstElementType.isF32();
|
||||
auto cvtFunc =
|
||||
getConversionFunc(useFP16IntermediateSrc ? f16_ty : srcElementType,
|
||||
isDstFP32 ? f16_ty : dstElementType);
|
||||
Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;
|
||||
Type dstType = isDstFP32 ? f16_ty : dstElementType;
|
||||
auto cvtFunc = getConversionFunc(srcType, dstType);
|
||||
SmallVector<Value> inVals;
|
||||
for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) {
|
||||
inVals.push_back(operands[i][0]);
|
||||
@@ -735,8 +744,7 @@ struct FpToFpOpConversion
|
||||
if (useFP16IntermediateSrc)
|
||||
for (Value &v : inVals)
|
||||
v = convertFp32ToFp16(loc, rewriter, v);
|
||||
inVals.resize(numElements,
|
||||
undef(typeConverter->convertType(srcElementType)));
|
||||
inVals.resize(numElements, undef(typeConverter->convertType(srcType)));
|
||||
SmallVector<Value> outVals = cvtFunc(loc, rewriter, inVals);
|
||||
assert(outVals.size() == inVals.size());
|
||||
outVals.resize(std::min(numElements, operands.size()));
|
||||
|
||||
@@ -1329,8 +1329,11 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device):
|
||||
if dtype_z in uint_dtypes:
|
||||
x = np.absolute(x)
|
||||
x_tri = to_triton(x, device=device)
|
||||
|
||||
if 'float' in dtype_z and 'float' in dtype_x:
|
||||
# make sure we use values that can be represented in both types
|
||||
x_tri = x_tri.to(getattr(torch, dtype_z)).to(getattr(torch, dtype_x))
|
||||
# triton kernel
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr):
|
||||
x_ptr = X + tl.arange(0, SIZE)
|
||||
@@ -1344,7 +1347,7 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device):
|
||||
if dtype_z.startswith('bfloat'):
|
||||
z_tri = torch.empty((size,), dtype=getattr(torch, dtype_z), device=device)
|
||||
elif dtype_z.startswith('float8'):
|
||||
z_tri = torch.empty((size,), dtype=torch.float, device=device)
|
||||
z_tri = torch.empty((size,), dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_z))
|
||||
else:
|
||||
z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device)
|
||||
kernel[(1, )](x_tri, z_tri, BITCAST=bitcast, SIZE=size, num_warps=1, num_ctas=num_ctas)
|
||||
|
||||
@@ -188,6 +188,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
// -----
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
|
||||
module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: test_fp8_to_f16_conversion
|
||||
tt.func @test_fp8_to_f16_conversion(
|
||||
@@ -197,16 +198,18 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%out0 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xf16, #blocked>
|
||||
// CHECK-COUNT-2: cvt.rn.f16x2.e4m3x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16>
|
||||
%out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FNUZ, #blocked> -> tensor<128xf16, #blocked>
|
||||
// CHECK-COUNT-2: mul.rn.bf16x2
|
||||
%out2 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xbf16, #blocked>
|
||||
|
||||
// CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8>
|
||||
%out2 = tt.fp_to_fp %in2 : tensor<128xf16, #blocked> -> tensor<128xf8E5M2, #blocked>
|
||||
%out3 = tt.fp_to_fp %in2 : tensor<128xf16, #blocked> -> tensor<128xf8E5M2, #blocked>
|
||||
// CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8>
|
||||
%out3 = tt.fp_to_fp %in2 : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked>
|
||||
%out4 = tt.fp_to_fp %in2 : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked>
|
||||
|
||||
// CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8>
|
||||
%out4 = tt.fp_to_fp %in3 : tensor<128xf32, #blocked> -> tensor<128xf8E5M2, #blocked>
|
||||
%out5 = tt.fp_to_fp %in3 : tensor<128xf32, #blocked> -> tensor<128xf8E5M2, #blocked>
|
||||
// CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8>
|
||||
%out5 = tt.fp_to_fp %in3 : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked>
|
||||
%out6 = tt.fp_to_fp %in3 : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user