[BACKEND] Fixing bug in elementwise conversion (#2517)

This commit is contained in:
Zahi Moudallal
2023-10-20 09:11:15 -07:00
committed by GitHub
parent dc9e3063d7
commit b0c166b9e3
3 changed files with 39 additions and 25 deletions

View File

@@ -64,16 +64,23 @@ static const std::string Fp8E5M2_to_Bf16(bool hasNativeFP) {
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
"}";
} else {
ret = "{ \n"
".reg .b32 a; \n"
".reg .f16 a<2>; \n"
".reg .b16 b<2>; \n"
"cvt.rn.f16x2.e5m2x2 a, $1; \n"
"mov.b32 {a0, a1}, a; \n"
"cvt.bf16.f16 b0, a0; \n"
"cvt.bf16.f16 b1, a1; \n"
"mov.b32 $0, {b0, b1}; \n"
"}";
ret =
"{ \n"
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
".reg .b32 e112; \n"
"mov.u32 e112, 0x77807780; \n" // 2**112 represented as
// bf16x2
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
"shr.b32 b1, b1, 3; \n" // shift into bf16 position
"lop3.b32 b0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
"lop3.b32 b1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
"mul.rn.bf16x2 $0, b0, e112; \n" // b0.exp += 2**7-2**4
"mul.rn.bf16x2 $1, b1, e112; \n" // exponent compensate = 112
"}";
}
return ret;
}
@@ -129,7 +136,7 @@ static const std::string Bf16_to_Fp8E5M2(bool hasNativeFP) {
"mov.b32 {a0, a1}, $1; \n"
"cvt.f32.bf16 b0, a0; \n"
"cvt.f32.bf16 b1, a1; \n"
"cvt.rn.satfinite.e5m2x2.f32 $0, b0, b1; \n"
"cvt.rn.satfinite.e5m2x2.f32 $0, b1, b0; \n"
"}";
}
return ret;
@@ -257,7 +264,7 @@ static const std::string Bf16_to_Fp8E4M3Nv =
"mov.b32 {a0, a1}, $1; \n"
"cvt.f32.bf16 b0, a0; \n"
"cvt.f32.bf16 b1, a1; \n"
"cvt.rn.satfinite.e4m3x2.f32 $0, b0, b1; \n"
"cvt.rn.satfinite.e4m3x2.f32 $0, b1, b0; \n"
"}";
/* ----- Packed integer to BF16 ------ */
@@ -677,7 +684,7 @@ struct FpToFpOpConversion
int inVecWidthBits = 32;
int outVecWidthBits = 32;
if (srcTy.isFloat8E4M3FNUZ() ||
(computeCapability >= 90 && srcTy.isFloat8E5M2())) {
(computeCapability >= 90 && srcTy.isFloat8E5M2() && dstTy.isF16())) {
inVecWidthBits = 16;
outVecWidthBits = 32;
}
@@ -717,7 +724,9 @@ struct FpToFpOpConversion
if (srcElementType.isFloat8E4M3FNUZ() ||
dstElementType.isFloat8E4M3FNUZ() ||
(computeCapability >= 90 &&
(srcElementType.isFloat8E5M2() || dstElementType.isFloat8E5M2()))) {
((srcElementType.isFloat8E5M2() &&
(dstElementType.isF16() || dstElementType.isF32())) ||
dstElementType.isFloat8E5M2()))) {
numElements = 2;
}
bool useFP16IntermediateSrc =
@@ -725,9 +734,9 @@ struct FpToFpOpConversion
!(computeCapability >= 90 &&
(dstElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E5M2()));
bool isDstFP32 = dstElementType.isF32();
auto cvtFunc =
getConversionFunc(useFP16IntermediateSrc ? f16_ty : srcElementType,
isDstFP32 ? f16_ty : dstElementType);
Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;
Type dstType = isDstFP32 ? f16_ty : dstElementType;
auto cvtFunc = getConversionFunc(srcType, dstType);
SmallVector<Value> inVals;
for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) {
inVals.push_back(operands[i][0]);
@@ -735,8 +744,7 @@ struct FpToFpOpConversion
if (useFP16IntermediateSrc)
for (Value &v : inVals)
v = convertFp32ToFp16(loc, rewriter, v);
inVals.resize(numElements,
undef(typeConverter->convertType(srcElementType)));
inVals.resize(numElements, undef(typeConverter->convertType(srcType)));
SmallVector<Value> outVals = cvtFunc(loc, rewriter, inVals);
assert(outVals.size() == inVals.size());
outVals.resize(std::min(numElements, operands.size()));

View File

@@ -1329,8 +1329,11 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device):
if dtype_z in uint_dtypes:
x = np.absolute(x)
x_tri = to_triton(x, device=device)
if 'float' in dtype_z and 'float' in dtype_x:
# make sure we use values that can be represented in both types
x_tri = x_tri.to(getattr(torch, dtype_z)).to(getattr(torch, dtype_x))
# triton kernel
@triton.jit
def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr):
x_ptr = X + tl.arange(0, SIZE)
@@ -1344,7 +1347,7 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device):
if dtype_z.startswith('bfloat'):
z_tri = torch.empty((size,), dtype=getattr(torch, dtype_z), device=device)
elif dtype_z.startswith('float8'):
z_tri = torch.empty((size,), dtype=torch.float, device=device)
z_tri = torch.empty((size,), dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_z))
else:
z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device)
kernel[(1, )](x_tri, z_tri, BITCAST=bitcast, SIZE=size, num_warps=1, num_ctas=num_ctas)

View File

@@ -188,6 +188,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: test_fp8_to_f16_conversion
tt.func @test_fp8_to_f16_conversion(
@@ -197,16 +198,18 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%out0 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xf16, #blocked>
// CHECK-COUNT-2: cvt.rn.f16x2.e4m3x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16>
%out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FNUZ, #blocked> -> tensor<128xf16, #blocked>
// CHECK-COUNT-2: mul.rn.bf16x2
%out2 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xbf16, #blocked>
// CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8>
%out2 = tt.fp_to_fp %in2 : tensor<128xf16, #blocked> -> tensor<128xf8E5M2, #blocked>
%out3 = tt.fp_to_fp %in2 : tensor<128xf16, #blocked> -> tensor<128xf8E5M2, #blocked>
// CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8>
%out3 = tt.fp_to_fp %in2 : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked>
%out4 = tt.fp_to_fp %in2 : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked>
// CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8>
%out4 = tt.fp_to_fp %in3 : tensor<128xf32, #blocked> -> tensor<128xf8E5M2, #blocked>
%out5 = tt.fp_to_fp %in3 : tensor<128xf32, #blocked> -> tensor<128xf8E5M2, #blocked>
// CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8>
%out5 = tt.fp_to_fp %in3 : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked>
%out6 = tt.fp_to_fp %in3 : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked>
tt.return
}
}