mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
fp8 type support (#357)
* add two fp8 data types `tl.float8e4b8` and `tl.float8e5b16` to triton. * add SW type conversion between `tl.float8e4b8/tl.float8e5b16` and `fp16` * change flashattention to support fp8 in q/k.
This commit is contained in:
@@ -14,7 +14,7 @@ class TritonTypeDef<string name, string _mnemonic>
|
||||
}
|
||||
|
||||
// Floating-point Type
|
||||
def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2, F16, BF16, F32, F64], "floating-point">;
|
||||
def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
|
||||
def TT_FloatTensor : TensorOf<[TT_Float]>;
|
||||
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
|
||||
|
||||
|
||||
@@ -21,16 +21,6 @@ Fp16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter,
|
||||
|
||||
Value a0 = bitcast(fp16x2Vec0, i32_ty);
|
||||
Value a1 = bitcast(fp16x2Vec1, i32_ty);
|
||||
Value sign0 = and_(i32_ty, a0, i32_val(0x80008000));
|
||||
Value sign1 = and_(i32_ty, a1, i32_val(0x80008000));
|
||||
|
||||
a0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
|
||||
a1 = and_(i32_ty, a1, i32_val(0x7fff7fff));
|
||||
a0 = add(i32_ty, a0, i32_val(0x00800080));
|
||||
a1 = add(i32_ty, a1, i32_val(0x00800080));
|
||||
|
||||
a0 = or_(i32_ty, a0, sign0);
|
||||
a1 = or_(i32_ty, a1, sign1);
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
a0 = bitcast(a0, fp8x4VecTy);
|
||||
@@ -54,6 +44,57 @@ const std::string Fp16_to_Fp8E5M2 =
|
||||
"}";
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static Value convert_val_Fp16_to_Fp8E5M2FNUZ(
|
||||
Location loc, ConversionPatternRewriter &rewriter, Value v) {
|
||||
auto vi16 = bitcast(v, i16_ty);
|
||||
auto e = and_(i16_ty, vi16, int_val(16, 0x7C00));
|
||||
auto sign = and_(i16_ty, vi16, int_val(16, 0x8000));
|
||||
|
||||
// normal value
|
||||
auto a = and_(i16_ty, vi16, int_val(16, 0x7FFFF));
|
||||
auto a1 = add(i16_ty, a, int_val(16, 0x0400));
|
||||
auto o1 = or_(i16_ty, a1, sign);
|
||||
|
||||
// subnormal value, e is 0
|
||||
auto m = and_(i16_ty, vi16, int_val(16, 0x03FF));
|
||||
auto m2 = shl(m, int_val(16, 1));
|
||||
auto o2 = or_(i16_ty, sign, or_(i16_ty, int_val(16, 1), m2));
|
||||
|
||||
auto e_is_zero = icmp_eq(e, int_val(16, 0));
|
||||
auto e_is_all1 = icmp_eq(e, int_val(16, 0x7C00));
|
||||
|
||||
auto ot = select(e_is_zero, o2, o1);
|
||||
auto o = select(e_is_all1, vi16, ot);
|
||||
auto fp8x2VecTy = vec_ty(i8_ty, 2);
|
||||
auto res = bitcast(o, fp8x2VecTy);
|
||||
|
||||
return extract_element(i8_ty, res, i32_val(1));
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
Fp16_to_Fp8E5M2FNUZ(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
SmallVector<Value> result(4);
|
||||
result[0] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[0]);
|
||||
result[1] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[1]);
|
||||
result[2] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[2]);
|
||||
result[3] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[3]);
|
||||
|
||||
return result;
|
||||
}
|
||||
#else
|
||||
const std::string Fp16_to_Fp8E5M2FNUZ =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>; \n"
|
||||
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
|
||||
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
|
||||
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
|
||||
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
|
||||
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
|
||||
"}";
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
@@ -89,6 +130,61 @@ const std::string Fp8E5M2_to_Fp16 = "{ \n"
|
||||
"}";
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
|
||||
static Value convert_val_Fp8E5M2FNUZ_to_Fp16(
|
||||
Location loc, ConversionPatternRewriter &rewriter, Value v) {
|
||||
auto fp8x2VecTy = vec_ty(i8_ty, 2);
|
||||
Value a = undef(fp8x2VecTy);
|
||||
a = insert_element(fp8x2VecTy, a, int_val(8, 0), i32_val(0));
|
||||
a = insert_element(fp8x2VecTy, a, v, i32_val(1));
|
||||
a = bitcast(a, i16_ty);
|
||||
|
||||
auto e = and_(i16_ty, a, int_val(16, 0x7C00));
|
||||
auto m = and_(i16_ty, a, int_val(16, 0x0300));
|
||||
auto sign = and_(i16_ty, a, int_val(16, 0x8000));
|
||||
|
||||
// check whether all exponents are zeros
|
||||
auto e_is_zero = icmp_eq(e, int_val(16, 0x0));
|
||||
|
||||
// case 1, e is zero, need to move m right by 1 bit
|
||||
auto m1 = lshr(i16_ty, m, int_val(16, 1));
|
||||
auto o0 = or_(i16_ty, sign, m1);
|
||||
|
||||
// case 2, e is nonzero, sub exponent by 1
|
||||
auto e1 = sub(i16_ty, e, int_val(16, 0x0400));
|
||||
|
||||
auto e_is_one = icmp_eq(e, int_val(16, 0x0400));
|
||||
auto m2 = add(i16_ty, m1, int_val(16, 0x0200));
|
||||
|
||||
auto o1 = or_(i16_ty, sign, or_(i16_ty, m, e1));
|
||||
auto o2 = or_(i16_ty, sign, m2);
|
||||
|
||||
auto o12 = select(e_is_one, o2, o1);
|
||||
auto o = select(e_is_zero, o0, o12);
|
||||
|
||||
return bitcast(o, f16_ty);
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
Fp8E5M2FNUZ_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
|
||||
SmallVector<Value> result(4);
|
||||
result[0] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[0]);
|
||||
result[1] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[1]);
|
||||
result[2] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[2]);
|
||||
result[3] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[3]);
|
||||
|
||||
return result;
|
||||
}
|
||||
#else
|
||||
const std::string Fp8E5M2FNUZ_to_Fp16 = "{ \n"
|
||||
"prmt.b32 $0, 0, $2, 0x5140; \n\t"
|
||||
"prmt.b32 $1, 0, $2, 0x7362; \n\t"
|
||||
"}";
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp8E5M2_to_Bf16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
@@ -510,36 +606,50 @@ const std::string Fp16_to_Fp8E4M3B15x4 =
|
||||
// does not handle denormals and has
|
||||
// more than a single NaN values.
|
||||
|
||||
// Fp8E4M3 -> Fp16 (packed)
|
||||
#ifdef USE_ROCM
|
||||
static Value convert_val_Fp8E4M3FNUZ_to_Fp16(
|
||||
Location loc, ConversionPatternRewriter &rewriter, Value v) {
|
||||
auto fp8x2VecTy = vec_ty(i8_ty, 2);
|
||||
Value a = undef(fp8x2VecTy);
|
||||
a = insert_element(fp8x2VecTy, a, int_val(8, 0), i32_val(0));
|
||||
a = insert_element(fp8x2VecTy, a, v, i32_val(1));
|
||||
a = bitcast(a, i16_ty);
|
||||
|
||||
auto e_mask = int_val(16, 0x7A00);
|
||||
auto e = and_(i16_ty, a, e_mask);
|
||||
|
||||
auto m = and_(i16_ty, a, int_val(16, 0x0700));
|
||||
auto sign = and_(i16_ty, a, int_val(16, 0x8000));
|
||||
|
||||
// check whether all exponents are zeros
|
||||
auto e_is_zero = icmp_eq(e, int_val(16, 0x0));
|
||||
auto b = and_(i16_ty, a, int_val(16, 0x7FFF));
|
||||
auto b1 = lshr(i16_ty, b, int_val(16, 1));
|
||||
|
||||
// case 1, e is nonzero, add exponent by 6
|
||||
auto o0v = add(i16_ty, b1, int_val(16, 0x0C00));
|
||||
auto o0 = or_(i16_ty, o0v, sign);
|
||||
|
||||
// case 2, e is nonzero, add exponent by 7
|
||||
auto o1v = add(i16_ty, b1, int_val(16, 0x1C00));
|
||||
auto o1 = or_(i16_ty, o1v, sign);
|
||||
|
||||
auto io = select(e_is_zero, o0, o1);
|
||||
return bitcast(io, f16_ty);
|
||||
}
|
||||
|
||||
// Fp8E4M3FNUZ -> Fp16 (packed)
|
||||
static SmallVector<Value>
|
||||
Fp8E4M3_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
Fp8E4M3FNUZ_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
Value a0 = undef(fp8x4VecTy);
|
||||
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1));
|
||||
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3));
|
||||
a0 = bitcast(a0, i32_ty);
|
||||
SmallVector<Value> result(2);
|
||||
result[0] = convert_val_Fp8E4M3FNUZ_to_Fp16(loc, rewriter, v[0]);
|
||||
result[1] = convert_val_Fp8E4M3FNUZ_to_Fp16(loc, rewriter, v[1]);
|
||||
|
||||
Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
|
||||
|
||||
b0 = lshr(i32_ty, b0, i32_val(1));
|
||||
|
||||
b0 = add(i32_ty, b0, i32_val(0x20002000));
|
||||
|
||||
b0 = or_( i32_ty, b0, and_(i32_ty, a0, i32_val(0x80008000)) );
|
||||
|
||||
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||
auto fp16x2Vec0 = bitcast(b0, fp16x2VecTy);
|
||||
|
||||
return { extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
|
||||
extract_element(f16_ty, fp16x2Vec0, i32_val(1))
|
||||
};
|
||||
return result;
|
||||
}
|
||||
#else
|
||||
const std::string Fp8E4M3_to_Fp16 =
|
||||
const std::string Fp8E4M3FNUZ_to_Fp16 =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
|
||||
"prmt.b32 a0, 0, $2, 0x5040; \n" // a0 = 0xf300f400
|
||||
@@ -558,32 +668,50 @@ const std::string Fp8E4M3_to_Fp16 =
|
||||
|
||||
// Fp16 -> Fp8E4M3 (packed)
|
||||
#ifdef USE_ROCM
|
||||
static Value convert_val_Fp16_to_Fp8E4M3FNUZ(
|
||||
Location loc, ConversionPatternRewriter &rewriter, Value v) {
|
||||
auto vi16 = bitcast(v, i16_ty);
|
||||
auto e10 = and_(vi16, int_val(16, 0x7C00));
|
||||
auto e = lshr(i16_ty, e10, int_val(16, 10));
|
||||
|
||||
auto s = and_(i16_ty, vi16, int_val(16, 0x8000));
|
||||
|
||||
auto m7 = and_(i16_ty, vi16, int_val(16, 0x0380));
|
||||
auto m = shl(i16_ty, m7, int_val(16, 1));
|
||||
|
||||
// three cases:
|
||||
// 1) e > 21 --> e = 1111,
|
||||
// 2) e <= 7 ---> e = 0,
|
||||
// 3) others, normal conversion
|
||||
auto e1 = int_val(16, 0x7800);
|
||||
auto e2 = int_val(16, 0x0);
|
||||
auto e31 = sub(i16_ty, e10, int_val(16, 0x1C00));
|
||||
auto e3 = shl(i16_ty, e31, int_val(16, 1));
|
||||
|
||||
auto c13 = icmp_sgt(e, int_val(16, 21));
|
||||
auto e13 = select(c13, e1, e3);
|
||||
auto c23 = icmp_sle(e, int_val(16, 7));
|
||||
auto re = select(c23, e2, e13);
|
||||
|
||||
auto r = or_(i16_ty, s, or_(i16_ty, re, m));
|
||||
auto fp8x2VecTy = vec_ty(i8_ty, 2);
|
||||
auto res = bitcast(r, fp8x2VecTy);
|
||||
|
||||
return extract_element(i8_ty, res, i32_val(1));
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
Fp16_to_Fp8E4M3(Location loc, ConversionPatternRewriter &rewriter,
|
||||
Fp16_to_Fp8E4M3FNUZ(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||
Value fp16x2Vec0 = undef(fp16x2VecTy);
|
||||
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0));
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1));
|
||||
|
||||
fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
|
||||
fp16x2Vec0 = sub(i32_ty, fp16x2Vec0, i32_val(0x20002000));
|
||||
SmallVector<Value> result(2);
|
||||
result[0] = convert_val_Fp16_to_Fp8E4M3FNUZ(loc, rewriter, v[0]);
|
||||
result[1] = convert_val_Fp16_to_Fp8E4M3FNUZ(loc, rewriter, v[1]);
|
||||
|
||||
Value a0 = shl(i32_ty, fp16x2Vec0, i32_val(1));
|
||||
a0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
|
||||
a0 = add(i32_ty, a0, i32_val(0x00800080));
|
||||
Value b0 = or_( i32_ty, and_(i32_ty, fp16x2Vec0, i32_val(0x80008000)), a0 );
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
b0 = bitcast(b0, fp8x4VecTy);
|
||||
|
||||
return {extract_element(i8_ty, b0, i32_val(1)),
|
||||
extract_element(i8_ty, b0, i32_val(3))
|
||||
};
|
||||
return result;
|
||||
}
|
||||
#else
|
||||
const std::string Fp16_to_Fp8E4M3 =
|
||||
const std::string Fp16_to_Fp8E4M3FNUZ =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n" // see Fp8E4M3x4ToFp16x4
|
||||
"sub.u32 a0, $1, 0x20002000; \n" // a0 = input0 - 0x20002000
|
||||
@@ -1215,9 +1343,10 @@ struct FpToFpOpConversion
|
||||
|
||||
ConverterT getConversionFunc(Type srcTy, Type dstTy) const {
|
||||
auto F8E4M3B15TyID = TypeID::get<mlir::Float8E4M3B11FNUZType>();
|
||||
auto F8E4M3TyID = TypeID::get<mlir::Float8E4M3FNUZType>();
|
||||
auto F8E5M2TyID = TypeID::get<mlir::Float8E5M2Type>();
|
||||
auto F8E4M3FNUZTyID = TypeID::get<mlir::Float8E4M3FNUZType>();
|
||||
auto F8E4M3FNTyID = TypeID::get<mlir::Float8E4M3FNType>();
|
||||
auto F8E5M2TyID = TypeID::get<mlir::Float8E5M2Type>();
|
||||
auto F8E5M2FNUZTyID = TypeID::get<mlir::Float8E5M2FNUZType>();
|
||||
auto F16TyID = TypeID::get<mlir::Float16Type>();
|
||||
auto BF16TyID = TypeID::get<mlir::BFloat16Type>();
|
||||
auto F32TyID = TypeID::get<mlir::Float32Type>();
|
||||
@@ -1230,8 +1359,9 @@ struct FpToFpOpConversion
|
||||
// F8 -> F16
|
||||
{{F8E4M3B15TyID, F16TyID}, Fp8E4M3B15_to_Fp16},
|
||||
{{F8E4M3FNTyID, F16TyID}, Fp8E4M3B15x4_to_Fp16},
|
||||
{{F8E4M3TyID, F16TyID}, Fp8E4M3_to_Fp16},
|
||||
{{F8E4M3FNUZTyID, F16TyID}, Fp8E4M3FNUZ_to_Fp16},
|
||||
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16},
|
||||
{{F8E5M2FNUZTyID, F16TyID}, Fp8E5M2FNUZ_to_Fp16},
|
||||
// F16 -> F8
|
||||
#ifdef USE_ROCM
|
||||
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15},
|
||||
@@ -1239,8 +1369,9 @@ struct FpToFpOpConversion
|
||||
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15(computeCapability >= 80)},
|
||||
#endif
|
||||
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
|
||||
{{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3},
|
||||
{{F16TyID, F8E4M3FNUZTyID}, Fp16_to_Fp8E4M3FNUZ},
|
||||
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2},
|
||||
{{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ},
|
||||
// F8 -> BF16
|
||||
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
|
||||
// BF16 -> F8
|
||||
|
||||
@@ -37,6 +37,9 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
|
||||
addConversion([&](mlir::Float8E5M2Type type) -> std::optional<Type> {
|
||||
return IntegerType::get(type.getContext(), 8);
|
||||
});
|
||||
addConversion([&](mlir::Float8E5M2FNUZType type) -> std::optional<Type> {
|
||||
return IntegerType::get(type.getContext(), 8);
|
||||
});
|
||||
// Internally store bfloat16 as int16
|
||||
addConversion([&](BFloat16Type type) -> std::optional<Type> {
|
||||
return IntegerType::get(type.getContext(), 16);
|
||||
|
||||
@@ -400,7 +400,7 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
|
||||
} // namespace LLVM
|
||||
|
||||
bool isF8(Type eType) {
|
||||
return eType.isFloat8E5M2FNUZ() or eType.isFloat8E4M3FNUZ() or
|
||||
return eType.isFloat8E4M3FNUZ() or eType.isFloat8E4M3FN() or
|
||||
eType.isFloat8E5M2() or eType.isFloat8E5M2FNUZ();
|
||||
}
|
||||
|
||||
|
||||
@@ -811,6 +811,10 @@ void init_triton_ir(py::module &&m) {
|
||||
[](TritonOpBuilder &self) -> mlir::Type {
|
||||
return self.getBuilder().getType<mlir::Float8E4M3FNUZType>();
|
||||
})
|
||||
.def("get_fp8e4b8_ty",
|
||||
[](TritonOpBuilder &self) -> mlir::Type {
|
||||
return self.getBuilder().getType<mlir::Float8E4M3FNUZType>();
|
||||
})
|
||||
.def("get_fp8e4b15_ty",
|
||||
[](TritonOpBuilder &self) -> mlir::Type {
|
||||
// TODO: upstream FP8E4B15 into MLIR, or find a way to externally
|
||||
@@ -827,6 +831,10 @@ void init_triton_ir(py::module &&m) {
|
||||
[](TritonOpBuilder &self) -> mlir::Type {
|
||||
return self.getBuilder().getType<mlir::Float8E5M2Type>();
|
||||
})
|
||||
.def("get_fp8e5b16_ty",
|
||||
[](TritonOpBuilder &self) -> mlir::Type {
|
||||
return self.getBuilder().getType<mlir::Float8E5M2FNUZType>();
|
||||
})
|
||||
.def("get_half_ty",
|
||||
[](TritonOpBuilder &self) -> mlir::Type {
|
||||
return self.getBuilder().getF16Type();
|
||||
|
||||
@@ -959,6 +959,8 @@ def convert_float_to_float32(fp: torch.tensor, dtype=None):
|
||||
# float8e4m3nv does not have infinities
|
||||
output[fp == 0b01111111] = torch.nan
|
||||
output[fp == 0b11111111] = torch.nan
|
||||
elif dtype in [tl.float8e4b8, tl.float8e5b16]:
|
||||
output[fp==0b10000000] = torch.nan
|
||||
else:
|
||||
output = torch.where(exp == (1 << exp_width) - 1,
|
||||
((sign << (tl.float32.primitive_bitwidth - 1)) | extended_exp | (frac << (tl.float32.fp_mantissa_width - dtype.fp_mantissa_width))).view(torch.float32),
|
||||
@@ -1015,7 +1017,11 @@ def deserialize_fp8(np_data, in_dtype):
|
||||
return np_data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4b15x4, tl.float8e4nv, tl.float8e5])
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4b15,
|
||||
tl.float8e4b15x4,
|
||||
tl.float8e4b8,
|
||||
tl.float8e5,
|
||||
tl.float8e5b16])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32])
|
||||
def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
|
||||
"""
|
||||
@@ -1040,9 +1046,11 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
|
||||
is_nan = (ref_fp8 & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
|
||||
exp_mask = 0b01111111 ^ ((1 << in_dtype.fp_mantissa_width) - 1)
|
||||
is_subnormal = np.logical_or((ref_fp8 & exp_mask) == 0, (ref_fp8 & exp_mask) == exp_mask)
|
||||
|
||||
ref_fp8[is_nan] = 0
|
||||
ref_fp8[is_subnormal] = 0
|
||||
tri_fp8 = torch.from_numpy(serialize_fp8(ref_fp8, in_dtype)).cuda()
|
||||
|
||||
tri_fp16 = torch.empty(256, dtype=out_dtype, device="cuda")
|
||||
copy_kernel[(1,)](triton.reinterpret(tri_fp8, in_dtype), tri_fp16, tri_fp16.shape[0], BLOCK_SIZE=1024)
|
||||
|
||||
@@ -1055,6 +1063,50 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
|
||||
assert torch.all(tri_fp8 == ref_fp8)
|
||||
|
||||
|
||||
TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz')
|
||||
TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz')
|
||||
tl_to_torch_types = {
|
||||
tl.float16: torch.float16,
|
||||
tl.float32: torch.float32,
|
||||
}
|
||||
|
||||
if TORCH_HAS_FP8E5B16:
|
||||
tl_to_torch_types[tl.float8e5b16] = torch.float8_e5m2fnuz
|
||||
if TORCH_HAS_FP8E4B8:
|
||||
tl_to_torch_types[tl.float8e4b8] = torch.float8_e4m3fnuz
|
||||
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
input = tl.load(input_ptr + offsets, mask=mask)
|
||||
output = input
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
def gen_input(M, N, d_type, seed, device='cuda'):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
if d_type == tl.float16:
|
||||
input = torch.randn((M, N), dtype=torch.float16, device=device)
|
||||
input_f16 = input
|
||||
else: # d_type is float8
|
||||
raw_data = torch.randn((M, N), dtype=torch.float32, device='cuda') * 10
|
||||
if (d_type == tl.float8e4b8 and TORCH_HAS_FP8E4B8) or \
|
||||
(d_type == tl.float8e5b16 and TORCH_HAS_FP8E5B16) :
|
||||
input = raw_data.to(tl_to_torch_types[d_type])
|
||||
input_f16 = input.to(torch.float16)
|
||||
else:
|
||||
f8_tensor = raw_data.to(torch.int8)
|
||||
# keep only two bits of exponent to avoid overflow
|
||||
f8_tensor = f8_tensor & 0b00111111
|
||||
input = triton.reinterpret(f8_tensor, d_type)
|
||||
input_f16 = torch.empty_like(f8_tensor, dtype=torch.float16)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
n_elements = raw_data.numel()
|
||||
copy_kernel[grid](input, input_f16, n_elements, BLOCK_SIZE=1024)
|
||||
return input, input_f16
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, K, a_type, b_type, out_dtype",
|
||||
[(*shape, *ab_type, out_dtype)
|
||||
for shape in [[128, 256, 32],
|
||||
@@ -1071,6 +1123,10 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
|
||||
[tl.float8e4b15, tl.float16],
|
||||
[tl.float8e4b15x4, tl.float16],
|
||||
[tl.float8e5, tl.float16],
|
||||
[tl.float8e4b8, tl.float16],
|
||||
[tl.float8e5b16, tl.float16],
|
||||
[tl.float16, tl.float8e5b16],
|
||||
[tl.float16, tl.float8e4b8],
|
||||
[tl.float16, tl.float8e4nv],
|
||||
[tl.float16, tl.float8e4b15],
|
||||
[tl.float16, tl.float8e4b15x4],
|
||||
@@ -1078,17 +1134,8 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
|
||||
for out_dtype in [torch.float16, torch.float32]
|
||||
])
|
||||
def test_gemm_fp816_mixed_inputs(M, N, K, a_type, b_type, out_dtype, device = 'cuda'):
|
||||
|
||||
check_type_supported(out_dtype, device)
|
||||
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
input = tl.load(input_ptr + offsets, mask=mask)
|
||||
output = input
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
@@ -1117,15 +1164,128 @@ def test_gemm_fp816_mixed_inputs(M, N, K, a_type, b_type, out_dtype, device = 'c
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
# We accumulate along the K dimension.
|
||||
accumulator += tl.dot(a, b, out_dtype=compute_type)
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
c = accumulator
|
||||
c = accumulator.to(compute_type)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output matrix C with masks.
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
def matmul(a, b, c_type):
|
||||
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
|
||||
if c_type == torch.float16:
|
||||
comp_type = tl.float16
|
||||
else:
|
||||
comp_type = tl.float32
|
||||
|
||||
|
||||
c = torch.empty((M, N), device = a.device, dtype=c_type)
|
||||
# 1D launch kernel where each block gets its own program.
|
||||
grid = lambda META: (
|
||||
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
|
||||
)
|
||||
matmul_kernel[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
compute_type = comp_type,
|
||||
BLOCK_SIZE_M=32,
|
||||
BLOCK_SIZE_N=64,
|
||||
BLOCK_SIZE_K=64,
|
||||
GROUP_SIZE_M=4,
|
||||
num_stages=1,
|
||||
num_warps=2,
|
||||
)
|
||||
return c
|
||||
|
||||
a, a_f16 = gen_input(M, K, a_type, 11, device=device)
|
||||
b, b_f16 = gen_input(K, N, b_type, 22, device=device)
|
||||
|
||||
# call torch function to compute gold
|
||||
golden = torch.matmul(a_f16, b_f16)
|
||||
|
||||
c = matmul(a, b, out_dtype)
|
||||
torch.testing.assert_close(c.to(golden.dtype), golden, rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
# @pytest.mark.skip(reason="Pytorch does not support the following types, so need to skip for now")
|
||||
@pytest.mark.parametrize("M, N, K, a_type, b_type, out_dtype",
|
||||
[(*shape, *ab_type, out_dtype)
|
||||
for shape in [[128, 256, 32],
|
||||
[128, 16, 32],
|
||||
[32, 128, 64],
|
||||
[128, 128, 64],
|
||||
[64, 128, 128],
|
||||
[32, 128, 64],
|
||||
[64, 64, 32],
|
||||
[32, 32, 128],
|
||||
[128, 128, 64],
|
||||
[64, 128, 128]]
|
||||
for ab_type in [[tl.float8e4b8, tl.float8e4b8],
|
||||
[tl.float8e5b16, tl.float8e4b8],
|
||||
[tl.float8e4b8, tl.float8e5b16],
|
||||
[tl.float8e5b16, tl.float8e5b16]]
|
||||
for out_dtype in [torch.float32]
|
||||
])
|
||||
def test_gemm_amd_fp8_inputs(M, N, K, a_type, b_type, out_dtype, device = 'cuda'):
|
||||
check_type_supported(out_dtype, device)
|
||||
|
||||
if triton.language.semantic.gpu_matrix_core_version() != 3:
|
||||
pytest.skip("fp8 data type is not available on hardware")
|
||||
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
compute_type:tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
# We accumulate along the K dimension.
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
c = accumulator.to(compute_type)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output matrix C with masks.
|
||||
@@ -1168,33 +1328,14 @@ def test_gemm_fp816_mixed_inputs(M, N, K, a_type, b_type, out_dtype, device = 'c
|
||||
|
||||
return c
|
||||
|
||||
|
||||
def gen_input(M, N, d_type, seed, device='cuda'):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
if d_type == tl.float16:
|
||||
input = torch.randn((M, K), dtype=torch.float16, device=device)
|
||||
input_f16 = input
|
||||
else: # d_type is float8
|
||||
f8_tensor = torch.randn((M, N), dtype=torch.float32, device='cuda') * 10
|
||||
f8_tensor = f8_tensor.to(torch.int8)
|
||||
# keep only two bits of exponent to avoid overflow
|
||||
f8_tensor = f8_tensor & 0b00111111
|
||||
input = triton.reinterpret(f8_tensor, d_type)
|
||||
input_f16 = torch.empty_like(f8_tensor, dtype=torch.float16)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
n_elements = f8_tensor.numel()
|
||||
copy_kernel[grid](input, input_f16, n_elements, BLOCK_SIZE=1024)
|
||||
return input, input_f16
|
||||
|
||||
a, a_f16 = gen_input(M, K, a_type, 11, device=device)
|
||||
a, a_f16 = gen_input(M, K, a_type, 21, device=device)
|
||||
b, b_f16 = gen_input(K, N, b_type, 22, device=device)
|
||||
|
||||
# call torch function to compute gold
|
||||
golden = torch.matmul(a_f16, b_f16)
|
||||
|
||||
c = matmul(a, b, out_dtype)
|
||||
torch.testing.assert_close(c.to(golden.dtype), golden, rtol=1e-2, atol=6e-2)
|
||||
|
||||
torch.testing.assert_close(golden, c.to(golden.dtype), rtol=1e-2, atol=2e-2)
|
||||
|
||||
|
||||
# ---------------
|
||||
@@ -1591,9 +1732,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
x = tl.load(Xs)
|
||||
y = tl.load(Ys)
|
||||
if in_dtype is tl.float8e4b15 or in_dtype is tl.float8e5:
|
||||
# if in_dtype is tl.float8e4b15 or in_dtype is tl.float8e5:
|
||||
# TODO change types when they are available
|
||||
# if in_dtype is tl.float8e5b16 or in_dtype is tl.float8e4b8:
|
||||
if in_dtype is tl.float8e5b16 or in_dtype is tl.float8e4b8:
|
||||
x = x.to(in_dtype)
|
||||
y = y.to(in_dtype)
|
||||
z = tl.dot(x, y, allow_tf32=ALLOW_TF32, out_dtype=out_dtype)
|
||||
@@ -1626,13 +1767,13 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
effective_in_dtype = tl.bfloat16
|
||||
elif in_dtype == "float8e5m2fnuz":
|
||||
# TODO change types when they are available
|
||||
effective_in_dtype = tl.float8e5
|
||||
# effective_in_dtype = tl.float8e5b16
|
||||
# effective_in_dtype = tl.float8e5
|
||||
effective_in_dtype = tl.float8e5b16
|
||||
in_dtype = "float32"
|
||||
elif in_dtype == "float8e4m3fnuz":
|
||||
# TODO change types when they are available
|
||||
effective_in_dtype = tl.float8e4b15
|
||||
# effective_in_dtype = tl.float8e4b8
|
||||
# effective_in_dtype = tl.float8e4b15
|
||||
effective_in_dtype = tl.float8e4b8
|
||||
in_dtype = "float32"
|
||||
else:
|
||||
assert("unexpected in dtype")
|
||||
@@ -1655,10 +1796,12 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
if effective_in_dtype.is_fp8():
|
||||
if effective_in_dtype.is_fp8e5():
|
||||
x = x + 1
|
||||
y = y + 1
|
||||
if effective_in_dtype.is_fp8e5b16():
|
||||
mask = 0b111111000110 << 20
|
||||
else:
|
||||
mask = 0b111110000111 << 20
|
||||
mask = 0b101111000111 << 20
|
||||
x = (x.view('uint32') & np.uint32(mask)).view('float32')
|
||||
y = (y.view('uint32') & np.uint32(mask)).view('float32')
|
||||
x_tri = to_triton(x, device=device)
|
||||
|
||||
@@ -1070,7 +1070,9 @@ def str_to_ty(name):
|
||||
return language.pointer_type(ty)
|
||||
tys = {
|
||||
"fp8e4nv": language.float8e4nv,
|
||||
"fp8e4b8": language.float8e4b8,
|
||||
"fp8e5": language.float8e5,
|
||||
"fp8e5b16": language.float8e5b16,
|
||||
"fp8e4b15": language.float8e4b15,
|
||||
"fp8e4b15x4": language.float8e4b15x4,
|
||||
"fp16": language.float16,
|
||||
|
||||
@@ -58,7 +58,9 @@ from .core import (
|
||||
float8e4b15,
|
||||
float8e4b15x4,
|
||||
float8e4nv,
|
||||
float8e4b8,
|
||||
float8e5,
|
||||
float8e5b16,
|
||||
function_type,
|
||||
inline_asm_elementwise,
|
||||
int1,
|
||||
|
||||
@@ -75,7 +75,7 @@ def _to_tensor(x, builder):
|
||||
class dtype:
|
||||
SINT_TYPES = ['int8', 'int16', 'int32', 'int64']
|
||||
UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64']
|
||||
FP_TYPES = ['fp8e4b15', 'fp8e4b15x4', 'fp8e4nv', 'fp8e5', 'fp16', 'bf16', 'fp32', 'fp64']
|
||||
FP_TYPES = ['fp8e4b15', 'fp8e4b15x4', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64']
|
||||
STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64']
|
||||
OTHER_TYPES = ['void']
|
||||
|
||||
@@ -107,10 +107,18 @@ class dtype:
|
||||
self.fp_mantissa_width = 3
|
||||
self.primitive_bitwidth = 8
|
||||
self.exponent_bias = 7
|
||||
elif name == 'fp8e4b8':
|
||||
self.fp_mantissa_width = 3
|
||||
self.primitive_bitwidth = 8
|
||||
self.exponent_bias = 8
|
||||
elif name == 'fp8e5':
|
||||
self.fp_mantissa_width = 2
|
||||
self.primitive_bitwidth = 8
|
||||
self.exponent_bias = 15
|
||||
elif name == 'fp8e5b16':
|
||||
self.fp_mantissa_width = 2
|
||||
self.primitive_bitwidth = 8
|
||||
self.exponent_bias = 16
|
||||
elif name == 'fp16':
|
||||
self.fp_mantissa_width = 10
|
||||
self.primitive_bitwidth = 16
|
||||
@@ -138,6 +146,9 @@ class dtype:
|
||||
def is_fp8e4nv(self):
|
||||
return self.name == 'fp8e4nv'
|
||||
|
||||
def is_fp8e4b8(self):
|
||||
return self.name == 'fp8e4b8'
|
||||
|
||||
def is_fp8e4b15(self):
|
||||
return self.name == 'fp8e4b15'
|
||||
|
||||
@@ -147,6 +158,9 @@ class dtype:
|
||||
def is_fp8e5(self):
|
||||
return self.name == 'fp8e5'
|
||||
|
||||
def is_fp8e5b16(self):
|
||||
return self.name == 'fp8e5b16'
|
||||
|
||||
def is_fp16(self):
|
||||
return self.name == 'fp16'
|
||||
|
||||
@@ -250,8 +264,12 @@ class dtype:
|
||||
return builder.get_int64_ty()
|
||||
elif self.name == 'fp8e5':
|
||||
return builder.get_fp8e5_ty()
|
||||
elif self.name == 'fp8e5b16':
|
||||
return builder.get_fp8e5b16_ty()
|
||||
elif self.name == 'fp8e4nv':
|
||||
return builder.get_fp8e4nv_ty()
|
||||
elif self.name == 'fp8e4b8':
|
||||
return builder.get_fp8e4b8_ty()
|
||||
elif self.name == 'fp8e4b15':
|
||||
return builder.get_fp8e4b15_ty()
|
||||
elif self.name == 'fp8e4b15x4':
|
||||
@@ -388,7 +406,9 @@ uint16 = dtype('uint16')
|
||||
uint32 = dtype('uint32')
|
||||
uint64 = dtype('uint64')
|
||||
float8e5 = dtype('fp8e5')
|
||||
float8e5b16 = dtype('fp8e5b16')
|
||||
float8e4nv = dtype('fp8e4nv')
|
||||
float8e4b8 = dtype('fp8e4b8')
|
||||
float8e4b15 = dtype('fp8e4b15')
|
||||
float8e4b15x4 = dtype('fp8e4b15x4')
|
||||
float16 = dtype('fp16')
|
||||
|
||||
@@ -247,7 +247,13 @@ class JITFunction(KernelInterface[T]):
|
||||
tys = {
|
||||
"bool": "i1",
|
||||
"float8e4nv": "fp8e4nv",
|
||||
"float8_e4m3fn": "fp8e4nv",
|
||||
"float8e4b8": "fp8e4b8",
|
||||
"float8_e4m3fnuz": "fp8e4b8",
|
||||
"float8e5": "fp8e5",
|
||||
"float8_e5m2": "fp8e5",
|
||||
"float8e5b16": "fp8e5b16",
|
||||
"float8_e5m2fnuz": "fp8e5b16",
|
||||
"float8e4b15": "fp8e4b15",
|
||||
"float8e4b15x4": "fp8e4b15x4",
|
||||
"float16": "fp16",
|
||||
|
||||
@@ -17,6 +17,9 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
torch_dtype:tl.constexpr = torch.float16
|
||||
# torch_dtype:tl.constexpr = torch.float8_e5m2fnuz
|
||||
TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz')
|
||||
|
||||
@triton.jit
|
||||
def max_fn(x, y):
|
||||
@@ -67,7 +70,7 @@ def _attn_fwd_inner(
|
||||
acc = acc * alpha[:, None]
|
||||
if not pre_load_v:
|
||||
v = tl.load(V_block_ptr)
|
||||
acc += tl.dot(p.to(tl.float16), v)
|
||||
acc += tl.dot(p.to(v.dtype), v)
|
||||
# -- update m_i and l_i
|
||||
l_ij = tl.sum(p, 1)
|
||||
l_i = l_i * alpha + l_ij
|
||||
@@ -144,7 +147,7 @@ def _attn_fwd(
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
# load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs
|
||||
q = tl.load(Q_block_ptr)
|
||||
q = (q * qk_scale).to(tl.float16)
|
||||
q = (q * qk_scale).to(q.dtype)
|
||||
# stage 1: off-band
|
||||
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
|
||||
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
|
||||
@@ -272,7 +275,7 @@ def _bwd_kernel(
|
||||
p = tl.math.exp2(qk * qk_scale - l_i[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
|
||||
dv += tl.dot(tl.trans(p.to(do.dtype)), do)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
@@ -357,7 +360,7 @@ def _bwd_kernel_dk_dv(
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
# load k and v: they will stay in SRAM throughout
|
||||
k = tl.load(K_block_ptr)
|
||||
k = (k * qk_scale).to(tl.float16)
|
||||
k = (k * qk_scale).to(k.dtype)
|
||||
v = tl.load(V_block_ptr)
|
||||
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
@@ -378,7 +381,7 @@ def _bwd_kernel_dk_dv(
|
||||
l_i = tl.load(l_ptrs + offs_m_curr)
|
||||
p = tl.math.exp2(qk - l_i)
|
||||
# -- compute dv ----
|
||||
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
|
||||
dv += tl.dot(tl.trans(p.to(do.dtype)), do)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_N, BLOCK_M], dtype=tl.float32) - Di
|
||||
@@ -407,7 +410,7 @@ def _bwd_kernel_dk_dv(
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
tl.store(DK_block_ptr, (dk * sm_scale).to(tl.float16))
|
||||
tl.store(DK_block_ptr, (dk * sm_scale).to(DK.dtype.element_ty))
|
||||
tl.store(DV_block_ptr, dv.to(tl.float16))
|
||||
|
||||
@triton.jit
|
||||
@@ -469,7 +472,7 @@ def _bwd_kernel_dq(
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
# load q and do: they will stay in SRAM throughout
|
||||
q = tl.load(Q_block_ptr)
|
||||
q = (q * qk_scale).to(tl.float16)
|
||||
q = (q * qk_scale).to(q.dtype)
|
||||
do = tl.load(DO_block_ptr)
|
||||
Di = tl.load(D_ptrs + offs_m)
|
||||
l_i = tl.load(l_ptrs + offs_m)
|
||||
@@ -518,7 +521,7 @@ class _attention(torch.autograd.Function):
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
o = torch.empty_like(q, dtype=v.dtype)
|
||||
if torch.version.hip is None:
|
||||
BLOCK_M = 128
|
||||
BLOCK_N = 64 if Lk <= 64 else 32
|
||||
@@ -569,6 +572,7 @@ class _attention(torch.autograd.Function):
|
||||
q, k, v, o, L = ctx.saved_tensors
|
||||
do = do.contiguous()
|
||||
dq = torch.zeros_like(q, dtype=torch.float32)
|
||||
# dk = torch.empty_like(k, dtype=torch_dtype)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
delta = torch.empty_like(L)
|
||||
@@ -648,26 +652,17 @@ attention = _attention.apply
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = (
|
||||
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
|
||||
.normal_(mean=0., std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
k = (
|
||||
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
|
||||
.normal_(mean=0., std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
v = (
|
||||
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
|
||||
.normal_(mean=0., std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
if TORCH_HAS_FP8E5:
|
||||
q = q.to(torch_dtype)
|
||||
k = k.to(torch_dtype)
|
||||
sm_scale = 0.5
|
||||
dout = torch.randn_like(q)
|
||||
dout = torch.randn_like(q, dtype=torch.float16)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||||
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
||||
p = torch.matmul(q.half(), k.transpose(2, 3).half()) * sm_scale
|
||||
if causal:
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
@@ -675,7 +670,7 @@ def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
|
||||
# triton implementation
|
||||
tri_out = attention(q, k, v, causal, sm_scale)
|
||||
# compare
|
||||
assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
|
||||
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
|
||||
@@ -690,7 +685,8 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
sm_scale = 0,5
|
||||
|
||||
sm_scale = 0.5
|
||||
split_kernel = True
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
@@ -777,6 +773,9 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
|
||||
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
if mode == "fwd":
|
||||
q = q.to(torch_dtype)
|
||||
k = k.to(torch_dtype)
|
||||
sm_scale = 1.3
|
||||
fn = lambda: attention(q, k, v, causal, sm_scale, split_kernel)
|
||||
if mode == 'bwd':
|
||||
|
||||
Reference in New Issue
Block a user