Merge branch 'triton-mlir' into ifu-231117

This commit is contained in:
jayfurmanek
2023-11-27 07:44:04 -06:00
committed by GitHub
5 changed files with 265 additions and 168 deletions

View File

@@ -4,6 +4,10 @@ using namespace mlir;
using namespace mlir::triton;
using ::mlir::triton::gpu::getTotalElemsPerThread;
typedef std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
const SmallVector<Value> &)>
ConverterT;
/* ----- FP8E5M2 ------ */
// This data-type is the standard FP8E5M2 format
#ifdef USE_ROCM
@@ -50,8 +54,93 @@ static const std::string Fp16_to_Fp8E5M2(bool hasNativeFP) {
}
#endif
// ROCM utility functions for data type conversion
#ifdef USE_ROCM
static Value convert_val_Fp16_to_Fp8E5M2FNUZ(
static Value cvtFp16ToFp32(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
GCNBuilder builder;
auto &cvt = *builder.create("v_cvt_f32_f16");
auto res = builder.newOperand("=v");
auto operand = builder.newOperand(v, "v");
cvt(res, operand);
return builder.launch(rewriter, loc, f32_ty, false);
}
static Value cvtFp32ToFp16(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
GCNBuilder builder;
auto &cvt = *builder.create("v_cvt_f16_f32");
auto res = builder.newOperand("=v");
auto operand = builder.newOperand(v, "v");
cvt(res, operand);
return builder.launch(rewriter, loc, f16_ty, false);
}
static SmallVector<Value> convert_val_Fp16_to_Fp8(
Location loc, ConversionPatternRewriter &rewriter,
Value v0, Value v1, const std::string& fp8_format) {
assert(fp8_format == "fp8" or fp8_format == "bf8");
std::string ins_str = "v_cvt_pk_" + fp8_format + "_f32";
auto f32_0 = cvtFp16ToFp32(loc, rewriter, v0);
auto f32_1 = cvtFp16ToFp32(loc, rewriter, v1);
GCNBuilder builder;
auto &cvt = *builder.create(ins_str);
auto res = builder.newOperand("=v");
auto operand0 = builder.newOperand(f32_0, "v");
auto operand1 = builder.newOperand(f32_1, "v");
cvt(res, operand0, operand1);
auto fp8x4Vec = builder.launch(rewriter, loc, i32_ty, false);
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto a1 = bitcast(fp8x4Vec, fp8x4VecTy);
SmallVector<Value> ret(2);
ret[0] = extract_element(i8_ty, a1, i32_val(0));
ret[1] = extract_element(i8_ty, a1, i32_val(1));
return ret;
}
static SmallVector<Value> convert_val_Fp8_to_Fp16(
Location loc, ConversionPatternRewriter &rewriter,
Value v0, Value v1, const std::string& fp8_format) {
assert(fp8_format == "fp8" or fp8_format == "bf8");
std::string ins_str = "v_cvt_pk_f32_" + fp8_format;
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value fp8x4Vec = undef(fp8x4VecTy);
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1));
auto i32v = bitcast(fp8x4Vec, i32_ty);
GCNBuilder builder1;
auto &cvt = *builder1.create(ins_str);
auto res = builder1.newOperand("=v");
auto operand = builder1.newOperand(i32v, "v");
cvt(res, operand);
auto i64v = builder1.launch(rewriter, loc, i64_ty, false);
auto fp32x2VecTy = vec_ty(f32_ty, 2);
auto fp32x2Vec = bitcast(i64v, fp32x2VecTy);
auto f32_0 = extract_element(f32_ty, fp32x2Vec, i32_val(0));
auto f32_1 = extract_element(f32_ty, fp32x2Vec, i32_val(1));
SmallVector<Value> ret(2);
ret[0] = cvtFp32ToFp16(loc, rewriter, f32_0);
ret[1] = cvtFp32ToFp16(loc, rewriter, f32_1);
return ret;
}
#endif
#ifdef USE_ROCM
// Depend on whether we focus more on performance, we may skip
// the processing of submornal values
static Value Fp16_to_Fp8E5M2FNUZ_oneValue(
Location loc, ConversionPatternRewriter &rewriter, Value v) {
auto vi16 = bitcast(v, i16_ty);
auto e = and_(i16_ty, vi16, int_val(16, 0x7C00));
@@ -79,28 +168,26 @@ static Value convert_val_Fp16_to_Fp8E5M2FNUZ(
}
static SmallVector<Value>
Fp16_to_Fp8E5M2FNUZ(Location loc, ConversionPatternRewriter &rewriter,
Fp16_to_Fp8E5M2FNUZ_SW(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
SmallVector<Value> result(4);
result[0] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[0]);
result[1] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[1]);
result[2] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[2]);
result[3] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[3]);
SmallVector<Value> result(2);
result[0] = Fp16_to_Fp8E5M2FNUZ_oneValue(loc, rewriter, v[0]);
result[1] = Fp16_to_Fp8E5M2FNUZ_oneValue(loc, rewriter, v[1]);
return result;
}
#else
const std::string Fp16_to_Fp8E5M2FNUZ =
"{ \n"
".reg .b32 a<2>; \n"
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
"}";
static SmallVector<Value> Fp16_to_Fp8E5M2FNUZ_HW(
Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value>& v) {
return convert_val_Fp16_to_Fp8(loc, rewriter, v[0], v[1], "bf8");
}
ConverterT Fp16_to_Fp8E5M2FNUZ(int computeCapability) {
return computeCapability >= 300 ? Fp16_to_Fp8E5M2FNUZ_HW : Fp16_to_Fp8E5M2FNUZ_SW;
}
#endif
#ifdef USE_ROCM
static SmallVector<Value>
Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
@@ -145,8 +232,7 @@ static const std::string Fp8E5M2_to_Fp16(bool hasNativeFP) {
#endif
#ifdef USE_ROCM
static Value convert_val_Fp8E5M2FNUZ_to_Fp16(
static Value Fp8E5M2FNUZ_to_Fp16_oneValue(
Location loc, ConversionPatternRewriter &rewriter, Value v) {
auto fp8x2VecTy = vec_ty(i8_ty, 2);
Value a = undef(fp8x2VecTy);
@@ -181,17 +267,23 @@ static Value convert_val_Fp8E5M2FNUZ_to_Fp16(
}
static SmallVector<Value>
Fp8E5M2FNUZ_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
Fp8E5M2FNUZ_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
SmallVector<Value> result(4);
result[0] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[0]);
result[1] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[1]);
result[2] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[2]);
result[3] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[3]);
SmallVector<Value> result(2);
result[0] = Fp8E5M2FNUZ_to_Fp16_oneValue(loc, rewriter, v[0]);
result[1] = Fp8E5M2FNUZ_to_Fp16_oneValue(loc, rewriter, v[1]);
return result;
}
static SmallVector<Value>
Fp8E5M2FNUZ_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
return convert_val_Fp8_to_Fp16(loc, rewriter, v[0], v[1], "bf8");
}
ConverterT Fp8E5M2FNUZ_to_Fp16(int computeCapability) {
return (computeCapability >= 300) ? Fp8E5M2FNUZ_to_Fp16_HW : Fp8E5M2FNUZ_to_Fp16_SW;
}
#endif
#ifdef USE_ROCM
@@ -662,7 +754,7 @@ static const std::string Fp16_to_Fp8E4M3B15x4 =
// more than a single NaN values.
#ifdef USE_ROCM
static Value convert_val_Fp8E4M3FNUZ_to_Fp16(
static Value Fp8E4M3FNUZ_to_Fp16_oneValue(
Location loc, ConversionPatternRewriter &rewriter, Value v) {
auto fp8x2VecTy = vec_ty(i8_ty, 2);
Value a = undef(fp8x2VecTy);
@@ -693,37 +785,30 @@ static Value convert_val_Fp8E4M3FNUZ_to_Fp16(
return bitcast(io, f16_ty);
}
// Fp8E4M3FNUZ -> Fp16 (packed)
static SmallVector<Value>
Fp8E4M3FNUZ_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
Fp8E4M3FNUZ_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
SmallVector<Value> result(2);
result[0] = convert_val_Fp8E4M3FNUZ_to_Fp16(loc, rewriter, v[0]);
result[1] = convert_val_Fp8E4M3FNUZ_to_Fp16(loc, rewriter, v[1]);
result[0] = Fp8E4M3FNUZ_to_Fp16_oneValue(loc, rewriter, v[0]);
result[1] = Fp8E4M3FNUZ_to_Fp16_oneValue(loc, rewriter, v[1]);
return result;
}
#else
const std::string Fp8E4M3FNUZ_to_Fp16 =
"{ \n"
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
"prmt.b32 a0, 0, $2, 0x5040; \n" // a0 = 0xf300f400
"prmt.b32 a1, 0, $2, 0x7060; \n" // a1 = 0xf100f200
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
"shr.b32 b0, b0, 1; \n" // b0 >>= 1
"shr.b32 b1, b1, 1; \n" // shift into fp16 position
"add.u32 b0, b0, 0x20002000; \n" // b0.exp += 2**4-2**3
// exponent compensate = 8
"add.u32 b1, b1, 0x20002000; \n" // b1 += 8<<10 | 8<<10<<16
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
"}";
static SmallVector<Value>
Fp8E4M3FNUZ_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
return convert_val_Fp8_to_Fp16(loc, rewriter, v[0], v[1], "fp8");
}
static ConverterT
Fp8E4M3FNUZ_to_Fp16(int computeCapability) {
return computeCapability >= 300 ? Fp8E4M3FNUZ_to_Fp16_HW : Fp8E4M3FNUZ_to_Fp16_SW;
}
#endif
// Fp16 -> Fp8E4M3 (packed)
#ifdef USE_ROCM
static Value convert_val_Fp16_to_Fp8E4M3FNUZ(
static Value Fp16_to_Fp8E4M3FNUZ_oneValue(
Location loc, ConversionPatternRewriter &rewriter, Value v) {
auto vi16 = bitcast(v, i16_ty);
auto e10 = and_(vi16, int_val(16, 0x7C00));
@@ -756,33 +841,25 @@ static Value convert_val_Fp16_to_Fp8E4M3FNUZ(
}
static SmallVector<Value>
Fp16_to_Fp8E4M3FNUZ(Location loc, ConversionPatternRewriter &rewriter,
Fp16_to_Fp8E4M3FNUZ_SW(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
SmallVector<Value> result(2);
result[0] = convert_val_Fp16_to_Fp8E4M3FNUZ(loc, rewriter, v[0]);
result[1] = convert_val_Fp16_to_Fp8E4M3FNUZ(loc, rewriter, v[1]);
result[0] = Fp16_to_Fp8E4M3FNUZ_oneValue(loc, rewriter, v[0]);
result[1] = Fp16_to_Fp8E4M3FNUZ_oneValue(loc, rewriter, v[1]);
return result;
}
#else
const std::string Fp16_to_Fp8E4M3FNUZ =
"{ \n"
".reg .b32 a<2>, b<2>; \n" // see Fp8E4M3x4ToFp16x4
"sub.u32 a0, $1, 0x20002000; \n" // a0 = input0 - 0x20002000
// (compensate offset)
"sub.u32 a1, $2, 0x20002000; \n" // a1 = input1 - 0x20002000
// (8 << 10 | 8 << 10 << 16)
"shl.b32 a0, a0, 1; \n" // a0 <<= 1
"shl.b32 a1, a1, 1; \n" // shift into fp8e4 position
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n" // a0 &= 0x7fff7fff
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n" // b0 = a0|(0x80008000&in0)
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n" // (restore sign)
"prmt.b32 $0, b0, b1, 0x7531; \n" // output = b1b0
"}";
static SmallVector<Value>
Fp16_to_Fp8E4M3FNUZ_HW(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
return convert_val_Fp16_to_Fp8(loc, rewriter, v[0], v[1], "fp8");
}
static ConverterT
Fp16_to_Fp8E4M3FNUZ(int computeCapability) {
return computeCapability >= 300 ? Fp16_to_Fp8E4M3FNUZ_HW : Fp16_to_Fp8E4M3FNUZ_SW;
}
#endif
// WARN: subnormal (0bs0000xxx) are not handled
@@ -1151,10 +1228,6 @@ inline SmallVector<Value> packI32(const SmallVector<Value> &inValues,
return outValues;
}
typedef std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
const SmallVector<Value> &)>
ConverterT;
static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType,
Type outType,
const int inVecWidthBits = 32,
@@ -1473,12 +1546,7 @@ struct FpToFpOpConversion
ConversionPatternRewriter &rewriter,
const Value &v) {
#ifdef USE_ROCM
GCNBuilder builder;
auto &cvt = *builder.create("v_cvt_f32_f16");
auto res = builder.newOperand("=v");
auto operand = builder.newOperand(v, "v");
cvt(res, operand);
return builder.launch(rewriter, loc, f32_ty, false);
return cvtFp16ToFp32(loc, rewriter, v);
#else
PTXBuilder builder;
auto &cvt = *builder.create("cvt.f32.f16");
@@ -1524,12 +1592,7 @@ struct FpToFpOpConversion
ConversionPatternRewriter &rewriter,
const Value &v) {
#ifdef USE_ROCM
GCNBuilder builder;
auto &cvt = *builder.create("v_cvt_f16_f32");
auto res = builder.newOperand("=v");
auto operand = builder.newOperand(v, "v");
cvt(res, operand);
return builder.launch(rewriter, loc, f16_ty, false);
return cvtFp32ToFp16(loc, rewriter, v);
#else
PTXBuilder builder;
auto &cvt = *builder.create("cvt.rz.f16.f32");
@@ -1542,7 +1605,11 @@ struct FpToFpOpConversion
ConverterT getConversionFunc(Type srcTy, Type dstTy) const {
auto F8E4M3B15TyID = TypeID::get<mlir::Float8E4M3B11FNUZType>();
#ifdef USE_ROCM
auto F8E4M3FNUZTyID = TypeID::get<mlir::Float8E4M3FNUZType>();
#else
auto F8E4M3TyID = TypeID::get<mlir::Float8E4M3FNUZType>();
#endif
auto F8E4M3FNTyID = TypeID::get<mlir::Float8E4M3FNType>();
auto F8E5M2TyID = TypeID::get<mlir::Float8E5M2Type>();
auto F8E5M2FNUZTyID = TypeID::get<mlir::Float8E5M2FNUZType>();
@@ -1558,37 +1625,37 @@ struct FpToFpOpConversion
// F8 -> F16
{{F8E4M3B15TyID, F16TyID}, Fp8E4M3B15_to_Fp16},
{{F8E4M3FNTyID, F16TyID}, Fp8E4M3B15x4_to_Fp16},
{{F8E4M3FNUZTyID, F16TyID}, Fp8E4M3FNUZ_to_Fp16},
#ifdef USE_ROCM
{{F8E4M3FNUZTyID, F16TyID}, Fp8E4M3FNUZ_to_Fp16(computeCapability)},
{{F8E5M2FNUZTyID, F16TyID}, Fp8E5M2FNUZ_to_Fp16(computeCapability)},
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16},
{{F8E5M2FNUZTyID, F16TyID}, Fp8E5M2FNUZ_to_Fp16},
#else
{{F8E4M3TyID, F16TyID}, Fp8E4M3Nv_to_Fp16},
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16(computeCapability >= 90)},
#endif
// F16 -> F8
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
#ifdef USE_ROCM
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15},
{{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ(computeCapability)},
{{F16TyID, F8E4M3FNUZTyID}, Fp16_to_Fp8E4M3FNUZ(computeCapability)},
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2},
#else
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15(computeCapability >= 80)},
#endif
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
{{F16TyID, F8E4M3FNUZTyID}, Fp16_to_Fp8E4M3FNUZ},
#ifdef USE_ROCM
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2},
{{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ},
#else
{{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3Nv},
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2(computeCapability >= 90)},
#endif
// F8 -> BF16
// F8 -> BF16
#ifdef USE_ROCM
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
#else
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)},
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)},
{{F8E4M3TyID, BF16TyID}, Fp8E4M3Nv_to_Bf16},
#endif
// BF16 -> F8
// BF16 -> F8
#ifdef USE_ROCM
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2},
#else
@@ -1599,6 +1666,16 @@ struct FpToFpOpConversion
{{F32TyID, F8E5M2TyID}, Fp32_to_Fp8E5M2},
#endif
};
std::pair<TypeID, TypeID> key = {srcTy.getTypeID(), dstTy.getTypeID()};
if (srcMap.count(key) == 0) {
llvm::errs() << "Unsupported conversion from " << srcTy << " to " << dstTy
<< "\n";
llvm_unreachable("");
}
#ifdef USE_ROCM
return srcMap.lookup(key);
#else
int inVecWidthBits = 32;
int outVecWidthBits = 32;
if (srcTy.isFloat8E4M3FNUZ() ||
@@ -1612,15 +1689,6 @@ struct FpToFpOpConversion
outVecWidthBits = 16;
}
std::pair<TypeID, TypeID> key = {srcTy.getTypeID(), dstTy.getTypeID()};
if (srcMap.count(key) == 0) {
llvm::errs() << "Unsupported conversion from " << srcTy << " to " << dstTy
<< "\n";
llvm_unreachable("");
}
#ifdef USE_ROCM
return srcMap.lookup(key);
#else
if (computeCapability < 90 &&
(srcTy.isFloat8E4M3FNUZ() || dstTy.isFloat8E4M3FNUZ())) {
llvm::errs() << "Conversion from/to f8e4m3nv is only supported on "
@@ -1645,16 +1713,26 @@ struct FpToFpOpConversion
size_t numElements = 4;
if (srcElementType.isFloat8E4M3FNUZ() ||
dstElementType.isFloat8E4M3FNUZ() ||
#ifdef USE_ROCM
srcElementType.isFloat8E5M2FNUZ() ||
dstElementType.isFloat8E5M2FNUZ())
#else
(computeCapability >= 90 &&
((srcElementType.isFloat8E5M2() &&
(dstElementType.isF16() || dstElementType.isF32())) ||
dstElementType.isFloat8E5M2()))) {
#endif
{
numElements = 2;
}
bool useFP16IntermediateSrc =
#ifdef USE_ROCM
srcElementType.isF32();
#else
srcElementType.isF32() &&
!(computeCapability >= 90 &&
(dstElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E5M2()));
#endif
bool isDstFP32 = dstElementType.isF32();
Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;
Type dstType = isDstFP32 ? f16_ty : dstElementType;

View File

@@ -1075,15 +1075,16 @@ if TORCH_HAS_FP8E5B16:
if TORCH_HAS_FP8E4B8:
tl_to_torch_types[tl.float8e4b8] = torch.float8_e4m3fnuz
@triton.jit
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
input = tl.load(input_ptr + offsets, mask=mask)
output = input
tl.store(output_ptr + offsets, output, mask=mask)
def gen_input(M, N, d_type, seed, device='cuda'):
@triton.jit
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
input = tl.load(input_ptr + offsets, mask=mask)
output = input
tl.store(output_ptr + offsets, output, mask=mask)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if d_type == tl.float16:
@@ -1246,7 +1247,8 @@ def test_gemm_fp816_mixed_inputs(M, N, K, a_type, b_type, out_dtype, device = 'c
def test_gemm_amd_fp8_inputs(M, N, K, a_type, b_type, out_dtype, device = 'cuda'):
check_type_supported(out_dtype, device)
if triton.language.semantic.gpu_matrix_core_version() != 3:
backend = triton.common.backend.get_backend("hip")
if backend.get_matrix_core_version() != 3:
pytest.skip("fp8 data type is not available on hardware")
@triton.jit
@@ -1630,7 +1632,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
('float16', 'float16'),
('float16', 'float32'),
('float32', 'float32')]]
if triton.language.semantic.gpu_matrix_core_version() == 0 else
if triton.common.backend.get_backend("hip").get_matrix_core_version() == 0 else
# MFMA Test Dot tests
[(*shape, 2, False, False, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim)
for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)]
@@ -1881,7 +1883,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
# added atol, to loose precision for float16xfloat16->float32 case
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
if torch.version.hip is not None:
if triton.language.semantic.gpu_matrix_core_version() > 0:
backend = triton.common.backend.get_backend("hip")
if backend.get_matrix_core_version() > 0:
ttgir = pgm.asm['ttgir']
if non_k_dim == 16:
assert "#triton_gpu.mfma<{nonKDim = 16" in ttgir
@@ -1890,9 +1893,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
assert "#triton_gpu.mfma<{nonKDim = 32" in ttgir
assert "#triton_gpu.mfma<{nonKDim = 16" not in ttgir
gcn = pgm.asm['amdgcn']
if triton.language.semantic.gpu_matrix_core_version() == 3 and effective_in_dtype == tl.float8e5b16:
if backend.get_matrix_core_version() == 3 and effective_in_dtype == tl.float8e5b16:
assert "v_mfma_f32_32x32x16_bf8_bf8" in gcn or "v_mfma_f32_16x16x32_bf8_bf8" in gcn
if triton.language.semantic.gpu_matrix_core_version() == 3 and effective_in_dtype == tl.float8e4b8:
if backend.get_matrix_core_version() == 3 and effective_in_dtype == tl.float8e4b8:
assert "v_mfma_f32_32x32x16_fp8_fp8" in gcn or "v_mfma_f32_16x16x32_fp8_fp8" in gcn
return
# make sure ld/st are vectorized
@@ -2727,7 +2730,7 @@ def test_dot_mfma_vector_load(vec_size, swizzle, transposeA, transposeB):
if transposeA and not transposeB:
pytest.skip()
if triton.language.semantic.gpu_matrix_core_version() == 0:
if triton.common.backend.get_backend("hip").get_matrix_core_version() == 0:
pytest.skip("mfma is not available on hardware")
# source code for following ttgir:
@@ -2817,7 +2820,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
kernel = triton.compile(f.name, device_type="hip", cc=capabilities)
import triton.language.semantic as sem
if torch.version.hip is not None and sem.gpu_matrix_core_version() > 0:
# if torch.version.hip is not None and sem.gpu_matrix_core_version() > 0:
if torch.version.hip is not None and backend.get_matrix_core_version() > 0:
kernel[(1, 1, 1)](x_tri, y_tri, z_tri)
np.testing.assert_allclose(z_np, to_numpy(z_tri), rtol=0.01, atol=1e-3)

