mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge branch 'triton-mlir' into ifu-231117
This commit is contained in:
@@ -4,6 +4,10 @@ using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
|
||||
typedef std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
|
||||
const SmallVector<Value> &)>
|
||||
ConverterT;
|
||||
|
||||
/* ----- FP8E5M2 ------ */
|
||||
// This data-type is the standard FP8E5M2 format
|
||||
#ifdef USE_ROCM
|
||||
@@ -50,8 +54,93 @@ static const std::string Fp16_to_Fp8E5M2(bool hasNativeFP) {
|
||||
}
|
||||
#endif
|
||||
|
||||
// ROCM utility functions for data type conversion
|
||||
#ifdef USE_ROCM
|
||||
static Value convert_val_Fp16_to_Fp8E5M2FNUZ(
|
||||
static Value cvtFp16ToFp32(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
GCNBuilder builder;
|
||||
auto &cvt = *builder.create("v_cvt_f32_f16");
|
||||
auto res = builder.newOperand("=v");
|
||||
auto operand = builder.newOperand(v, "v");
|
||||
cvt(res, operand);
|
||||
return builder.launch(rewriter, loc, f32_ty, false);
|
||||
}
|
||||
|
||||
static Value cvtFp32ToFp16(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
GCNBuilder builder;
|
||||
auto &cvt = *builder.create("v_cvt_f16_f32");
|
||||
auto res = builder.newOperand("=v");
|
||||
auto operand = builder.newOperand(v, "v");
|
||||
cvt(res, operand);
|
||||
return builder.launch(rewriter, loc, f16_ty, false);
|
||||
}
|
||||
|
||||
static SmallVector<Value> convert_val_Fp16_to_Fp8(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
Value v0, Value v1, const std::string& fp8_format) {
|
||||
assert(fp8_format == "fp8" or fp8_format == "bf8");
|
||||
std::string ins_str = "v_cvt_pk_" + fp8_format + "_f32";
|
||||
|
||||
auto f32_0 = cvtFp16ToFp32(loc, rewriter, v0);
|
||||
auto f32_1 = cvtFp16ToFp32(loc, rewriter, v1);
|
||||
|
||||
GCNBuilder builder;
|
||||
auto &cvt = *builder.create(ins_str);
|
||||
auto res = builder.newOperand("=v");
|
||||
auto operand0 = builder.newOperand(f32_0, "v");
|
||||
auto operand1 = builder.newOperand(f32_1, "v");
|
||||
cvt(res, operand0, operand1);
|
||||
auto fp8x4Vec = builder.launch(rewriter, loc, i32_ty, false);
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
auto a1 = bitcast(fp8x4Vec, fp8x4VecTy);
|
||||
|
||||
SmallVector<Value> ret(2);
|
||||
ret[0] = extract_element(i8_ty, a1, i32_val(0));
|
||||
ret[1] = extract_element(i8_ty, a1, i32_val(1));
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
static SmallVector<Value> convert_val_Fp8_to_Fp16(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
Value v0, Value v1, const std::string& fp8_format) {
|
||||
assert(fp8_format == "fp8" or fp8_format == "bf8");
|
||||
std::string ins_str = "v_cvt_pk_f32_" + fp8_format;
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
Value fp8x4Vec = undef(fp8x4VecTy);
|
||||
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0));
|
||||
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1));
|
||||
auto i32v = bitcast(fp8x4Vec, i32_ty);
|
||||
|
||||
GCNBuilder builder1;
|
||||
auto &cvt = *builder1.create(ins_str);
|
||||
auto res = builder1.newOperand("=v");
|
||||
auto operand = builder1.newOperand(i32v, "v");
|
||||
cvt(res, operand);
|
||||
auto i64v = builder1.launch(rewriter, loc, i64_ty, false);
|
||||
auto fp32x2VecTy = vec_ty(f32_ty, 2);
|
||||
auto fp32x2Vec = bitcast(i64v, fp32x2VecTy);
|
||||
|
||||
auto f32_0 = extract_element(f32_ty, fp32x2Vec, i32_val(0));
|
||||
auto f32_1 = extract_element(f32_ty, fp32x2Vec, i32_val(1));
|
||||
|
||||
SmallVector<Value> ret(2);
|
||||
ret[0] = cvtFp32ToFp16(loc, rewriter, f32_0);
|
||||
ret[1] = cvtFp32ToFp16(loc, rewriter, f32_1);
|
||||
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// Depend on whether we focus more on performance, we may skip
|
||||
// the processing of submornal values
|
||||
static Value Fp16_to_Fp8E5M2FNUZ_oneValue(
|
||||
Location loc, ConversionPatternRewriter &rewriter, Value v) {
|
||||
auto vi16 = bitcast(v, i16_ty);
|
||||
auto e = and_(i16_ty, vi16, int_val(16, 0x7C00));
|
||||
@@ -79,28 +168,26 @@ static Value convert_val_Fp16_to_Fp8E5M2FNUZ(
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
Fp16_to_Fp8E5M2FNUZ(Location loc, ConversionPatternRewriter &rewriter,
|
||||
Fp16_to_Fp8E5M2FNUZ_SW(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
SmallVector<Value> result(4);
|
||||
result[0] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[0]);
|
||||
result[1] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[1]);
|
||||
result[2] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[2]);
|
||||
result[3] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[3]);
|
||||
|
||||
SmallVector<Value> result(2);
|
||||
result[0] = Fp16_to_Fp8E5M2FNUZ_oneValue(loc, rewriter, v[0]);
|
||||
result[1] = Fp16_to_Fp8E5M2FNUZ_oneValue(loc, rewriter, v[1]);
|
||||
return result;
|
||||
}
|
||||
#else
|
||||
const std::string Fp16_to_Fp8E5M2FNUZ =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>; \n"
|
||||
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
|
||||
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
|
||||
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
|
||||
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
|
||||
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
|
||||
"}";
|
||||
|
||||
static SmallVector<Value> Fp16_to_Fp8E5M2FNUZ_HW(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value>& v) {
|
||||
return convert_val_Fp16_to_Fp8(loc, rewriter, v[0], v[1], "bf8");
|
||||
}
|
||||
|
||||
ConverterT Fp16_to_Fp8E5M2FNUZ(int computeCapability) {
|
||||
return computeCapability >= 300 ? Fp16_to_Fp8E5M2FNUZ_HW : Fp16_to_Fp8E5M2FNUZ_SW;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
@@ -145,8 +232,7 @@ static const std::string Fp8E5M2_to_Fp16(bool hasNativeFP) {
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
|
||||
static Value convert_val_Fp8E5M2FNUZ_to_Fp16(
|
||||
static Value Fp8E5M2FNUZ_to_Fp16_oneValue(
|
||||
Location loc, ConversionPatternRewriter &rewriter, Value v) {
|
||||
auto fp8x2VecTy = vec_ty(i8_ty, 2);
|
||||
Value a = undef(fp8x2VecTy);
|
||||
@@ -181,17 +267,23 @@ static Value convert_val_Fp8E5M2FNUZ_to_Fp16(
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
Fp8E5M2FNUZ_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
Fp8E5M2FNUZ_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
|
||||
SmallVector<Value> result(4);
|
||||
result[0] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[0]);
|
||||
result[1] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[1]);
|
||||
result[2] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[2]);
|
||||
result[3] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[3]);
|
||||
|
||||
SmallVector<Value> result(2);
|
||||
result[0] = Fp8E5M2FNUZ_to_Fp16_oneValue(loc, rewriter, v[0]);
|
||||
result[1] = Fp8E5M2FNUZ_to_Fp16_oneValue(loc, rewriter, v[1]);
|
||||
return result;
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
Fp8E5M2FNUZ_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
return convert_val_Fp8_to_Fp16(loc, rewriter, v[0], v[1], "bf8");
|
||||
}
|
||||
|
||||
ConverterT Fp8E5M2FNUZ_to_Fp16(int computeCapability) {
|
||||
return (computeCapability >= 300) ? Fp8E5M2FNUZ_to_Fp16_HW : Fp8E5M2FNUZ_to_Fp16_SW;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
@@ -662,7 +754,7 @@ static const std::string Fp16_to_Fp8E4M3B15x4 =
|
||||
// more than a single NaN values.
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static Value convert_val_Fp8E4M3FNUZ_to_Fp16(
|
||||
static Value Fp8E4M3FNUZ_to_Fp16_oneValue(
|
||||
Location loc, ConversionPatternRewriter &rewriter, Value v) {
|
||||
auto fp8x2VecTy = vec_ty(i8_ty, 2);
|
||||
Value a = undef(fp8x2VecTy);
|
||||
@@ -693,37 +785,30 @@ static Value convert_val_Fp8E4M3FNUZ_to_Fp16(
|
||||
return bitcast(io, f16_ty);
|
||||
}
|
||||
|
||||
// Fp8E4M3FNUZ -> Fp16 (packed)
|
||||
static SmallVector<Value>
|
||||
Fp8E4M3FNUZ_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
Fp8E4M3FNUZ_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
SmallVector<Value> result(2);
|
||||
result[0] = convert_val_Fp8E4M3FNUZ_to_Fp16(loc, rewriter, v[0]);
|
||||
result[1] = convert_val_Fp8E4M3FNUZ_to_Fp16(loc, rewriter, v[1]);
|
||||
|
||||
result[0] = Fp8E4M3FNUZ_to_Fp16_oneValue(loc, rewriter, v[0]);
|
||||
result[1] = Fp8E4M3FNUZ_to_Fp16_oneValue(loc, rewriter, v[1]);
|
||||
return result;
|
||||
}
|
||||
#else
|
||||
const std::string Fp8E4M3FNUZ_to_Fp16 =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
|
||||
"prmt.b32 a0, 0, $2, 0x5040; \n" // a0 = 0xf300f400
|
||||
"prmt.b32 a1, 0, $2, 0x7060; \n" // a1 = 0xf100f200
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
|
||||
"shr.b32 b0, b0, 1; \n" // b0 >>= 1
|
||||
"shr.b32 b1, b1, 1; \n" // shift into fp16 position
|
||||
"add.u32 b0, b0, 0x20002000; \n" // b0.exp += 2**4-2**3
|
||||
// exponent compensate = 8
|
||||
"add.u32 b1, b1, 0x20002000; \n" // b1 += 8<<10 | 8<<10<<16
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
|
||||
"}";
|
||||
|
||||
static SmallVector<Value>
|
||||
Fp8E4M3FNUZ_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
return convert_val_Fp8_to_Fp16(loc, rewriter, v[0], v[1], "fp8");
|
||||
}
|
||||
|
||||
static ConverterT
|
||||
Fp8E4M3FNUZ_to_Fp16(int computeCapability) {
|
||||
return computeCapability >= 300 ? Fp8E4M3FNUZ_to_Fp16_HW : Fp8E4M3FNUZ_to_Fp16_SW;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Fp16 -> Fp8E4M3 (packed)
|
||||
#ifdef USE_ROCM
|
||||
static Value convert_val_Fp16_to_Fp8E4M3FNUZ(
|
||||
static Value Fp16_to_Fp8E4M3FNUZ_oneValue(
|
||||
Location loc, ConversionPatternRewriter &rewriter, Value v) {
|
||||
auto vi16 = bitcast(v, i16_ty);
|
||||
auto e10 = and_(vi16, int_val(16, 0x7C00));
|
||||
@@ -756,33 +841,25 @@ static Value convert_val_Fp16_to_Fp8E4M3FNUZ(
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
Fp16_to_Fp8E4M3FNUZ(Location loc, ConversionPatternRewriter &rewriter,
|
||||
Fp16_to_Fp8E4M3FNUZ_SW(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
|
||||
SmallVector<Value> result(2);
|
||||
result[0] = convert_val_Fp16_to_Fp8E4M3FNUZ(loc, rewriter, v[0]);
|
||||
result[1] = convert_val_Fp16_to_Fp8E4M3FNUZ(loc, rewriter, v[1]);
|
||||
result[0] = Fp16_to_Fp8E4M3FNUZ_oneValue(loc, rewriter, v[0]);
|
||||
result[1] = Fp16_to_Fp8E4M3FNUZ_oneValue(loc, rewriter, v[1]);
|
||||
|
||||
return result;
|
||||
}
|
||||
#else
|
||||
const std::string Fp16_to_Fp8E4M3FNUZ =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n" // see Fp8E4M3x4ToFp16x4
|
||||
"sub.u32 a0, $1, 0x20002000; \n" // a0 = input0 - 0x20002000
|
||||
// (compensate offset)
|
||||
"sub.u32 a1, $2, 0x20002000; \n" // a1 = input1 - 0x20002000
|
||||
// (8 << 10 | 8 << 10 << 16)
|
||||
"shl.b32 a0, a0, 1; \n" // a0 <<= 1
|
||||
"shl.b32 a1, a1, 1; \n" // shift into fp8e4 position
|
||||
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n" // a0 &= 0x7fff7fff
|
||||
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
|
||||
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
|
||||
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
|
||||
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n" // b0 = a0|(0x80008000&in0)
|
||||
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n" // (restore sign)
|
||||
"prmt.b32 $0, b0, b1, 0x7531; \n" // output = b1b0
|
||||
"}";
|
||||
|
||||
static SmallVector<Value>
|
||||
Fp16_to_Fp8E4M3FNUZ_HW(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
return convert_val_Fp16_to_Fp8(loc, rewriter, v[0], v[1], "fp8");
|
||||
}
|
||||
|
||||
static ConverterT
|
||||
Fp16_to_Fp8E4M3FNUZ(int computeCapability) {
|
||||
return computeCapability >= 300 ? Fp16_to_Fp8E4M3FNUZ_HW : Fp16_to_Fp8E4M3FNUZ_SW;
|
||||
}
|
||||
#endif
|
||||
|
||||
// WARN: subnormal (0bs0000xxx) are not handled
|
||||
@@ -1151,10 +1228,6 @@ inline SmallVector<Value> packI32(const SmallVector<Value> &inValues,
|
||||
return outValues;
|
||||
}
|
||||
|
||||
typedef std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
|
||||
const SmallVector<Value> &)>
|
||||
ConverterT;
|
||||
|
||||
static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType,
|
||||
Type outType,
|
||||
const int inVecWidthBits = 32,
|
||||
@@ -1473,12 +1546,7 @@ struct FpToFpOpConversion
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
#ifdef USE_ROCM
|
||||
GCNBuilder builder;
|
||||
auto &cvt = *builder.create("v_cvt_f32_f16");
|
||||
auto res = builder.newOperand("=v");
|
||||
auto operand = builder.newOperand(v, "v");
|
||||
cvt(res, operand);
|
||||
return builder.launch(rewriter, loc, f32_ty, false);
|
||||
return cvtFp16ToFp32(loc, rewriter, v);
|
||||
#else
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.f32.f16");
|
||||
@@ -1524,12 +1592,7 @@ struct FpToFpOpConversion
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
#ifdef USE_ROCM
|
||||
GCNBuilder builder;
|
||||
auto &cvt = *builder.create("v_cvt_f16_f32");
|
||||
auto res = builder.newOperand("=v");
|
||||
auto operand = builder.newOperand(v, "v");
|
||||
cvt(res, operand);
|
||||
return builder.launch(rewriter, loc, f16_ty, false);
|
||||
return cvtFp32ToFp16(loc, rewriter, v);
|
||||
#else
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.rz.f16.f32");
|
||||
@@ -1542,7 +1605,11 @@ struct FpToFpOpConversion
|
||||
|
||||
ConverterT getConversionFunc(Type srcTy, Type dstTy) const {
|
||||
auto F8E4M3B15TyID = TypeID::get<mlir::Float8E4M3B11FNUZType>();
|
||||
#ifdef USE_ROCM
|
||||
auto F8E4M3FNUZTyID = TypeID::get<mlir::Float8E4M3FNUZType>();
|
||||
#else
|
||||
auto F8E4M3TyID = TypeID::get<mlir::Float8E4M3FNUZType>();
|
||||
#endif
|
||||
auto F8E4M3FNTyID = TypeID::get<mlir::Float8E4M3FNType>();
|
||||
auto F8E5M2TyID = TypeID::get<mlir::Float8E5M2Type>();
|
||||
auto F8E5M2FNUZTyID = TypeID::get<mlir::Float8E5M2FNUZType>();
|
||||
@@ -1558,37 +1625,37 @@ struct FpToFpOpConversion
|
||||
// F8 -> F16
|
||||
{{F8E4M3B15TyID, F16TyID}, Fp8E4M3B15_to_Fp16},
|
||||
{{F8E4M3FNTyID, F16TyID}, Fp8E4M3B15x4_to_Fp16},
|
||||
{{F8E4M3FNUZTyID, F16TyID}, Fp8E4M3FNUZ_to_Fp16},
|
||||
#ifdef USE_ROCM
|
||||
{{F8E4M3FNUZTyID, F16TyID}, Fp8E4M3FNUZ_to_Fp16(computeCapability)},
|
||||
{{F8E5M2FNUZTyID, F16TyID}, Fp8E5M2FNUZ_to_Fp16(computeCapability)},
|
||||
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16},
|
||||
{{F8E5M2FNUZTyID, F16TyID}, Fp8E5M2FNUZ_to_Fp16},
|
||||
#else
|
||||
{{F8E4M3TyID, F16TyID}, Fp8E4M3Nv_to_Fp16},
|
||||
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16(computeCapability >= 90)},
|
||||
#endif
|
||||
|
||||
// F16 -> F8
|
||||
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
|
||||
#ifdef USE_ROCM
|
||||
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15},
|
||||
{{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ(computeCapability)},
|
||||
{{F16TyID, F8E4M3FNUZTyID}, Fp16_to_Fp8E4M3FNUZ(computeCapability)},
|
||||
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2},
|
||||
#else
|
||||
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15(computeCapability >= 80)},
|
||||
#endif
|
||||
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
|
||||
{{F16TyID, F8E4M3FNUZTyID}, Fp16_to_Fp8E4M3FNUZ},
|
||||
#ifdef USE_ROCM
|
||||
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2},
|
||||
{{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ},
|
||||
#else
|
||||
{{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3Nv},
|
||||
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2(computeCapability >= 90)},
|
||||
#endif
|
||||
// F8 -> BF16
|
||||
|
||||
// F8 -> BF16
|
||||
#ifdef USE_ROCM
|
||||
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
|
||||
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
|
||||
#else
|
||||
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)},
|
||||
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)},
|
||||
{{F8E4M3TyID, BF16TyID}, Fp8E4M3Nv_to_Bf16},
|
||||
#endif
|
||||
// BF16 -> F8
|
||||
|
||||
// BF16 -> F8
|
||||
#ifdef USE_ROCM
|
||||
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2},
|
||||
#else
|
||||
@@ -1599,6 +1666,16 @@ struct FpToFpOpConversion
|
||||
{{F32TyID, F8E5M2TyID}, Fp32_to_Fp8E5M2},
|
||||
#endif
|
||||
};
|
||||
|
||||
std::pair<TypeID, TypeID> key = {srcTy.getTypeID(), dstTy.getTypeID()};
|
||||
if (srcMap.count(key) == 0) {
|
||||
llvm::errs() << "Unsupported conversion from " << srcTy << " to " << dstTy
|
||||
<< "\n";
|
||||
llvm_unreachable("");
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
return srcMap.lookup(key);
|
||||
#else
|
||||
int inVecWidthBits = 32;
|
||||
int outVecWidthBits = 32;
|
||||
if (srcTy.isFloat8E4M3FNUZ() ||
|
||||
@@ -1612,15 +1689,6 @@ struct FpToFpOpConversion
|
||||
outVecWidthBits = 16;
|
||||
}
|
||||
|
||||
std::pair<TypeID, TypeID> key = {srcTy.getTypeID(), dstTy.getTypeID()};
|
||||
if (srcMap.count(key) == 0) {
|
||||
llvm::errs() << "Unsupported conversion from " << srcTy << " to " << dstTy
|
||||
<< "\n";
|
||||
llvm_unreachable("");
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
return srcMap.lookup(key);
|
||||
#else
|
||||
if (computeCapability < 90 &&
|
||||
(srcTy.isFloat8E4M3FNUZ() || dstTy.isFloat8E4M3FNUZ())) {
|
||||
llvm::errs() << "Conversion from/to f8e4m3nv is only supported on "
|
||||
@@ -1645,16 +1713,26 @@ struct FpToFpOpConversion
|
||||
size_t numElements = 4;
|
||||
if (srcElementType.isFloat8E4M3FNUZ() ||
|
||||
dstElementType.isFloat8E4M3FNUZ() ||
|
||||
#ifdef USE_ROCM
|
||||
srcElementType.isFloat8E5M2FNUZ() ||
|
||||
dstElementType.isFloat8E5M2FNUZ())
|
||||
#else
|
||||
(computeCapability >= 90 &&
|
||||
((srcElementType.isFloat8E5M2() &&
|
||||
(dstElementType.isF16() || dstElementType.isF32())) ||
|
||||
dstElementType.isFloat8E5M2()))) {
|
||||
#endif
|
||||
{
|
||||
numElements = 2;
|
||||
}
|
||||
bool useFP16IntermediateSrc =
|
||||
#ifdef USE_ROCM
|
||||
srcElementType.isF32();
|
||||
#else
|
||||
srcElementType.isF32() &&
|
||||
!(computeCapability >= 90 &&
|
||||
(dstElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E5M2()));
|
||||
#endif
|
||||
bool isDstFP32 = dstElementType.isF32();
|
||||
Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;
|
||||
Type dstType = isDstFP32 ? f16_ty : dstElementType;
|
||||
|
||||
@@ -1075,15 +1075,16 @@ if TORCH_HAS_FP8E5B16:
|
||||
if TORCH_HAS_FP8E4B8:
|
||||
tl_to_torch_types[tl.float8e4b8] = torch.float8_e4m3fnuz
|
||||
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
input = tl.load(input_ptr + offsets, mask=mask)
|
||||
output = input
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
def gen_input(M, N, d_type, seed, device='cuda'):
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
input = tl.load(input_ptr + offsets, mask=mask)
|
||||
output = input
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
if d_type == tl.float16:
|
||||
@@ -1246,7 +1247,8 @@ def test_gemm_fp816_mixed_inputs(M, N, K, a_type, b_type, out_dtype, device = 'c
|
||||
def test_gemm_amd_fp8_inputs(M, N, K, a_type, b_type, out_dtype, device = 'cuda'):
|
||||
check_type_supported(out_dtype, device)
|
||||
|
||||
if triton.language.semantic.gpu_matrix_core_version() != 3:
|
||||
backend = triton.common.backend.get_backend("hip")
|
||||
if backend.get_matrix_core_version() != 3:
|
||||
pytest.skip("fp8 data type is not available on hardware")
|
||||
|
||||
@triton.jit
|
||||
@@ -1630,7 +1632,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
('float16', 'float16'),
|
||||
('float16', 'float32'),
|
||||
('float32', 'float32')]]
|
||||
if triton.language.semantic.gpu_matrix_core_version() == 0 else
|
||||
if triton.common.backend.get_backend("hip").get_matrix_core_version() == 0 else
|
||||
# MFMA Test Dot tests
|
||||
[(*shape, 2, False, False, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim)
|
||||
for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)]
|
||||
@@ -1881,7 +1883,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
# added atol, to loose precision for float16xfloat16->float32 case
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
if torch.version.hip is not None:
|
||||
if triton.language.semantic.gpu_matrix_core_version() > 0:
|
||||
backend = triton.common.backend.get_backend("hip")
|
||||
if backend.get_matrix_core_version() > 0:
|
||||
ttgir = pgm.asm['ttgir']
|
||||
if non_k_dim == 16:
|
||||
assert "#triton_gpu.mfma<{nonKDim = 16" in ttgir
|
||||
@@ -1890,9 +1893,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
assert "#triton_gpu.mfma<{nonKDim = 32" in ttgir
|
||||
assert "#triton_gpu.mfma<{nonKDim = 16" not in ttgir
|
||||
gcn = pgm.asm['amdgcn']
|
||||
if triton.language.semantic.gpu_matrix_core_version() == 3 and effective_in_dtype == tl.float8e5b16:
|
||||
if backend.get_matrix_core_version() == 3 and effective_in_dtype == tl.float8e5b16:
|
||||
assert "v_mfma_f32_32x32x16_bf8_bf8" in gcn or "v_mfma_f32_16x16x32_bf8_bf8" in gcn
|
||||
if triton.language.semantic.gpu_matrix_core_version() == 3 and effective_in_dtype == tl.float8e4b8:
|
||||
if backend.get_matrix_core_version() == 3 and effective_in_dtype == tl.float8e4b8:
|
||||
assert "v_mfma_f32_32x32x16_fp8_fp8" in gcn or "v_mfma_f32_16x16x32_fp8_fp8" in gcn
|
||||
return
|
||||
# make sure ld/st are vectorized
|
||||
@@ -2727,7 +2730,7 @@ def test_dot_mfma_vector_load(vec_size, swizzle, transposeA, transposeB):
|
||||
if transposeA and not transposeB:
|
||||
pytest.skip()
|
||||
|
||||
if triton.language.semantic.gpu_matrix_core_version() == 0:
|
||||
if triton.common.backend.get_backend("hip").get_matrix_core_version() == 0:
|
||||
pytest.skip("mfma is not available on hardware")
|
||||
|
||||
# source code for following ttgir:
|
||||
@@ -2817,7 +2820,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
kernel = triton.compile(f.name, device_type="hip", cc=capabilities)
|
||||
|
||||
import triton.language.semantic as sem
|
||||
if torch.version.hip is not None and sem.gpu_matrix_core_version() > 0:
|
||||
# if torch.version.hip is not None and sem.gpu_matrix_core_version() > 0:
|
||||
if torch.version.hip is not None and backend.get_matrix_core_version() > 0:
|
||||
kernel[(1, 1, 1)](x_tri, y_tri, z_tri)
|
||||
np.testing.assert_allclose(z_np, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
|
||||
|
||||
@@ -65,8 +65,7 @@ def ttir_compute_capability_rewrite(mod, target):
|
||||
if _is_cuda(target):
|
||||
pm.add_rewrite_tensor_pointer_pass(target.capability, False)
|
||||
elif is_hip():
|
||||
capability = 90
|
||||
pm.add_rewrite_tensor_pointer_pass(capability, True)
|
||||
pm.add_rewrite_tensor_pointer_pass(target["capability"], True)
|
||||
else:
|
||||
assert(False, "unsupported target")
|
||||
pm.run(mod)
|
||||
@@ -118,14 +117,14 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, e
|
||||
pm.add_tritongpu_accelerate_matmul_pass(capability)
|
||||
# TODO change interface of accelerate_matmul_pass
|
||||
if is_hip():
|
||||
matrix_core_version = gpu_matrix_core_version()
|
||||
matrix_core_version = target["matrix_core_version"]
|
||||
matrix_inst_size = matrix_inst_type
|
||||
pm.add_tritonamdgpu_accelerate_matmul_pass(matrix_core_version, matrix_inst_size)
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
if optimize_epilogue:
|
||||
pm.add_tritongpu_optimize_epilogue_pass()
|
||||
pm.add_tritongpu_optimize_dot_operands_pass()
|
||||
if num_stages == 0 and is_hip() and gpu_matrix_core_version() != 0:
|
||||
if num_stages == 0 and is_hip() and target["matrix_core_version"] != 0:
|
||||
pm.add_tritongpu_stream_pipeline_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
ws_enabled = False
|
||||
@@ -191,7 +190,7 @@ def ttgir_to_llir(mod, extern_libs, target, tma_infos, waves_per_eu=0):
|
||||
if _is_cuda(target):
|
||||
return translate_triton_gpu_to_llvmir(mod, target.capability, tma_infos, runtime.TARGET.NVVM, waves_per_eu)
|
||||
else:
|
||||
return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL, waves_per_eu)
|
||||
return translate_triton_gpu_to_llvmir(mod, target["capability"], TMAInfos(), runtime.TARGET.ROCDL, waves_per_eu)
|
||||
|
||||
|
||||
# PTX translation
|
||||
@@ -360,8 +359,6 @@ def is_hip():
|
||||
raise ImportError("Triton requires PyTorch to be installed")
|
||||
return torch.version.hip is not None
|
||||
|
||||
from ..language.semantic import gpu_matrix_core_version
|
||||
|
||||
def get_cuda_capability(capability):
|
||||
if capability is None:
|
||||
device = get_current_device()
|
||||
|
||||
@@ -1188,32 +1188,6 @@ def is_hip():
|
||||
raise ImportError("Triton requires PyTorch to be installed")
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
def gpu_matrix_core_version() -> int:
|
||||
""" Determine matrix core type available on current GPU.
|
||||
|
||||
0 means no tensor cores are available
|
||||
1 corresponds to MFMA in CDNA 1 architecture
|
||||
2 corresponds to MFMA in CDNA 2 architecture
|
||||
3 corresponds to MFMA in CDNA 3 architecture
|
||||
"""
|
||||
|
||||
if not is_hip():
|
||||
return 0
|
||||
arch_info = _triton.get_arch_info()
|
||||
gfx_arch_details = re.search('amd.*', arch_info)
|
||||
if gfx_arch_details is None:
|
||||
return 0
|
||||
gfx_arch_details = gfx_arch_details.group(0).strip().split('--')
|
||||
gpu_name = gfx_arch_details[1].split(':')[0]
|
||||
if gpu_name in ['gfx908']:
|
||||
return 1
|
||||
if gpu_name in ['gfx90a']:
|
||||
return 2
|
||||
if gpu_name in ['gfx940', 'gfx941', 'gfx942']:
|
||||
return 3
|
||||
return 0
|
||||
|
||||
def mfma_supported_granularity(m, n, k) -> bool:
|
||||
# todo make this gran_type matrix element type sensitive
|
||||
for gran_type in [(32, 8), (16, 16)]:
|
||||
@@ -1226,8 +1200,8 @@ def mfma_supported_granularity(m, n, k) -> bool:
|
||||
return True
|
||||
return False
|
||||
|
||||
def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
|
||||
matrix_core_version = gpu_matrix_core_version()
|
||||
def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty, target) -> bool:
|
||||
matrix_core_version = target["matrix_core_version"]
|
||||
if matrix_core_version not in [1, 2, 3]:
|
||||
return False
|
||||
if not mfma_supported_granularity(M, N ,K):
|
||||
@@ -1240,10 +1214,18 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_nu
|
||||
|
||||
def assert_dtypes_valid(lhs_dtype, rhs_dtype, target):
|
||||
# Checks for non-cuda archs
|
||||
if not _is_cuda(target):
|
||||
if is_hip():
|
||||
assert lhs.dtype == rhs.dtype or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp16()) or \
|
||||
(lhs.type.scalar.is_fp16() and rhs.type.scalar.is_fp8()) or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp8()), \
|
||||
f"First input ({lhs.dtype}) and second input ({rhs.dtype}) must have the same dtype!"
|
||||
if lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp8():
|
||||
assert lhs.type.scalar.is_fp8e4b8() or lhs.type.scalar.is_fp8e5b16() or lhs.type.scalar.is_fp8e5(),\
|
||||
f"Only hip fp8 or f8e5 types are accepted for both inputs of fp8"
|
||||
assert rhs.type.scalar.is_fp8e4b8() or rhs.type.scalar.is_fp8e5b16() or rhs.type.scalar.is_fp8e5(),\
|
||||
f"Only hip fp8 or f8e5 types are accepted for both inputs of fp8"
|
||||
return
|
||||
|
||||
if not _is_cuda(target):
|
||||
return
|
||||
|
||||
# Checks for cuda archs
|
||||
@@ -1287,13 +1269,18 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_nu
|
||||
|
||||
# hip for now converts fp8 to fp16 for mixed input
|
||||
if is_hip():
|
||||
fp8_supported = gpu_matrix_core_version() == 3
|
||||
target = builder.target
|
||||
assert "matrix_core_version" in target
|
||||
fp8_supported = target["matrix_core_version"] == 3
|
||||
# gfx940 data type
|
||||
lhs_hip_fp8 = lhs.type.scalar.is_fp8e4b8() or lhs.type.scalar.is_fp8e5b16()
|
||||
rhs_hip_fp8 = rhs.type.scalar.is_fp8e4b8() or rhs.type.scalar.is_fp8e5b16()
|
||||
lhs_fp8 = lhs.type.scalar.is_fp8()
|
||||
rhs_fp8 = rhs.type.scalar.is_fp8()
|
||||
supported_fp8_dot = fp8_supported and lhs_fp8 and rhs_fp8
|
||||
if not supported_fp8_dot and lhs_fp8:
|
||||
supported_fp8_dot = fp8_supported and lhs_hip_fp8 and rhs_hip_fp8
|
||||
if (not supported_fp8_dot) and lhs_fp8:
|
||||
lhs = cast(lhs, tl.float16, builder)
|
||||
if not supported_fp8_dot and rhs_fp8:
|
||||
if (not supported_fp8_dot) and rhs_fp8:
|
||||
rhs = cast(rhs, tl.float16, builder)
|
||||
|
||||
if lhs.type.scalar.is_int():
|
||||
@@ -1316,7 +1303,7 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_nu
|
||||
N = rhs.type.shape[1]
|
||||
|
||||
# Cast operands of types f16 and i8 for configurations where FMA only supported.
|
||||
if is_hip() and not mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty):
|
||||
if is_hip() and not mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty, builder.target):
|
||||
# max_num_imprecise_acc does not yet apply to hip
|
||||
if is_hip():
|
||||
max_num_imprecise_acc = 0
|
||||
@@ -1334,7 +1321,7 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_nu
|
||||
ret_ty)
|
||||
return cast(ret, ret_scalar_ty, builder)
|
||||
if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32,
|
||||
ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth < 32:
|
||||
ret_scalar_ty, builder.target) and ret_scalar_ty.primitive_bitwidth <= 32:
|
||||
# max_num_imprecise_acc does not yet apply to hip
|
||||
if is_hip():
|
||||
max_num_imprecise_acc = 0
|
||||
|
||||
33
python/triton/third_party/hip/hip_backend.py
vendored
33
python/triton/third_party/hip/hip_backend.py
vendored
@@ -273,6 +273,30 @@ def get_amdgcn_bitcode_paths(gfx_arch: str):
|
||||
return amdgcn_bitcode_paths
|
||||
|
||||
|
||||
def gpu_matrix_core_version() -> int:
|
||||
""" Determine matrix core type available on current GPU.
|
||||
|
||||
0 means no tensor cores are available
|
||||
1 corresponds to MFMA in CDNA 1 architecture
|
||||
2 corresponds to MFMA in CDNA 2 architecture
|
||||
3 corresponds to MFMA in CDNA 3 architecture
|
||||
"""
|
||||
|
||||
arch_info = _triton.get_arch_info()
|
||||
gfx_arch_details = re.search('amd.*', arch_info)
|
||||
if gfx_arch_details is None:
|
||||
return 0
|
||||
gfx_arch_details = gfx_arch_details.group(0).strip().split('--')
|
||||
gpu_name = gfx_arch_details[1].split(':')[0]
|
||||
if gpu_name in ['gfx908']:
|
||||
return 1
|
||||
if gpu_name in ['gfx90a']:
|
||||
return 2
|
||||
if gpu_name in ['gfx940', 'gfx941', 'gfx942']:
|
||||
return 3
|
||||
return 0
|
||||
|
||||
|
||||
def get_amdgpu_arch_fulldetails():
|
||||
# print("get_amdgpu_arch_fulldetails")
|
||||
"""
|
||||
@@ -294,7 +318,11 @@ def get_amdgpu_arch_fulldetails():
|
||||
if gfx_arch is None:
|
||||
raise RuntimeError('gfx_arch is None (not specified)')
|
||||
|
||||
return {"gfx_triple": arch_triple, "gfx_arch": gfx_arch, "gfx_features": arch_features}
|
||||
mat_core_ver = gpu_matrix_core_version()
|
||||
capability = gpu_matrix_core_version() * 100
|
||||
|
||||
return {"gfx_triple": arch_triple, "gfx_arch": gfx_arch, "gfx_features": arch_features,\
|
||||
"capability": capability, "matrix_core_version": mat_core_ver}
|
||||
except BaseException:
|
||||
return None
|
||||
|
||||
@@ -487,3 +515,6 @@ class HIPBackend(BaseBackend):
|
||||
return _triton.get_num_warps(module)
|
||||
else:
|
||||
return _triton.get_num_warps(module)
|
||||
|
||||
def get_matrix_core_version(self):
|
||||
return gpu_matrix_core_version()
|
||||
Reference in New Issue
Block a user