[FRONTEND][BACKEND] Fix trans for float8e4b15 (#1964)

float8e4b15 is a packed type; it is incompatible with most of our layout conversions. For now, we just convert to float16.
This commit is contained in:
Philippe Tillet
2023-07-19 11:30:39 -07:00
committed by GitHub
parent 15ab48d407
commit 68124676c9
3 changed files with 59 additions and 3 deletions

View File

@@ -439,12 +439,15 @@ private:
}
// Potentially we need to store for multiple CTAs in this replication
auto accumNumReplicates = product<unsigned>(numReplicates);
// unsigned elems = getTotalElemsPerThread(srcTy);
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
rewriter, srcTy);
unsigned inVec = 0;
unsigned outVec = 0;
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
if (getElementTypeOrSelf(op.getType()).isa<mlir::Float8E4M3B11FNUZType>()) {
assert(inVec % 4 == 0 && "conversion not supported for FP8E4M3B15");
assert(outVec % 4 == 0 && "conversion not supported for FP8E4M3B15");
}
unsigned outElems = getTotalElemsPerThread(dstTy);
auto outOrd = getOrder(dstLayout);

View File

@@ -313,6 +313,7 @@ public:
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
// Preprocess
decomposeFp8e4b15Convert(mod);
decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp);
decomposeBlockedToDotOperand(mod);
decomposeInsertSliceAsyncOp(mod);
@@ -442,6 +443,33 @@ private:
allocation.getSharedMemorySize()));
}
void decomposeFp8e4b15Convert(ModuleOp mod) const {
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
OpBuilder builder(cvtOp);
if (!getElementTypeOrSelf(cvtOp).isa<mlir::Float8E4M3B11FNUZType>())
return;
auto shape = cvtOp.getType().cast<RankedTensorType>().getShape();
auto argEncoding =
cvtOp.getOperand().getType().cast<RankedTensorType>().getEncoding();
auto cvtEncoding = cvtOp.getType().cast<RankedTensorType>().getEncoding();
if (argEncoding.isa<triton::gpu::DotOperandEncodingAttr>() ||
cvtEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
return;
auto F16Ty = builder.getF16Type();
auto newArgType = RankedTensorType::get(shape, F16Ty, argEncoding);
auto newCvtType = RankedTensorType::get(shape, F16Ty, cvtEncoding);
auto newArg = builder.create<mlir::triton::FpToFpOp>(
cvtOp.getLoc(), newArgType, cvtOp.getOperand());
auto newCvt = builder.create<mlir::triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), newCvtType, newArg);
auto newRet = builder.create<mlir::triton::FpToFpOp>(
cvtOp.getLoc(), cvtOp.getType(), newCvt.getResult());
cvtOp.replaceAllUsesWith(newRet.getResult());
cvtOp.erase();
});
}
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps,
int threadsPerWarp) const {
// Replace `mma -> dot_op` with `mma -> blocked -> dot_op`

View File

@@ -44,6 +44,9 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, h
x = rs.randint(low, high, shape, dtype=dtype)
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
return x
elif dtype_str and 'float8' in dtype_str:
x = rs.randint(20, 40, shape, dtype=np.int8)
return x
elif dtype_str in float_dtypes:
return rs.normal(0, 1, shape).astype(dtype_str)
elif dtype_str == 'bfloat16':
@@ -67,6 +70,8 @@ def to_triton(x: np.ndarray, device='cuda', dst_type=None) -> Union[TensorWrappe
x_signed = x.astype(getattr(np, signed_type_name))
return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t))
else:
if dst_type and 'float8' in dst_type:
return reinterpret(torch.tensor(x, device=device), getattr(tl, dst_type))
if t == 'float32' and dst_type == 'bfloat16':
return torch.tensor(x, device=device).bfloat16()
return torch.tensor(x, device=device)
@@ -1276,6 +1281,20 @@ def serialize_fp8(np_data, in_dtype):
else:
return np_data
# inverse of `serialize_fp8`
def deserialize_fp8(np_data, in_dtype):
if in_dtype == tl.float8e4b15:
f8x4 = np_data.view(np.uint32)
s = [(f8x4 & (0x80000000 >> i)) << i for i in [0, 16, 1, 17]]
b = [(f8x4 & (0x7f000000 >> i)) << i for i in [1, 17, 8, 24]]
signs = (s[0] >> 0) | (s[1] >> 8) | (s[2] >> 16) | (s[3] >> 24)
bits = (b[0] >> 0) | (b[1] >> 8) | (b[2] >> 16) | (b[3] >> 24)
return (signs | bits).view(np.int8)
else:
return np_data
@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4, tl.float8e5])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.float32])
@@ -1901,7 +1920,7 @@ def test_generic_reduction(device):
@pytest.mark.parametrize("dtype_str, shape, perm",
[(dtype, shape, perm)
# TODO: bfloat16
for dtype in ['float16', 'float32']
for dtype in ['float8e4b15', 'float16', 'float32']
for shape in [(64, 64), (128, 128)]
for perm in [(1, 0)]])
def test_permute(dtype_str, shape, perm, device):
@@ -1930,7 +1949,13 @@ def test_permute(dtype_str, shape, perm, device):
z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1),
BLOCK_M=shape[0], BLOCK_N=shape[1])
# numpy result
z_ref = x.transpose(*perm)
if dtype_str == 'float8e4b15':
ty = tl.float8e4b15
z_ref = serialize_fp8(deserialize_fp8(x, ty).T.copy(), ty)
z_tri = z_tri.base
z_tri_contiguous = z_tri_contiguous.base
else:
z_ref = x.transpose(*perm)
# compare
np.testing.assert_allclose(to_numpy(z_tri), z_ref)
np.testing.assert_allclose(to_numpy(z_tri_contiguous), z_ref)