[FRONTEND] allow mixed precision FP8 matmul on pre-H100 hardware (#2281)

This commit is contained in:
Philippe Tillet
2023-09-11 20:54:29 -07:00
committed by GitHub
parent a5e483652b
commit bf4f9375a7
3 changed files with 5 additions and 2 deletions

View File

@@ -384,7 +384,8 @@ bool supportMMA(Value value, int version) {
// FP8 is not natively supported on all mma versions but it can always be
// promoted to fp16 therefore we can always support it.
bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() ||
elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ();
elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ() ||
elemTy.isFloat8E4M3B11FNUZ();
return isFP8 || elemTy.isF16() || elemTy.isBF16() ||
(elemTy.isF32() && version >= 2) ||
(elemTy.isInteger(8) && version >= 2);

View File

@@ -845,7 +845,7 @@ private:
bool isNativeHopperFP8 =
AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ();
bool isFP8 = isNativeHopperFP8 || AElType.isFloat8E5M2FNUZ() ||
AElType.isFloat8E4M3FN();
AElType.isFloat8E4M3FN() || AElType.isFloat8E4M3B11FNUZ();
if (!isFP8 || (isNativeHopperFP8 && mmaLayout.isHopper()))
return;
promoteType = builder.getF16Type();

View File

@@ -1266,6 +1266,8 @@ def dot(lhs: tl.tensor,
# Checks for cuda arch
if arch < 90:
assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90"
if lhs_dtype.is_fp8() and rhs_dtype.is_fp8():
return
assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
else:
assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(), "Dot op does not support fp8e4b15 on CUDA arch >= 90"