mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] allow mixed precision FP8 matmul on pre-H100 hardware (#2281)
This commit is contained in:
@@ -384,7 +384,8 @@ bool supportMMA(Value value, int version) {
|
||||
// FP8 is not natively supported on all mma versions but it can always be
|
||||
// promoted to fp16 therefore we can always support it.
|
||||
bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() ||
|
||||
elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ();
|
||||
elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ() ||
|
||||
elemTy.isFloat8E4M3B11FNUZ();
|
||||
return isFP8 || elemTy.isF16() || elemTy.isBF16() ||
|
||||
(elemTy.isF32() && version >= 2) ||
|
||||
(elemTy.isInteger(8) && version >= 2);
|
||||
|
||||
@@ -845,7 +845,7 @@ private:
|
||||
bool isNativeHopperFP8 =
|
||||
AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ();
|
||||
bool isFP8 = isNativeHopperFP8 || AElType.isFloat8E5M2FNUZ() ||
|
||||
AElType.isFloat8E4M3FN();
|
||||
AElType.isFloat8E4M3FN() || AElType.isFloat8E4M3B11FNUZ();
|
||||
if (!isFP8 || (isNativeHopperFP8 && mmaLayout.isHopper()))
|
||||
return;
|
||||
promoteType = builder.getF16Type();
|
||||
|
||||
@@ -1266,6 +1266,8 @@ def dot(lhs: tl.tensor,
|
||||
# Checks for cuda arch
|
||||
if arch < 90:
|
||||
assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90"
|
||||
if lhs_dtype.is_fp8() and rhs_dtype.is_fp8():
|
||||
return
|
||||
assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
|
||||
else:
|
||||
assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(), "Dot op does not support fp8e4b15 on CUDA arch >= 90"
|
||||
|
||||
Reference in New Issue
Block a user