View File

@@ -65,8 +65,7 @@ def ttir_compute_capability_rewrite(mod, target):
if _is_cuda(target):
pm.add_rewrite_tensor_pointer_pass(target.capability, False)
elif is_hip():
capability = 90
pm.add_rewrite_tensor_pointer_pass(capability, True)
pm.add_rewrite_tensor_pointer_pass(target["capability"], True)
else:
assert(False, "unsupported target")
pm.run(mod)
@@ -118,14 +117,14 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, e
pm.add_tritongpu_accelerate_matmul_pass(capability)
# TODO change interface of accelerate_matmul_pass
if is_hip():
matrix_core_version = gpu_matrix_core_version()
matrix_core_version = target["matrix_core_version"]
matrix_inst_size = matrix_inst_type
pm.add_tritonamdgpu_accelerate_matmul_pass(matrix_core_version, matrix_inst_size)
pm.add_tritongpu_remove_layout_conversions_pass()
if optimize_epilogue:
pm.add_tritongpu_optimize_epilogue_pass()
pm.add_tritongpu_optimize_dot_operands_pass()
if num_stages == 0 and is_hip() and gpu_matrix_core_version() != 0:
if num_stages == 0 and is_hip() and target["matrix_core_version"] != 0:
pm.add_tritongpu_stream_pipeline_pass()
pm.add_canonicalizer_pass()
ws_enabled = False
@@ -191,7 +190,7 @@ def ttgir_to_llir(mod, extern_libs, target, tma_infos, waves_per_eu=0):
if _is_cuda(target):
return translate_triton_gpu_to_llvmir(mod, target.capability, tma_infos, runtime.TARGET.NVVM, waves_per_eu)
else:
return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL, waves_per_eu)
return translate_triton_gpu_to_llvmir(mod, target["capability"], TMAInfos(), runtime.TARGET.ROCDL, waves_per_eu)
# PTX translation
@@ -360,8 +359,6 @@ def is_hip():
raise ImportError("Triton requires PyTorch to be installed")
return torch.version.hip is not None
from ..language.semantic import gpu_matrix_core_version
def get_cuda_capability(capability):
if capability is None:
device = get_current_device()

