[BACKEND] Lift restriction for float8e4b15 to only support row-col layout (#2212)

This commit is contained in:
Philippe Tillet
2023-08-30 14:06:31 -07:00
committed by GitHub
parent 3175ee4ce7
commit ec51552fff
2 changed files with 1 additions and 19 deletions

View File

@@ -511,14 +511,6 @@ std::function<void(int, int)> getLoadMatrixFn(
const int elemBytes = tensorTy.getElementTypeBitWidth() / 8;
auto order = sharedLayout.getOrder();
if (tensor.getType()
.cast<RankedTensorType>()
.getElementType()
.isa<mlir::Float8E4M3B11FNUZType, mlir::Float8E4M3FNType>()) {
bool noTrans = (isA ^ (order[0] == 0));
assert(noTrans && "float8e4b15 must have row-col layout");
}
if (kWidth != (4 / elemBytes))
assert(vecPhase == 1 || vecPhase == 4 * kWidth);

View File

@@ -85,6 +85,7 @@ def f8_to_f16(x, dtype):
("float8e4nv", "float8e4nv"),
("float8e5", "float8e4nv"),
("float8e5", "float8e5"),
("float8e4b15", "float8e4b15"),
("float8e4nv", "float16"),
("float16", "float8e5"),
("float16", "float32"),
@@ -105,17 +106,6 @@ def f8_to_f16(x, dtype):
("bfloat16", "float32"),
("float32", "bfloat16")] for AT in [False, True] for BT in [False, True]
],
*[
# float8e4b15 only supports row-col layout
[
(128, 128, 32, 1, 4, 2, None, None, None, False, True, ADTYPE, BDTYPE, True),
] for ADTYPE, BDTYPE in [("float8e4b15", "float8e5"),
("float8e4b15", "float16"),
("float16", "float8e4b15"),
("float8e5", "float8e5"),
("float8e4nv", "float8e4nv"),
("int8", "int8")]
]
),
)
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32):