mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
support gemm fp8/fp16 mixed input (#333)
* changes to support fp8/fp16 mixed inputs * add unit test for fp8/fp16 mixed input for gemm
This commit is contained in:
@@ -15,6 +15,7 @@ Type getShemPtrTy(Type elemTy) {
|
||||
auto ctx = elemTy.getContext();
|
||||
return ptr_ty(type::i16Ty(ctx), 3);
|
||||
}
|
||||
|
||||
return ptr_ty(elemTy, 3);
|
||||
}
|
||||
|
||||
@@ -439,9 +440,9 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
std::max<int>(mfmaInstrM * mfmaInstrK / iWaveSize /*wave size*/, 1);
|
||||
unsigned int maxNumWarps = shape[0] / mfmaInstrM;
|
||||
int warpsPerGroupM = std::min(warpsPerCTA[0], maxNumWarps);
|
||||
aElemTy = typeConverter->convertType(aElemTy);
|
||||
|
||||
SmallVector<Value> ha;
|
||||
|
||||
if (fastPathAvailable(smemObj, sharedLayout, mfmaLayout)) {
|
||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||
SmallVector<Value> offsets;
|
||||
@@ -459,7 +460,7 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
|
||||
|
||||
Type smemPtrTy = getShemPtrTy(aElemTy);
|
||||
Type resElemTy = aElemTy.isBF16() ? i16_ty : aElemTy;
|
||||
Type resElemTy = typeConverter->convertType(aElemTy);
|
||||
|
||||
int loadsPerThread = offsets.size() / (numRepM * numRepK);
|
||||
const int elemsPerLoad = numOfElems / loadsPerThread;
|
||||
@@ -500,7 +501,7 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
numReps, smemObj, sharedLayout);
|
||||
|
||||
Value smemBase = computeBasePtr(rewriter, loc, smemObj);
|
||||
Type resElemTy = aElemTy.isBF16() ? i16_ty : aElemTy;
|
||||
Type resElemTy = typeConverter->convertType(aElemTy);
|
||||
|
||||
Type smemPtrTy = getShemPtrTy(aElemTy);
|
||||
|
||||
@@ -585,9 +586,9 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
|
||||
unsigned int maxNumWarps = shape[1] / mfmaInstrN;
|
||||
int warpsPerGroupN = std::min(warpsPerCTA[1], maxNumWarps);
|
||||
bElemTy = typeConverter->convertType(bElemTy);
|
||||
|
||||
SmallVector<Value> hb;
|
||||
|
||||
if (fastPathAvailable(smemObj, sharedLayout, mfmaLayout)) {
|
||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||
|
||||
@@ -608,8 +609,7 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
|
||||
Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
|
||||
|
||||
Type resElemTy = bElemTy.isBF16() ? i16_ty : bElemTy;
|
||||
|
||||
Type resElemTy = typeConverter->convertType(bElemTy);
|
||||
Type smemPtrTy = getShemPtrTy(bElemTy);
|
||||
|
||||
const int loadsPerThread = offsets.size() / (numRepN * numRepK);
|
||||
@@ -651,9 +651,7 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
numReps, smemObj, sharedLayout);
|
||||
|
||||
Value smemBase = computeBasePtr(rewriter, loc, smemObj);
|
||||
|
||||
Type resElemTy = bElemTy.isBF16() ? i16_ty : bElemTy;
|
||||
|
||||
Type resElemTy = typeConverter->convertType(bElemTy);
|
||||
Type smemPtrTy = getShemPtrTy(bElemTy);
|
||||
|
||||
int loadsPerThread = offsets.size() / (numReps[0] * numReps[1]);
|
||||
|
||||
@@ -69,7 +69,6 @@ Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3));
|
||||
a0 = bitcast(a0, i32_ty);
|
||||
|
||||
Value a1 = undef(fp8x4VecTy);
|
||||
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1));
|
||||
@@ -862,8 +861,10 @@ inline SmallVector<Value> unpackI32(const SmallVector<Value> &inValues,
|
||||
if (!tensorTy)
|
||||
return inValues;
|
||||
auto encoding = tensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
|
||||
if (!(encoding && encoding.getParent().isa<MmaEncodingAttr>()))
|
||||
if (!(encoding && (encoding.getParent().isa<MmaEncodingAttr>() or
|
||||
encoding.getParent().isa<MfmaEncodingAttr>()))) {
|
||||
return inValues;
|
||||
}
|
||||
SmallVector<Value> outValues;
|
||||
for (auto v : inValues) {
|
||||
// cast i32 to appropriate eltType vector and extract elements
|
||||
@@ -997,6 +998,7 @@ public:
|
||||
Location loc = op->getLoc();
|
||||
// element type
|
||||
auto resultElementTy = getElementTypeOrSelf(resultTy);
|
||||
|
||||
Type elemTy = this->getTypeConverter()->convertType(resultElementTy);
|
||||
SmallVector<SmallVector<Value>> allOperands;
|
||||
for (auto operand : adaptor.getOperands()) {
|
||||
@@ -1025,12 +1027,15 @@ public:
|
||||
}
|
||||
it += curr.size();
|
||||
}
|
||||
|
||||
if (op->getNumOperands() > 0) {
|
||||
auto argTy = op->getOperand(0).getType();
|
||||
resultVals = reorderValues(resultVals, argTy, resultTy);
|
||||
}
|
||||
resultVals =
|
||||
packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
|
||||
resultVals = this->getTypeConverter()->packMfmaOperand(resultVals, resultTy, rewriter, loc);
|
||||
|
||||
Value view = this->getTypeConverter()->packLLElements(loc, resultVals,
|
||||
rewriter, resultTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
|
||||
@@ -281,6 +281,7 @@ public:
|
||||
// then (x + y) XOR z = 0byyyyxxxx XOR 0b00000zzzz = (x XOR z) + y
|
||||
// This means that we can use some immediate offsets for shared memory
|
||||
// operations.
|
||||
resElemTy = getTypeConverter()->convertType(resElemTy);
|
||||
auto dstPtrTy = ptr_ty(resElemTy, 3);
|
||||
auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides);
|
||||
Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset);
|
||||
|
||||
@@ -79,6 +79,37 @@ Value TritonGPUToLLVMTypeConverter::packLLElements(
|
||||
return llvmStruct;
|
||||
}
|
||||
|
||||
SmallVector<Value> TritonGPUToLLVMTypeConverter::packMfmaOperand(
|
||||
const SmallVector<Value> &inValues, Type srcTy,
|
||||
ConversionPatternRewriter &rewriter, Location loc) {
|
||||
auto tensorTy = srcTy.dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return inValues;
|
||||
auto encoding = tensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
|
||||
if (!(encoding && encoding.getParent().isa<MfmaEncodingAttr>())) {
|
||||
return inValues;
|
||||
}
|
||||
|
||||
auto structType = this->convertType(srcTy).dyn_cast<LLVM::LLVMStructType>();
|
||||
auto elementTypes = structType.getBody();
|
||||
assert(elementTypes.size() > 0);
|
||||
mlir::VectorType vecTy = elementTypes[0].dyn_cast<mlir::VectorType>();
|
||||
if (!vecTy) return inValues;
|
||||
|
||||
unsigned size = vecTy.getNumElements();
|
||||
|
||||
SmallVector<Value> result;
|
||||
for (int i = 0; i < inValues.size(); i += size) {
|
||||
Value valVec = undef(vecTy);
|
||||
for (unsigned j = 0; j < size; ++j) {
|
||||
valVec = insert_element(vecTy, valVec, inValues[i + j], i32_val(j));
|
||||
}
|
||||
result.push_back(valVec);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
SmallVector<Value> TritonGPUToLLVMTypeConverter::unpackLLElements(
|
||||
Location loc, Value llvmStruct, ConversionPatternRewriter &rewriter,
|
||||
Type type) {
|
||||
|
||||
@@ -26,6 +26,10 @@ public:
|
||||
Type type);
|
||||
|
||||
Type convertTritonTensorType(RankedTensorType type);
|
||||
|
||||
SmallVector<Value> packMfmaOperand(
|
||||
const SmallVector<Value> &inValues, Type srcTy,
|
||||
ConversionPatternRewriter &rewriter, Location loc);
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -289,4 +289,10 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
|
||||
}
|
||||
|
||||
} // namespace LLVM
|
||||
|
||||
bool isF8(Type eType) {
|
||||
return eType.isFloat8E5M2FNUZ() or eType.isFloat8E4M3FNUZ() or
|
||||
eType.isFloat8E5M2() or eType.isFloat8E5M2FNUZ();
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -300,6 +300,9 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
|
||||
StringRef key, StringRef content);
|
||||
|
||||
} // namespace LLVM
|
||||
|
||||
bool isF8(Type eType);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1043,11 +1043,148 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
|
||||
copy_kernel[(1,)](tri_fp16, triton.reinterpret(ref_fp8, in_dtype), tri_fp16.shape[0], BLOCK_SIZE=1024)
|
||||
assert torch.all(tri_fp8 == ref_fp8)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, K, a_type, b_type, out_dtype",
|
||||
[(*shape, *ab_type, out_dtype)
|
||||
for shape in [[128, 256, 32],
|
||||
[128, 16, 32],
|
||||
[32, 128, 64],
|
||||
[128, 128, 64],
|
||||
[64, 128, 128],
|
||||
[32, 128, 64],
|
||||
[64, 64, 32],
|
||||
[32, 32, 128],
|
||||
[128, 128, 64],
|
||||
[64, 128, 128]]
|
||||
for ab_type in [[tl.float8e4, tl.float16],
|
||||
[tl.float8e5, tl.float16],
|
||||
[tl.float16, tl.float8e4],
|
||||
[tl.float16, tl.float8e5]]
|
||||
for out_dtype in [torch.float16, torch.float32]
|
||||
])
|
||||
def test_gemm_fp816_mixed_inputs(M, N, K, a_type, b_type, out_dtype, device = 'cuda'):
|
||||
|
||||
check_type_supported(out_dtype, device)
|
||||
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
input = tl.load(input_ptr + offsets, mask=mask)
|
||||
output = input
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
compute_type:tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
# We accumulate along the K dimension.
|
||||
accumulator += tl.dot(a, b, out_dtype=compute_type)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
c = accumulator
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output matrix C with masks.
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
def matmul(a, b, c_type):
|
||||
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
|
||||
if c_type == torch.float16:
|
||||
comp_type = tl.float16
|
||||
else:
|
||||
comp_type = tl.float32
|
||||
|
||||
|
||||
c = torch.empty((M, N), device = a.device, dtype=c_type)
|
||||
# 1D launch kernel where each block gets its own program.
|
||||
grid = lambda META: (
|
||||
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
|
||||
)
|
||||
matmul_kernel[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
compute_type = comp_type,
|
||||
BLOCK_SIZE_M=32,
|
||||
BLOCK_SIZE_N=64,
|
||||
BLOCK_SIZE_K=64,
|
||||
GROUP_SIZE_M=4,
|
||||
num_stages=1,
|
||||
num_warps=2,
|
||||
)
|
||||
|
||||
return c
|
||||
|
||||
|
||||
def gen_input(M, N, d_type, seed, device='cuda'):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
if d_type == tl.float16:
|
||||
input = torch.randn((M, K), dtype=torch.float16, device=device)
|
||||
input_f16 = input
|
||||
else: # d_type is float8
|
||||
f8_tensor = torch.randn((M, N), dtype=torch.float32, device='cuda') * 10
|
||||
f8_tensor = f8_tensor.to(torch.int8)
|
||||
# keep only two bits of exponent to avoid overflow
|
||||
f8_tensor = f8_tensor & 0b00111111
|
||||
input = triton.reinterpret(f8_tensor, d_type)
|
||||
input_f16 = torch.empty_like(f8_tensor, dtype=torch.float16)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
n_elements = f8_tensor.numel()
|
||||
copy_kernel[grid](input, input_f16, n_elements, BLOCK_SIZE=1024)
|
||||
return input, input_f16
|
||||
|
||||
a, a_f16 = gen_input(M, K, a_type, 11, device=device)
|
||||
b, b_f16 = gen_input(K, N, b_type, 22, device=device)
|
||||
|
||||
# call torch function to compute gold
|
||||
golden = torch.matmul(a_f16, b_f16)
|
||||
|
||||
c = matmul(a, b, out_dtype)
|
||||
torch.testing.assert_close(c.to(golden.dtype), golden, rtol=1e-2, atol=6e-2)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test reduce
|
||||
# ---------------
|
||||
|
||||
|
||||
def get_reduced_dtype(dtype_str, op):
|
||||
if op in ('argmin', 'argmax'):
|
||||
return 'int32'
|
||||
|
||||
@@ -1301,13 +1301,19 @@ def dot(lhs: tl.tensor,
|
||||
out_dtype: tl.dtype,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
assert lhs.type.is_block() and rhs.type.is_block()
|
||||
assert lhs.dtype == rhs.dtype, f"First input ({lhs.dtype}) and second input ({rhs.dtype}) must have the same dtype!"
|
||||
assert lhs.dtype == rhs.dtype or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp16()) or (lhs.type.scalar.is_fp16() and rhs.type.scalar.is_fp8()), f"First input ({lhs.dtype}) and second input ({rhs.dtype}) must have the same dtype!"
|
||||
assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!"
|
||||
assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!"
|
||||
assert lhs.shape[1].value == rhs.shape[0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})"
|
||||
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
|
||||
and rhs.shape[1].value >= 16,\
|
||||
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!"
|
||||
|
||||
if lhs.type.scalar.is_fp8():
|
||||
lhs = cast(lhs, tl.float16, builder)
|
||||
elif rhs.type.scalar.is_fp8():
|
||||
rhs = cast(rhs, tl.float16, builder)
|
||||
|
||||
if lhs.type.scalar.is_int():
|
||||
assert lhs.type.scalar == tl.int8, "only int8 supported!"
|
||||
# TODO: This is CUDA specific, check if ROCm has the same limitation
|
||||
|
||||
Reference in New Issue
Block a user