View File

@@ -1188,32 +1188,6 @@ def is_hip():
raise ImportError("Triton requires PyTorch to be installed")
return torch.version.hip is not None
def gpu_matrix_core_version() -> int:
""" Determine matrix core type available on current GPU.
0 means no tensor cores are available
1 corresponds to MFMA in CDNA 1 architecture
2 corresponds to MFMA in CDNA 2 architecture
3 corresponds to MFMA in CDNA 3 architecture
"""
if not is_hip():
return 0
arch_info = _triton.get_arch_info()
gfx_arch_details = re.search('amd.*', arch_info)
if gfx_arch_details is None:
return 0
gfx_arch_details = gfx_arch_details.group(0).strip().split('--')
gpu_name = gfx_arch_details[1].split(':')[0]
if gpu_name in ['gfx908']:
return 1
if gpu_name in ['gfx90a']:
return 2
if gpu_name in ['gfx940', 'gfx941', 'gfx942']:
return 3
return 0
def mfma_supported_granularity(m, n, k) -> bool:
# todo make this gran_type matrix element type sensitive
for gran_type in [(32, 8), (16, 16)]:
@@ -1226,8 +1200,8 @@ def mfma_supported_granularity(m, n, k) -> bool:
return True
return False
def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
matrix_core_version = gpu_matrix_core_version()
def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty, target) -> bool:
matrix_core_version = target["matrix_core_version"]
if matrix_core_version not in [1, 2, 3]:
return False
if not mfma_supported_granularity(M, N ,K):
@@ -1240,10 +1214,18 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_nu
def assert_dtypes_valid(lhs_dtype, rhs_dtype, target):
# Checks for non-cuda archs
if not _is_cuda(target):
if is_hip():
assert lhs.dtype == rhs.dtype or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp16()) or \
(lhs.type.scalar.is_fp16() and rhs.type.scalar.is_fp8()) or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp8()), \
f"First input ({lhs.dtype}) and second input ({rhs.dtype}) must have the same dtype!"
if lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp8():
assert lhs.type.scalar.is_fp8e4b8() or lhs.type.scalar.is_fp8e5b16() or lhs.type.scalar.is_fp8e5(),\
f"Only hip fp8 or f8e5 types are accepted for both inputs of fp8"
assert rhs.type.scalar.is_fp8e4b8() or rhs.type.scalar.is_fp8e5b16() or rhs.type.scalar.is_fp8e5(),\
f"Only hip fp8 or f8e5 types are accepted for both inputs of fp8"
return
if not _is_cuda(target):
return
# Checks for cuda archs
@@ -1287,13 +1269,18 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_nu
# hip for now converts fp8 to fp16 for mixed input
if is_hip():
fp8_supported = gpu_matrix_core_version() == 3
target = builder.target
assert "matrix_core_version" in target
fp8_supported = target["matrix_core_version"] == 3
# gfx940 data type
lhs_hip_fp8 = lhs.type.scalar.is_fp8e4b8() or lhs.type.scalar.is_fp8e5b16()
rhs_hip_fp8 = rhs.type.scalar.is_fp8e4b8() or rhs.type.scalar.is_fp8e5b16()
lhs_fp8 = lhs.type.scalar.is_fp8()
rhs_fp8 = rhs.type.scalar.is_fp8()
supported_fp8_dot = fp8_supported and lhs_fp8 and rhs_fp8
if not supported_fp8_dot and lhs_fp8:
supported_fp8_dot = fp8_supported and lhs_hip_fp8 and rhs_hip_fp8
if (not supported_fp8_dot) and lhs_fp8:
lhs = cast(lhs, tl.float16, builder)
if not supported_fp8_dot and rhs_fp8:
if (not supported_fp8_dot) and rhs_fp8:
rhs = cast(rhs, tl.float16, builder)
if lhs.type.scalar.is_int():
@@ -1316,7 +1303,7 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_nu
N = rhs.type.shape[1]
# Cast operands of types f16 and i8 for configurations where FMA only supported.
if is_hip() and not mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty):
if is_hip() and not mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty, builder.target):
# max_num_imprecise_acc does not yet apply to hip
if is_hip():
max_num_imprecise_acc = 0
@@ -1334,7 +1321,7 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_nu
ret_ty)
return cast(ret, ret_scalar_ty, builder)
if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32,
ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth < 32:
ret_scalar_ty, builder.target) and ret_scalar_ty.primitive_bitwidth <= 32:
# max_num_imprecise_acc does not yet apply to hip
if is_hip():
max_num_imprecise_acc = 0

