diff --git a/include/triton/Dialect/Triton/IR/TritonTypes.td b/include/triton/Dialect/Triton/IR/TritonTypes.td index 5ab1f9907..bfe1514e3 100644 --- a/include/triton/Dialect/Triton/IR/TritonTypes.td +++ b/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -14,7 +14,7 @@ class TritonTypeDef } // Floating-point Type -def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2, F16, BF16, F32, F64], "floating-point">; +def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; def TT_FloatTensor : TensorOf<[TT_Float]>; def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>; diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 6fb006bb8..46aa1d82e 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -21,16 +21,6 @@ Fp16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter, Value a0 = bitcast(fp16x2Vec0, i32_ty); Value a1 = bitcast(fp16x2Vec1, i32_ty); - Value sign0 = and_(i32_ty, a0, i32_val(0x80008000)); - Value sign1 = and_(i32_ty, a1, i32_val(0x80008000)); - - a0 = and_(i32_ty, a0, i32_val(0x7fff7fff)); - a1 = and_(i32_ty, a1, i32_val(0x7fff7fff)); - a0 = add(i32_ty, a0, i32_val(0x00800080)); - a1 = add(i32_ty, a1, i32_val(0x00800080)); - - a0 = or_(i32_ty, a0, sign0); - a1 = or_(i32_ty, a1, sign1); auto fp8x4VecTy = vec_ty(i8_ty, 4); a0 = bitcast(a0, fp8x4VecTy); @@ -54,6 +44,57 @@ const std::string Fp16_to_Fp8E5M2 = "}"; #endif +#ifdef USE_ROCM +static Value convert_val_Fp16_to_Fp8E5M2FNUZ( + Location loc, ConversionPatternRewriter &rewriter, Value v) { + auto vi16 = bitcast(v, i16_ty); + auto e = and_(i16_ty, vi16, int_val(16, 0x7C00)); + auto sign = and_(i16_ty, vi16, int_val(16, 0x8000)); + + // normal value + auto a = and_(i16_ty, vi16, int_val(16, 0x7FFFF)); + auto a1 = add(i16_ty, a, int_val(16, 0x0400)); + auto o1 = or_(i16_ty, a1, sign); + + // subnormal value, e is 0 + auto m = and_(i16_ty, vi16, int_val(16, 0x03FF)); + auto m2 = shl(m, int_val(16, 1)); + auto o2 = or_(i16_ty, sign, or_(i16_ty, int_val(16, 1), m2)); + + auto e_is_zero = icmp_eq(e, int_val(16, 0)); + auto e_is_all1 = icmp_eq(e, int_val(16, 0x7C00)); + + auto ot = select(e_is_zero, o2, o1); + auto o = select(e_is_all1, vi16, ot); + auto fp8x2VecTy = vec_ty(i8_ty, 2); + auto res = bitcast(o, fp8x2VecTy); + + return extract_element(i8_ty, res, i32_val(1)); +} + +static SmallVector +Fp16_to_Fp8E5M2FNUZ(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + SmallVector result(4); + result[0] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[0]); + result[1] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[1]); + result[2] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[2]); + result[3] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[3]); + + return result; +} +#else +const std::string Fp16_to_Fp8E5M2FNUZ = + "{ \n" + ".reg .b32 a<2>; \n" + "and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe + "and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit) + "add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080 + "add.u32 a1, a1, 0x00800080; \n" // (round to nearest) + "prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0 + "}"; +#endif + #ifdef USE_ROCM static SmallVector Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, @@ -89,6 +130,61 @@ const std::string Fp8E5M2_to_Fp16 = "{ \n" "}"; #endif +#ifdef USE_ROCM + +static Value convert_val_Fp8E5M2FNUZ_to_Fp16( + Location loc, ConversionPatternRewriter &rewriter, Value v) { + auto fp8x2VecTy = vec_ty(i8_ty, 2); + Value a = undef(fp8x2VecTy); + a = insert_element(fp8x2VecTy, a, int_val(8, 0), i32_val(0)); + a = insert_element(fp8x2VecTy, a, v, i32_val(1)); + a = bitcast(a, i16_ty); + + auto e = and_(i16_ty, a, int_val(16, 0x7C00)); + auto m = and_(i16_ty, a, int_val(16, 0x0300)); + auto sign = and_(i16_ty, a, int_val(16, 0x8000)); + + // check whether all exponents are zeros + auto e_is_zero = icmp_eq(e, int_val(16, 0x0)); + + // case 1, e is zero, need to move m right by 1 bit + auto m1 = lshr(i16_ty, m, int_val(16, 1)); + auto o0 = or_(i16_ty, sign, m1); + + // case 2, e is nonzero, sub exponent by 1 + auto e1 = sub(i16_ty, e, int_val(16, 0x0400)); + + auto e_is_one = icmp_eq(e, int_val(16, 0x0400)); + auto m2 = add(i16_ty, m1, int_val(16, 0x0200)); + + auto o1 = or_(i16_ty, sign, or_(i16_ty, m, e1)); + auto o2 = or_(i16_ty, sign, m2); + + auto o12 = select(e_is_one, o2, o1); + auto o = select(e_is_zero, o0, o12); + + return bitcast(o, f16_ty); +} + +static SmallVector +Fp8E5M2FNUZ_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + + SmallVector result(4); + result[0] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[0]); + result[1] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[1]); + result[2] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[2]); + result[3] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[3]); + + return result; +} +#else +const std::string Fp8E5M2FNUZ_to_Fp16 = "{ \n" + "prmt.b32 $0, 0, $2, 0x5140; \n\t" + "prmt.b32 $1, 0, $2, 0x7362; \n\t" + "}"; +#endif + #ifdef USE_ROCM static SmallVector Fp8E5M2_to_Bf16(Location loc, ConversionPatternRewriter &rewriter, @@ -510,36 +606,50 @@ const std::string Fp16_to_Fp8E4M3B15x4 = // does not handle denormals and has // more than a single NaN values. -// Fp8E4M3 -> Fp16 (packed) #ifdef USE_ROCM +static Value convert_val_Fp8E4M3FNUZ_to_Fp16( + Location loc, ConversionPatternRewriter &rewriter, Value v) { + auto fp8x2VecTy = vec_ty(i8_ty, 2); + Value a = undef(fp8x2VecTy); + a = insert_element(fp8x2VecTy, a, int_val(8, 0), i32_val(0)); + a = insert_element(fp8x2VecTy, a, v, i32_val(1)); + a = bitcast(a, i16_ty); + + auto e_mask = int_val(16, 0x7A00); + auto e = and_(i16_ty, a, e_mask); + + auto m = and_(i16_ty, a, int_val(16, 0x0700)); + auto sign = and_(i16_ty, a, int_val(16, 0x8000)); + + // check whether all exponents are zeros + auto e_is_zero = icmp_eq(e, int_val(16, 0x0)); + auto b = and_(i16_ty, a, int_val(16, 0x7FFF)); + auto b1 = lshr(i16_ty, b, int_val(16, 1)); + + // case 1, e is nonzero, add exponent by 6 + auto o0v = add(i16_ty, b1, int_val(16, 0x0C00)); + auto o0 = or_(i16_ty, o0v, sign); + + // case 2, e is nonzero, add exponent by 7 + auto o1v = add(i16_ty, b1, int_val(16, 0x1C00)); + auto o1 = or_(i16_ty, o1v, sign); + + auto io = select(e_is_zero, o0, o1); + return bitcast(io, f16_ty); +} + +// Fp8E4M3FNUZ -> Fp16 (packed) static SmallVector -Fp8E4M3_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, +Fp8E4M3FNUZ_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { - auto fp8x4VecTy = vec_ty(i8_ty, 4); - Value a0 = undef(fp8x4VecTy); - a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0)); - a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1)); - a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2)); - a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3)); - a0 = bitcast(a0, i32_ty); + SmallVector result(2); + result[0] = convert_val_Fp8E4M3FNUZ_to_Fp16(loc, rewriter, v[0]); + result[1] = convert_val_Fp8E4M3FNUZ_to_Fp16(loc, rewriter, v[1]); - Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff)); - - b0 = lshr(i32_ty, b0, i32_val(1)); - - b0 = add(i32_ty, b0, i32_val(0x20002000)); - - b0 = or_( i32_ty, b0, and_(i32_ty, a0, i32_val(0x80008000)) ); - - auto fp16x2VecTy = vec_ty(f16_ty, 2); - auto fp16x2Vec0 = bitcast(b0, fp16x2VecTy); - - return { extract_element(f16_ty, fp16x2Vec0, i32_val(0)), - extract_element(f16_ty, fp16x2Vec0, i32_val(1)) - }; + return result; } #else -const std::string Fp8E4M3_to_Fp16 = +const std::string Fp8E4M3FNUZ_to_Fp16 = "{ \n" ".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4 "prmt.b32 a0, 0, $2, 0x5040; \n" // a0 = 0xf300f400 @@ -558,32 +668,50 @@ const std::string Fp8E4M3_to_Fp16 = // Fp16 -> Fp8E4M3 (packed) #ifdef USE_ROCM +static Value convert_val_Fp16_to_Fp8E4M3FNUZ( + Location loc, ConversionPatternRewriter &rewriter, Value v) { + auto vi16 = bitcast(v, i16_ty); + auto e10 = and_(vi16, int_val(16, 0x7C00)); + auto e = lshr(i16_ty, e10, int_val(16, 10)); + + auto s = and_(i16_ty, vi16, int_val(16, 0x8000)); + + auto m7 = and_(i16_ty, vi16, int_val(16, 0x0380)); + auto m = shl(i16_ty, m7, int_val(16, 1)); + + // three cases: + // 1) e > 21 --> e = 1111, + // 2) e <= 7 ---> e = 0, + // 3) others, normal conversion + auto e1 = int_val(16, 0x7800); + auto e2 = int_val(16, 0x0); + auto e31 = sub(i16_ty, e10, int_val(16, 0x1C00)); + auto e3 = shl(i16_ty, e31, int_val(16, 1)); + + auto c13 = icmp_sgt(e, int_val(16, 21)); + auto e13 = select(c13, e1, e3); + auto c23 = icmp_sle(e, int_val(16, 7)); + auto re = select(c23, e2, e13); + + auto r = or_(i16_ty, s, or_(i16_ty, re, m)); + auto fp8x2VecTy = vec_ty(i8_ty, 2); + auto res = bitcast(r, fp8x2VecTy); + + return extract_element(i8_ty, res, i32_val(1)); +} + static SmallVector -Fp16_to_Fp8E4M3(Location loc, ConversionPatternRewriter &rewriter, +Fp16_to_Fp8E4M3FNUZ(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { - auto fp16x2VecTy = vec_ty(f16_ty, 2); - Value fp16x2Vec0 = undef(fp16x2VecTy); - fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0)); - fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1)); - - fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty); - fp16x2Vec0 = sub(i32_ty, fp16x2Vec0, i32_val(0x20002000)); + SmallVector result(2); + result[0] = convert_val_Fp16_to_Fp8E4M3FNUZ(loc, rewriter, v[0]); + result[1] = convert_val_Fp16_to_Fp8E4M3FNUZ(loc, rewriter, v[1]); - Value a0 = shl(i32_ty, fp16x2Vec0, i32_val(1)); - a0 = and_(i32_ty, a0, i32_val(0x7fff7fff)); - a0 = add(i32_ty, a0, i32_val(0x00800080)); - Value b0 = or_( i32_ty, and_(i32_ty, fp16x2Vec0, i32_val(0x80008000)), a0 ); - - auto fp8x4VecTy = vec_ty(i8_ty, 4); - b0 = bitcast(b0, fp8x4VecTy); - - return {extract_element(i8_ty, b0, i32_val(1)), - extract_element(i8_ty, b0, i32_val(3)) - }; + return result; } #else -const std::string Fp16_to_Fp8E4M3 = +const std::string Fp16_to_Fp8E4M3FNUZ = "{ \n" ".reg .b32 a<2>, b<2>; \n" // see Fp8E4M3x4ToFp16x4 "sub.u32 a0, $1, 0x20002000; \n" // a0 = input0 - 0x20002000 @@ -1215,9 +1343,10 @@ struct FpToFpOpConversion ConverterT getConversionFunc(Type srcTy, Type dstTy) const { auto F8E4M3B15TyID = TypeID::get(); - auto F8E4M3TyID = TypeID::get(); - auto F8E5M2TyID = TypeID::get(); + auto F8E4M3FNUZTyID = TypeID::get(); auto F8E4M3FNTyID = TypeID::get(); + auto F8E5M2TyID = TypeID::get(); + auto F8E5M2FNUZTyID = TypeID::get(); auto F16TyID = TypeID::get(); auto BF16TyID = TypeID::get(); auto F32TyID = TypeID::get(); @@ -1230,8 +1359,9 @@ struct FpToFpOpConversion // F8 -> F16 {{F8E4M3B15TyID, F16TyID}, Fp8E4M3B15_to_Fp16}, {{F8E4M3FNTyID, F16TyID}, Fp8E4M3B15x4_to_Fp16}, - {{F8E4M3TyID, F16TyID}, Fp8E4M3_to_Fp16}, + {{F8E4M3FNUZTyID, F16TyID}, Fp8E4M3FNUZ_to_Fp16}, {{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16}, + {{F8E5M2FNUZTyID, F16TyID}, Fp8E5M2FNUZ_to_Fp16}, // F16 -> F8 #ifdef USE_ROCM {{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15}, @@ -1239,8 +1369,9 @@ struct FpToFpOpConversion {{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15(computeCapability >= 80)}, #endif {{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4}, - {{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3}, + {{F16TyID, F8E4M3FNUZTyID}, Fp16_to_Fp8E4M3FNUZ}, {{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2}, + {{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ}, // F8 -> BF16 {{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16}, // BF16 -> F8 diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index 8ab561a04..f04cee67e 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -37,6 +37,9 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( addConversion([&](mlir::Float8E5M2Type type) -> std::optional { return IntegerType::get(type.getContext(), 8); }); + addConversion([&](mlir::Float8E5M2FNUZType type) -> std::optional { + return IntegerType::get(type.getContext(), 8); + }); // Internally store bfloat16 as int16 addConversion([&](BFloat16Type type) -> std::optional { return IntegerType::get(type.getContext(), 16); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index d0c70e6ee..3cc3425b4 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -400,7 +400,7 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, } // namespace LLVM bool isF8(Type eType) { - return eType.isFloat8E5M2FNUZ() or eType.isFloat8E4M3FNUZ() or + return eType.isFloat8E4M3FNUZ() or eType.isFloat8E4M3FN() or eType.isFloat8E5M2() or eType.isFloat8E5M2FNUZ(); } diff --git a/python/src/triton.cc b/python/src/triton.cc index 110f1ae3c..0fe0eea4d 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -811,6 +811,10 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self) -> mlir::Type { return self.getBuilder().getType(); }) + .def("get_fp8e4b8_ty", + [](TritonOpBuilder &self) -> mlir::Type { + return self.getBuilder().getType(); + }) .def("get_fp8e4b15_ty", [](TritonOpBuilder &self) -> mlir::Type { // TODO: upstream FP8E4B15 into MLIR, or find a way to externally @@ -827,6 +831,10 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self) -> mlir::Type { return self.getBuilder().getType(); }) + .def("get_fp8e5b16_ty", + [](TritonOpBuilder &self) -> mlir::Type { + return self.getBuilder().getType(); + }) .def("get_half_ty", [](TritonOpBuilder &self) -> mlir::Type { return self.getBuilder().getF16Type(); diff --git a/python/test/unit/language/test_core_amd.py b/python/test/unit/language/test_core_amd.py index 9042de03d..ce4bfcf2f 100644 --- a/python/test/unit/language/test_core_amd.py +++ b/python/test/unit/language/test_core_amd.py @@ -959,6 +959,8 @@ def convert_float_to_float32(fp: torch.tensor, dtype=None): # float8e4m3nv does not have infinities output[fp == 0b01111111] = torch.nan output[fp == 0b11111111] = torch.nan + elif dtype in [tl.float8e4b8, tl.float8e5b16]: + output[fp==0b10000000] = torch.nan else: output = torch.where(exp == (1 << exp_width) - 1, ((sign << (tl.float32.primitive_bitwidth - 1)) | extended_exp | (frac << (tl.float32.fp_mantissa_width - dtype.fp_mantissa_width))).view(torch.float32), @@ -1015,7 +1017,11 @@ def deserialize_fp8(np_data, in_dtype): return np_data -@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4b15x4, tl.float8e4nv, tl.float8e5]) +@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, + tl.float8e4b15x4, + tl.float8e4b8, + tl.float8e5, + tl.float8e5b16]) @pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32]) def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device): """ @@ -1040,9 +1046,11 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device): is_nan = (ref_fp8 & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width exp_mask = 0b01111111 ^ ((1 << in_dtype.fp_mantissa_width) - 1) is_subnormal = np.logical_or((ref_fp8 & exp_mask) == 0, (ref_fp8 & exp_mask) == exp_mask) + ref_fp8[is_nan] = 0 ref_fp8[is_subnormal] = 0 tri_fp8 = torch.from_numpy(serialize_fp8(ref_fp8, in_dtype)).cuda() + tri_fp16 = torch.empty(256, dtype=out_dtype, device="cuda") copy_kernel[(1,)](triton.reinterpret(tri_fp8, in_dtype), tri_fp16, tri_fp16.shape[0], BLOCK_SIZE=1024) @@ -1055,6 +1063,50 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device): assert torch.all(tri_fp8 == ref_fp8) +TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz') +TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz') +tl_to_torch_types = { + tl.float16: torch.float16, + tl.float32: torch.float32, +} + +if TORCH_HAS_FP8E5B16: + tl_to_torch_types[tl.float8e5b16] = torch.float8_e5m2fnuz +if TORCH_HAS_FP8E4B8: + tl_to_torch_types[tl.float8e4b8] = torch.float8_e4m3fnuz + +@triton.jit +def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + input = tl.load(input_ptr + offsets, mask=mask) + output = input + tl.store(output_ptr + offsets, output, mask=mask) + +def gen_input(M, N, d_type, seed, device='cuda'): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + if d_type == tl.float16: + input = torch.randn((M, N), dtype=torch.float16, device=device) + input_f16 = input + else: # d_type is float8 + raw_data = torch.randn((M, N), dtype=torch.float32, device='cuda') * 10 + if (d_type == tl.float8e4b8 and TORCH_HAS_FP8E4B8) or \ + (d_type == tl.float8e5b16 and TORCH_HAS_FP8E5B16) : + input = raw_data.to(tl_to_torch_types[d_type]) + input_f16 = input.to(torch.float16) + else: + f8_tensor = raw_data.to(torch.int8) + # keep only two bits of exponent to avoid overflow + f8_tensor = f8_tensor & 0b00111111 + input = triton.reinterpret(f8_tensor, d_type) + input_f16 = torch.empty_like(f8_tensor, dtype=torch.float16) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + n_elements = raw_data.numel() + copy_kernel[grid](input, input_f16, n_elements, BLOCK_SIZE=1024) + return input, input_f16 + + @pytest.mark.parametrize("M, N, K, a_type, b_type, out_dtype", [(*shape, *ab_type, out_dtype) for shape in [[128, 256, 32], @@ -1071,6 +1123,10 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device): [tl.float8e4b15, tl.float16], [tl.float8e4b15x4, tl.float16], [tl.float8e5, tl.float16], + [tl.float8e4b8, tl.float16], + [tl.float8e5b16, tl.float16], + [tl.float16, tl.float8e5b16], + [tl.float16, tl.float8e4b8], [tl.float16, tl.float8e4nv], [tl.float16, tl.float8e4b15], [tl.float16, tl.float8e4b15x4], @@ -1078,17 +1134,8 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device): for out_dtype in [torch.float16, torch.float32] ]) def test_gemm_fp816_mixed_inputs(M, N, K, a_type, b_type, out_dtype, device = 'cuda'): - check_type_supported(out_dtype, device) - @triton.jit - def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - input = tl.load(input_ptr + offsets, mask=mask) - output = input - tl.store(output_ptr + offsets, output, mask=mask) - @triton.jit def matmul_kernel( a_ptr, b_ptr, c_ptr, @@ -1117,15 +1164,128 @@ def test_gemm_fp816_mixed_inputs(M, N, K, a_type, b_type, out_dtype, device = 'c a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. - accumulator += tl.dot(a, b, out_dtype=compute_type) + accumulator += tl.dot(a, b) a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk - c = accumulator + c = accumulator.to(compute_type) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + def matmul(a, b, c_type): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + + if c_type == torch.float16: + comp_type = tl.float16 + else: + comp_type = tl.float32 + + + c = torch.empty((M, N), device = a.device, dtype=c_type) + # 1D launch kernel where each block gets its own program. + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + compute_type = comp_type, + BLOCK_SIZE_M=32, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=64, + GROUP_SIZE_M=4, + num_stages=1, + num_warps=2, + ) + return c + + a, a_f16 = gen_input(M, K, a_type, 11, device=device) + b, b_f16 = gen_input(K, N, b_type, 22, device=device) + + # call torch function to compute gold + golden = torch.matmul(a_f16, b_f16) + + c = matmul(a, b, out_dtype) + torch.testing.assert_close(c.to(golden.dtype), golden, rtol=1e-2, atol=1e-2) + + +# @pytest.mark.skip(reason="Pytorch does not support the following types, so need to skip for now") +@pytest.mark.parametrize("M, N, K, a_type, b_type, out_dtype", + [(*shape, *ab_type, out_dtype) + for shape in [[128, 256, 32], + [128, 16, 32], + [32, 128, 64], + [128, 128, 64], + [64, 128, 128], + [32, 128, 64], + [64, 64, 32], + [32, 32, 128], + [128, 128, 64], + [64, 128, 128]] + for ab_type in [[tl.float8e4b8, tl.float8e4b8], + [tl.float8e5b16, tl.float8e4b8], + [tl.float8e4b8, tl.float8e5b16], + [tl.float8e5b16, tl.float8e5b16]] + for out_dtype in [torch.float32] + ]) +def test_gemm_amd_fp8_inputs(M, N, K, a_type, b_type, out_dtype, device = 'cuda'): + check_type_supported(out_dtype, device) + + if triton.language.semantic.gpu_matrix_core_version() != 3: + pytest.skip("fp8 data type is not available on hardware") + + @triton.jit + def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + compute_type:tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator.to(compute_type) # ----------------------------------------------------------- # Write back the block of the output matrix C with masks. @@ -1168,33 +1328,14 @@ def test_gemm_fp816_mixed_inputs(M, N, K, a_type, b_type, out_dtype, device = 'c return c - - def gen_input(M, N, d_type, seed, device='cuda'): - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - if d_type == tl.float16: - input = torch.randn((M, K), dtype=torch.float16, device=device) - input_f16 = input - else: # d_type is float8 - f8_tensor = torch.randn((M, N), dtype=torch.float32, device='cuda') * 10 - f8_tensor = f8_tensor.to(torch.int8) - # keep only two bits of exponent to avoid overflow - f8_tensor = f8_tensor & 0b00111111 - input = triton.reinterpret(f8_tensor, d_type) - input_f16 = torch.empty_like(f8_tensor, dtype=torch.float16) - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - n_elements = f8_tensor.numel() - copy_kernel[grid](input, input_f16, n_elements, BLOCK_SIZE=1024) - return input, input_f16 - - a, a_f16 = gen_input(M, K, a_type, 11, device=device) + a, a_f16 = gen_input(M, K, a_type, 21, device=device) b, b_f16 = gen_input(K, N, b_type, 22, device=device) # call torch function to compute gold golden = torch.matmul(a_f16, b_f16) - c = matmul(a, b, out_dtype) - torch.testing.assert_close(c.to(golden.dtype), golden, rtol=1e-2, atol=6e-2) + + torch.testing.assert_close(golden, c.to(golden.dtype), rtol=1e-2, atol=2e-2) # --------------- @@ -1591,9 +1732,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn x = tl.load(Xs) y = tl.load(Ys) - if in_dtype is tl.float8e4b15 or in_dtype is tl.float8e5: + # if in_dtype is tl.float8e4b15 or in_dtype is tl.float8e5: # TODO change types when they are available - # if in_dtype is tl.float8e5b16 or in_dtype is tl.float8e4b8: + if in_dtype is tl.float8e5b16 or in_dtype is tl.float8e4b8: x = x.to(in_dtype) y = y.to(in_dtype) z = tl.dot(x, y, allow_tf32=ALLOW_TF32, out_dtype=out_dtype) @@ -1626,13 +1767,13 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o effective_in_dtype = tl.bfloat16 elif in_dtype == "float8e5m2fnuz": # TODO change types when they are available - effective_in_dtype = tl.float8e5 - # effective_in_dtype = tl.float8e5b16 + # effective_in_dtype = tl.float8e5 + effective_in_dtype = tl.float8e5b16 in_dtype = "float32" elif in_dtype == "float8e4m3fnuz": # TODO change types when they are available - effective_in_dtype = tl.float8e4b15 - # effective_in_dtype = tl.float8e4b8 + # effective_in_dtype = tl.float8e4b15 + effective_in_dtype = tl.float8e4b8 in_dtype = "float32" else: assert("unexpected in dtype") @@ -1655,10 +1796,12 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32') w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32') if effective_in_dtype.is_fp8(): - if effective_in_dtype.is_fp8e5(): + x = x + 1 + y = y + 1 + if effective_in_dtype.is_fp8e5b16(): mask = 0b111111000110 << 20 else: - mask = 0b111110000111 << 20 + mask = 0b101111000111 << 20 x = (x.view('uint32') & np.uint32(mask)).view('float32') y = (y.view('uint32') & np.uint32(mask)).view('float32') x_tri = to_triton(x, device=device) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 328474642..aa7dfcbfb 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -1070,7 +1070,9 @@ def str_to_ty(name): return language.pointer_type(ty) tys = { "fp8e4nv": language.float8e4nv, + "fp8e4b8": language.float8e4b8, "fp8e5": language.float8e5, + "fp8e5b16": language.float8e5b16, "fp8e4b15": language.float8e4b15, "fp8e4b15x4": language.float8e4b15x4, "fp16": language.float16, diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 6b719c684..dedb1b919 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -58,7 +58,9 @@ from .core import ( float8e4b15, float8e4b15x4, float8e4nv, + float8e4b8, float8e5, + float8e5b16, function_type, inline_asm_elementwise, int1, diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 0c73c1f8b..c452aedfd 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -75,7 +75,7 @@ def _to_tensor(x, builder): class dtype: SINT_TYPES = ['int8', 'int16', 'int32', 'int64'] UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64'] - FP_TYPES = ['fp8e4b15', 'fp8e4b15x4', 'fp8e4nv', 'fp8e5', 'fp16', 'bf16', 'fp32', 'fp64'] + FP_TYPES = ['fp8e4b15', 'fp8e4b15x4', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64'] STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64'] OTHER_TYPES = ['void'] @@ -107,10 +107,18 @@ class dtype: self.fp_mantissa_width = 3 self.primitive_bitwidth = 8 self.exponent_bias = 7 + elif name == 'fp8e4b8': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 8 elif name == 'fp8e5': self.fp_mantissa_width = 2 self.primitive_bitwidth = 8 self.exponent_bias = 15 + elif name == 'fp8e5b16': + self.fp_mantissa_width = 2 + self.primitive_bitwidth = 8 + self.exponent_bias = 16 elif name == 'fp16': self.fp_mantissa_width = 10 self.primitive_bitwidth = 16 @@ -138,6 +146,9 @@ class dtype: def is_fp8e4nv(self): return self.name == 'fp8e4nv' + def is_fp8e4b8(self): + return self.name == 'fp8e4b8' + def is_fp8e4b15(self): return self.name == 'fp8e4b15' @@ -147,6 +158,9 @@ class dtype: def is_fp8e5(self): return self.name == 'fp8e5' + def is_fp8e5b16(self): + return self.name == 'fp8e5b16' + def is_fp16(self): return self.name == 'fp16' @@ -250,8 +264,12 @@ class dtype: return builder.get_int64_ty() elif self.name == 'fp8e5': return builder.get_fp8e5_ty() + elif self.name == 'fp8e5b16': + return builder.get_fp8e5b16_ty() elif self.name == 'fp8e4nv': return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b8': + return builder.get_fp8e4b8_ty() elif self.name == 'fp8e4b15': return builder.get_fp8e4b15_ty() elif self.name == 'fp8e4b15x4': @@ -388,7 +406,9 @@ uint16 = dtype('uint16') uint32 = dtype('uint32') uint64 = dtype('uint64') float8e5 = dtype('fp8e5') +float8e5b16 = dtype('fp8e5b16') float8e4nv = dtype('fp8e4nv') +float8e4b8 = dtype('fp8e4b8') float8e4b15 = dtype('fp8e4b15') float8e4b15x4 = dtype('fp8e4b15x4') float16 = dtype('fp16') diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 5aebb53e5..d346d51f1 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -247,7 +247,13 @@ class JITFunction(KernelInterface[T]): tys = { "bool": "i1", "float8e4nv": "fp8e4nv", + "float8_e4m3fn": "fp8e4nv", + "float8e4b8": "fp8e4b8", + "float8_e4m3fnuz": "fp8e4b8", "float8e5": "fp8e5", + "float8_e5m2": "fp8e5", + "float8e5b16": "fp8e5b16", + "float8_e5m2fnuz": "fp8e5b16", "float8e4b15": "fp8e4b15", "float8e4b15x4": "fp8e4b15x4", "float16": "fp16", diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index d0019b1fe..73e072ea9 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -17,6 +17,9 @@ import torch import triton import triton.language as tl +torch_dtype:tl.constexpr = torch.float16 +# torch_dtype:tl.constexpr = torch.float8_e5m2fnuz +TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz') @triton.jit def max_fn(x, y): @@ -67,7 +70,7 @@ def _attn_fwd_inner( acc = acc * alpha[:, None] if not pre_load_v: v = tl.load(V_block_ptr) - acc += tl.dot(p.to(tl.float16), v) + acc += tl.dot(p.to(v.dtype), v) # -- update m_i and l_i l_ij = tl.sum(p, 1) l_i = l_i * alpha + l_ij @@ -144,7 +147,7 @@ def _attn_fwd( qk_scale = sm_scale * 1.44269504 # load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs q = tl.load(Q_block_ptr) - q = (q * qk_scale).to(tl.float16) + q = (q * qk_scale).to(q.dtype) # stage 1: off-band # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE @@ -272,7 +275,7 @@ def _bwd_kernel( p = tl.math.exp2(qk * qk_scale - l_i[:, None]) # compute dv do = tl.load(do_ptrs) - dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) + dv += tl.dot(tl.trans(p.to(do.dtype)), do) # compute dp = dot(v, do) Di = tl.load(D_ptrs + offs_m_curr) dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] @@ -357,7 +360,7 @@ def _bwd_kernel_dk_dv( qk_scale = sm_scale * 1.44269504 # load k and v: they will stay in SRAM throughout k = tl.load(K_block_ptr) - k = (k * qk_scale).to(tl.float16) + k = (k * qk_scale).to(k.dtype) v = tl.load(V_block_ptr) dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) @@ -378,7 +381,7 @@ def _bwd_kernel_dk_dv( l_i = tl.load(l_ptrs + offs_m_curr) p = tl.math.exp2(qk - l_i) # -- compute dv ---- - dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) + dv += tl.dot(tl.trans(p.to(do.dtype)), do) # compute dp = dot(v, do) Di = tl.load(D_ptrs + offs_m_curr) dp = tl.zeros([BLOCK_N, BLOCK_M], dtype=tl.float32) - Di @@ -407,7 +410,7 @@ def _bwd_kernel_dk_dv( block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) - tl.store(DK_block_ptr, (dk * sm_scale).to(tl.float16)) + tl.store(DK_block_ptr, (dk * sm_scale).to(DK.dtype.element_ty)) tl.store(DV_block_ptr, dv.to(tl.float16)) @triton.jit @@ -469,7 +472,7 @@ def _bwd_kernel_dq( qk_scale = sm_scale * 1.44269504 # load q and do: they will stay in SRAM throughout q = tl.load(Q_block_ptr) - q = (q * qk_scale).to(tl.float16) + q = (q * qk_scale).to(q.dtype) do = tl.load(DO_block_ptr) Di = tl.load(D_ptrs + offs_m) l_i = tl.load(l_ptrs + offs_m) @@ -518,7 +521,7 @@ class _attention(torch.autograd.Function): Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128} - o = torch.empty_like(q) + o = torch.empty_like(q, dtype=v.dtype) if torch.version.hip is None: BLOCK_M = 128 BLOCK_N = 64 if Lk <= 64 else 32 @@ -569,6 +572,7 @@ class _attention(torch.autograd.Function): q, k, v, o, L = ctx.saved_tensors do = do.contiguous() dq = torch.zeros_like(q, dtype=torch.float32) + # dk = torch.empty_like(k, dtype=torch_dtype) dk = torch.empty_like(k) dv = torch.empty_like(v) delta = torch.empty_like(L) @@ -648,26 +652,17 @@ attention = _attention.apply @pytest.mark.parametrize('causal', [False, True]) def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): torch.manual_seed(20) - q = ( - torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0., std=0.5) - .requires_grad_() - ) - k = ( - torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0., std=0.5) - .requires_grad_() - ) - v = ( - torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0., std=0.5) - .requires_grad_() - ) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + if TORCH_HAS_FP8E5: + q = q.to(torch_dtype) + k = k.to(torch_dtype) sm_scale = 0.5 - dout = torch.randn_like(q) + dout = torch.randn_like(q, dtype=torch.float16) # reference implementation M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + p = torch.matmul(q.half(), k.transpose(2, 3).half()) * sm_scale if causal: p[:, :, M == 0] = float("-inf") p = torch.softmax(p.float(), dim=-1).half() @@ -675,7 +670,7 @@ def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): # triton implementation tri_out = attention(q, k, v, causal, sm_scale) # compare - assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2) @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', @@ -690,7 +685,8 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - sm_scale = 0,5 + + sm_scale = 0.5 split_kernel = True dout = torch.randn_like(q) # reference implementation @@ -777,6 +773,9 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + if mode == "fwd": + q = q.to(torch_dtype) + k = k.to(torch_dtype) sm_scale = 1.3 fn = lambda: attention(q, k, v, causal, sm_scale, split_kernel) if mode == 'bwd':