mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
[BACKEND] Lift restriction for float8e4b15 to only support row-col layout (#2212)
This commit is contained in:
@@ -511,14 +511,6 @@ std::function<void(int, int)> getLoadMatrixFn(
|
||||
const int elemBytes = tensorTy.getElementTypeBitWidth() / 8;
|
||||
auto order = sharedLayout.getOrder();
|
||||
|
||||
if (tensor.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType()
|
||||
.isa<mlir::Float8E4M3B11FNUZType, mlir::Float8E4M3FNType>()) {
|
||||
bool noTrans = (isA ^ (order[0] == 0));
|
||||
assert(noTrans && "float8e4b15 must have row-col layout");
|
||||
}
|
||||
|
||||
if (kWidth != (4 / elemBytes))
|
||||
assert(vecPhase == 1 || vecPhase == 4 * kWidth);
|
||||
|
||||
|
||||
@@ -85,6 +85,7 @@ def f8_to_f16(x, dtype):
|
||||
("float8e4nv", "float8e4nv"),
|
||||
("float8e5", "float8e4nv"),
|
||||
("float8e5", "float8e5"),
|
||||
("float8e4b15", "float8e4b15"),
|
||||
("float8e4nv", "float16"),
|
||||
("float16", "float8e5"),
|
||||
("float16", "float32"),
|
||||
@@ -105,17 +106,6 @@ def f8_to_f16(x, dtype):
|
||||
("bfloat16", "float32"),
|
||||
("float32", "bfloat16")] for AT in [False, True] for BT in [False, True]
|
||||
],
|
||||
*[
|
||||
# float8e4b15 only supports row-col layout
|
||||
[
|
||||
(128, 128, 32, 1, 4, 2, None, None, None, False, True, ADTYPE, BDTYPE, True),
|
||||
] for ADTYPE, BDTYPE in [("float8e4b15", "float8e5"),
|
||||
("float8e4b15", "float16"),
|
||||
("float16", "float8e4b15"),
|
||||
("float8e5", "float8e5"),
|
||||
("float8e4nv", "float8e4nv"),
|
||||
("int8", "int8")]
|
||||
]
|
||||
),
|
||||
)
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32):
|
||||
|
||||
Reference in New Issue
Block a user