mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND][BACKEND] Fix trans for float8e4b15 (#1964)
float8e4b15 is a packed type; it is incompatible with most of our layout conversions. For now, we just convert to float16.
This commit is contained in:
@@ -439,12 +439,15 @@ private:
|
||||
}
|
||||
// Potentially we need to store for multiple CTAs in this replication
|
||||
auto accumNumReplicates = product<unsigned>(numReplicates);
|
||||
// unsigned elems = getTotalElemsPerThread(srcTy);
|
||||
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
|
||||
rewriter, srcTy);
|
||||
unsigned inVec = 0;
|
||||
unsigned outVec = 0;
|
||||
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
|
||||
if (getElementTypeOrSelf(op.getType()).isa<mlir::Float8E4M3B11FNUZType>()) {
|
||||
assert(inVec % 4 == 0 && "conversion not supported for FP8E4M3B15");
|
||||
assert(outVec % 4 == 0 && "conversion not supported for FP8E4M3B15");
|
||||
}
|
||||
|
||||
unsigned outElems = getTotalElemsPerThread(dstTy);
|
||||
auto outOrd = getOrder(dstLayout);
|
||||
|
||||
@@ -313,6 +313,7 @@ public:
|
||||
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
|
||||
|
||||
// Preprocess
|
||||
decomposeFp8e4b15Convert(mod);
|
||||
decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp);
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
decomposeInsertSliceAsyncOp(mod);
|
||||
@@ -442,6 +443,33 @@ private:
|
||||
allocation.getSharedMemorySize()));
|
||||
}
|
||||
|
||||
void decomposeFp8e4b15Convert(ModuleOp mod) const {
|
||||
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
if (!getElementTypeOrSelf(cvtOp).isa<mlir::Float8E4M3B11FNUZType>())
|
||||
return;
|
||||
auto shape = cvtOp.getType().cast<RankedTensorType>().getShape();
|
||||
auto argEncoding =
|
||||
cvtOp.getOperand().getType().cast<RankedTensorType>().getEncoding();
|
||||
auto cvtEncoding = cvtOp.getType().cast<RankedTensorType>().getEncoding();
|
||||
if (argEncoding.isa<triton::gpu::DotOperandEncodingAttr>() ||
|
||||
cvtEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
|
||||
return;
|
||||
auto F16Ty = builder.getF16Type();
|
||||
|
||||
auto newArgType = RankedTensorType::get(shape, F16Ty, argEncoding);
|
||||
auto newCvtType = RankedTensorType::get(shape, F16Ty, cvtEncoding);
|
||||
auto newArg = builder.create<mlir::triton::FpToFpOp>(
|
||||
cvtOp.getLoc(), newArgType, cvtOp.getOperand());
|
||||
auto newCvt = builder.create<mlir::triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), newCvtType, newArg);
|
||||
auto newRet = builder.create<mlir::triton::FpToFpOp>(
|
||||
cvtOp.getLoc(), cvtOp.getType(), newCvt.getResult());
|
||||
cvtOp.replaceAllUsesWith(newRet.getResult());
|
||||
cvtOp.erase();
|
||||
});
|
||||
}
|
||||
|
||||
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps,
|
||||
int threadsPerWarp) const {
|
||||
// Replace `mma -> dot_op` with `mma -> blocked -> dot_op`
|
||||
|
||||
@@ -44,6 +44,9 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, h
|
||||
x = rs.randint(low, high, shape, dtype=dtype)
|
||||
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
|
||||
return x
|
||||
elif dtype_str and 'float8' in dtype_str:
|
||||
x = rs.randint(20, 40, shape, dtype=np.int8)
|
||||
return x
|
||||
elif dtype_str in float_dtypes:
|
||||
return rs.normal(0, 1, shape).astype(dtype_str)
|
||||
elif dtype_str == 'bfloat16':
|
||||
@@ -67,6 +70,8 @@ def to_triton(x: np.ndarray, device='cuda', dst_type=None) -> Union[TensorWrappe
|
||||
x_signed = x.astype(getattr(np, signed_type_name))
|
||||
return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t))
|
||||
else:
|
||||
if dst_type and 'float8' in dst_type:
|
||||
return reinterpret(torch.tensor(x, device=device), getattr(tl, dst_type))
|
||||
if t == 'float32' and dst_type == 'bfloat16':
|
||||
return torch.tensor(x, device=device).bfloat16()
|
||||
return torch.tensor(x, device=device)
|
||||
@@ -1276,6 +1281,20 @@ def serialize_fp8(np_data, in_dtype):
|
||||
else:
|
||||
return np_data
|
||||
|
||||
# inverse of `serialize_fp8`
|
||||
|
||||
|
||||
def deserialize_fp8(np_data, in_dtype):
|
||||
if in_dtype == tl.float8e4b15:
|
||||
f8x4 = np_data.view(np.uint32)
|
||||
s = [(f8x4 & (0x80000000 >> i)) << i for i in [0, 16, 1, 17]]
|
||||
b = [(f8x4 & (0x7f000000 >> i)) << i for i in [1, 17, 8, 24]]
|
||||
signs = (s[0] >> 0) | (s[1] >> 8) | (s[2] >> 16) | (s[3] >> 24)
|
||||
bits = (b[0] >> 0) | (b[1] >> 8) | (b[2] >> 16) | (b[3] >> 24)
|
||||
return (signs | bits).view(np.int8)
|
||||
else:
|
||||
return np_data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4, tl.float8e5])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
@@ -1901,7 +1920,7 @@ def test_generic_reduction(device):
|
||||
@pytest.mark.parametrize("dtype_str, shape, perm",
|
||||
[(dtype, shape, perm)
|
||||
# TODO: bfloat16
|
||||
for dtype in ['float16', 'float32']
|
||||
for dtype in ['float8e4b15', 'float16', 'float32']
|
||||
for shape in [(64, 64), (128, 128)]
|
||||
for perm in [(1, 0)]])
|
||||
def test_permute(dtype_str, shape, perm, device):
|
||||
@@ -1930,7 +1949,13 @@ def test_permute(dtype_str, shape, perm, device):
|
||||
z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1),
|
||||
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
||||
# numpy result
|
||||
z_ref = x.transpose(*perm)
|
||||
if dtype_str == 'float8e4b15':
|
||||
ty = tl.float8e4b15
|
||||
z_ref = serialize_fp8(deserialize_fp8(x, ty).T.copy(), ty)
|
||||
z_tri = z_tri.base
|
||||
z_tri_contiguous = z_tri_contiguous.base
|
||||
else:
|
||||
z_ref = x.transpose(*perm)
|
||||
# compare
|
||||
np.testing.assert_allclose(to_numpy(z_tri), z_ref)
|
||||
np.testing.assert_allclose(to_numpy(z_tri_contiguous), z_ref)
|
||||
|
||||
Reference in New Issue
Block a user