View File

@@ -273,6 +273,30 @@ def get_amdgcn_bitcode_paths(gfx_arch: str):
return amdgcn_bitcode_paths
def gpu_matrix_core_version() -> int:
""" Determine matrix core type available on current GPU.
0 means no tensor cores are available
1 corresponds to MFMA in CDNA 1 architecture
2 corresponds to MFMA in CDNA 2 architecture
3 corresponds to MFMA in CDNA 3 architecture
"""
arch_info = _triton.get_arch_info()
gfx_arch_details = re.search('amd.*', arch_info)
if gfx_arch_details is None:
return 0
gfx_arch_details = gfx_arch_details.group(0).strip().split('--')
gpu_name = gfx_arch_details[1].split(':')[0]
if gpu_name in ['gfx908']:
return 1
if gpu_name in ['gfx90a']:
return 2
if gpu_name in ['gfx940', 'gfx941', 'gfx942']:
return 3
return 0
def get_amdgpu_arch_fulldetails():
# print("get_amdgpu_arch_fulldetails")
"""
@@ -294,7 +318,11 @@ def get_amdgpu_arch_fulldetails():
if gfx_arch is None:
raise RuntimeError('gfx_arch is None (not specified)')
return {"gfx_triple": arch_triple, "gfx_arch": gfx_arch, "gfx_features": arch_features}
mat_core_ver = gpu_matrix_core_version()
capability = gpu_matrix_core_version() * 100
return {"gfx_triple": arch_triple, "gfx_arch": gfx_arch, "gfx_features": arch_features,\
"capability": capability, "matrix_core_version": mat_core_ver}
except BaseException:
return None
@@ -487,3 +515,6 @@ class HIPBackend(BaseBackend):
return _triton.get_num_warps(module)
else:
return _triton.get_num_warps(module)
def get_matrix_core_version(self):
return gpu_matrix_core_version()