support gemm fp8/fp16 mixed input (#333)

* changes to support fp8/fp16 mixed inputs

* add unit test for fp8/fp16 mixed input for gemm
This commit is contained in:
Shucai Xiao
2023-09-27 08:00:31 -05:00
committed by GitHub
parent 0a7b1c7c12
commit 6e82aa8dbc
9 changed files with 205 additions and 14 deletions

View File

@@ -15,6 +15,7 @@ Type getShemPtrTy(Type elemTy) {
auto ctx = elemTy.getContext();
return ptr_ty(type::i16Ty(ctx), 3);
}
return ptr_ty(elemTy, 3);
}
@@ -439,9 +440,9 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
std::max<int>(mfmaInstrM * mfmaInstrK / iWaveSize /*wave size*/, 1);
unsigned int maxNumWarps = shape[0] / mfmaInstrM;
int warpsPerGroupM = std::min(warpsPerCTA[0], maxNumWarps);
aElemTy = typeConverter->convertType(aElemTy);
SmallVector<Value> ha;
if (fastPathAvailable(smemObj, sharedLayout, mfmaLayout)) {
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
SmallVector<Value> offsets;
@@ -459,7 +460,7 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
Type smemPtrTy = getShemPtrTy(aElemTy);
Type resElemTy = aElemTy.isBF16() ? i16_ty : aElemTy;
Type resElemTy = typeConverter->convertType(aElemTy);
int loadsPerThread = offsets.size() / (numRepM * numRepK);
const int elemsPerLoad = numOfElems / loadsPerThread;
@@ -500,7 +501,7 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
numReps, smemObj, sharedLayout);
Value smemBase = computeBasePtr(rewriter, loc, smemObj);
Type resElemTy = aElemTy.isBF16() ? i16_ty : aElemTy;
Type resElemTy = typeConverter->convertType(aElemTy);
Type smemPtrTy = getShemPtrTy(aElemTy);
@@ -585,9 +586,9 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
unsigned int maxNumWarps = shape[1] / mfmaInstrN;
int warpsPerGroupN = std::min(warpsPerCTA[1], maxNumWarps);
bElemTy = typeConverter->convertType(bElemTy);
SmallVector<Value> hb;
if (fastPathAvailable(smemObj, sharedLayout, mfmaLayout)) {
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
@@ -608,8 +609,7 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
Type resElemTy = bElemTy.isBF16() ? i16_ty : bElemTy;
Type resElemTy = typeConverter->convertType(bElemTy);
Type smemPtrTy = getShemPtrTy(bElemTy);
const int loadsPerThread = offsets.size() / (numRepN * numRepK);
@@ -651,9 +651,7 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
numReps, smemObj, sharedLayout);
Value smemBase = computeBasePtr(rewriter, loc, smemObj);
Type resElemTy = bElemTy.isBF16() ? i16_ty : bElemTy;
Type resElemTy = typeConverter->convertType(bElemTy);
Type smemPtrTy = getShemPtrTy(bElemTy);
int loadsPerThread = offsets.size() / (numReps[0] * numReps[1]);

View File

@@ -69,7 +69,6 @@ Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2));
a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3));
a0 = bitcast(a0, i32_ty);
Value a1 = undef(fp8x4VecTy);
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0));
a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1));
@@ -862,8 +861,10 @@ inline SmallVector<Value> unpackI32(const SmallVector<Value> &inValues,
if (!tensorTy)
return inValues;
auto encoding = tensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
if (!(encoding && encoding.getParent().isa<MmaEncodingAttr>()))
if (!(encoding && (encoding.getParent().isa<MmaEncodingAttr>() or
encoding.getParent().isa<MfmaEncodingAttr>()))) {
return inValues;
}
SmallVector<Value> outValues;
for (auto v : inValues) {
// cast i32 to appropriate eltType vector and extract elements
@@ -997,6 +998,7 @@ public:
Location loc = op->getLoc();
// element type
auto resultElementTy = getElementTypeOrSelf(resultTy);
Type elemTy = this->getTypeConverter()->convertType(resultElementTy);
SmallVector<SmallVector<Value>> allOperands;
for (auto operand : adaptor.getOperands()) {
@@ -1025,12 +1027,15 @@ public:
}
it += curr.size();
}
if (op->getNumOperands() > 0) {
auto argTy = op->getOperand(0).getType();
resultVals = reorderValues(resultVals, argTy, resultTy);
}
resultVals =
packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
resultVals = this->getTypeConverter()->packMfmaOperand(resultVals, resultTy, rewriter, loc);
Value view = this->getTypeConverter()->packLLElements(loc, resultVals,
rewriter, resultTy);
rewriter.replaceOp(op, view);

View File

@@ -281,6 +281,7 @@ public:
// then (x + y) XOR z = 0byyyyxxxx XOR 0b00000zzzz = (x XOR z) + y
// This means that we can use some immediate offsets for shared memory
// operations.
resElemTy = getTypeConverter()->convertType(resElemTy);
auto dstPtrTy = ptr_ty(resElemTy, 3);
auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides);
Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset);

View File

@@ -79,6 +79,37 @@ Value TritonGPUToLLVMTypeConverter::packLLElements(
return llvmStruct;
}
SmallVector<Value> TritonGPUToLLVMTypeConverter::packMfmaOperand(
const SmallVector<Value> &inValues, Type srcTy,
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = srcTy.dyn_cast<RankedTensorType>();
if (!tensorTy)
return inValues;
auto encoding = tensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
if (!(encoding && encoding.getParent().isa<MfmaEncodingAttr>())) {
return inValues;
}
auto structType = this->convertType(srcTy).dyn_cast<LLVM::LLVMStructType>();
auto elementTypes = structType.getBody();
assert(elementTypes.size() > 0);
mlir::VectorType vecTy = elementTypes[0].dyn_cast<mlir::VectorType>();
if (!vecTy) return inValues;
unsigned size = vecTy.getNumElements();
SmallVector<Value> result;
for (int i = 0; i < inValues.size(); i += size) {
Value valVec = undef(vecTy);
for (unsigned j = 0; j < size; ++j) {
valVec = insert_element(vecTy, valVec, inValues[i + j], i32_val(j));
}
result.push_back(valVec);
}
return result;
}
SmallVector<Value> TritonGPUToLLVMTypeConverter::unpackLLElements(
Location loc, Value llvmStruct, ConversionPatternRewriter &rewriter,
Type type) {

View File

@@ -26,6 +26,10 @@ public:
Type type);
Type convertTritonTensorType(RankedTensorType type);
SmallVector<Value> packMfmaOperand(
const SmallVector<Value> &inValues, Type srcTy,
ConversionPatternRewriter &rewriter, Location loc);
};
#endif

View File

@@ -289,4 +289,10 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
}
} // namespace LLVM
bool isF8(Type eType) {
return eType.isFloat8E5M2FNUZ() or eType.isFloat8E4M3FNUZ() or
eType.isFloat8E5M2() or eType.isFloat8E5M2FNUZ();
}
} // namespace mlir

View File

@@ -300,6 +300,9 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
StringRef key, StringRef content);
} // namespace LLVM
bool isF8(Type eType);
} // namespace mlir
#endif

View File

@@ -1043,11 +1043,148 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
copy_kernel[(1,)](tri_fp16, triton.reinterpret(ref_fp8, in_dtype), tri_fp16.shape[0], BLOCK_SIZE=1024)
assert torch.all(tri_fp8 == ref_fp8)
@pytest.mark.parametrize("M, N, K, a_type, b_type, out_dtype",
[(*shape, *ab_type, out_dtype)
for shape in [[128, 256, 32],
[128, 16, 32],
[32, 128, 64],
[128, 128, 64],
[64, 128, 128],
[32, 128, 64],
[64, 64, 32],
[32, 32, 128],
[128, 128, 64],
[64, 128, 128]]
for ab_type in [[tl.float8e4, tl.float16],
[tl.float8e5, tl.float16],
[tl.float16, tl.float8e4],
[tl.float16, tl.float8e5]]
for out_dtype in [torch.float16, torch.float32]
])
def test_gemm_fp816_mixed_inputs(M, N, K, a_type, b_type, out_dtype, device = 'cuda'):
check_type_supported(out_dtype, device)
@triton.jit
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
input = tl.load(input_ptr + offsets, mask=mask)
output = input
tl.store(output_ptr + offsets, output, mask=mask)
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
compute_type:tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator += tl.dot(a, b, out_dtype=compute_type)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = accumulator
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def matmul(a, b, c_type):
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
M, K = a.shape
K, N = b.shape
if c_type == torch.float16:
comp_type = tl.float16
else:
comp_type = tl.float32
c = torch.empty((M, N), device = a.device, dtype=c_type)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
compute_type = comp_type,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=64,
BLOCK_SIZE_K=64,
GROUP_SIZE_M=4,
num_stages=1,
num_warps=2,
)
return c
def gen_input(M, N, d_type, seed, device='cuda'):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if d_type == tl.float16:
input = torch.randn((M, K), dtype=torch.float16, device=device)
input_f16 = input
else: # d_type is float8
f8_tensor = torch.randn((M, N), dtype=torch.float32, device='cuda') * 10
f8_tensor = f8_tensor.to(torch.int8)
# keep only two bits of exponent to avoid overflow
f8_tensor = f8_tensor & 0b00111111
input = triton.reinterpret(f8_tensor, d_type)
input_f16 = torch.empty_like(f8_tensor, dtype=torch.float16)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
n_elements = f8_tensor.numel()
copy_kernel[grid](input, input_f16, n_elements, BLOCK_SIZE=1024)
return input, input_f16
a, a_f16 = gen_input(M, K, a_type, 11, device=device)
b, b_f16 = gen_input(K, N, b_type, 22, device=device)
# call torch function to compute gold
golden = torch.matmul(a_f16, b_f16)
c = matmul(a, b, out_dtype)
torch.testing.assert_close(c.to(golden.dtype), golden, rtol=1e-2, atol=6e-2)
# ---------------
# test reduce
# ---------------
def get_reduced_dtype(dtype_str, op):
if op in ('argmin', 'argmax'):
return 'int32'

View File

@@ -1301,13 +1301,19 @@ def dot(lhs: tl.tensor,
out_dtype: tl.dtype,
builder: ir.builder) -> tl.tensor:
assert lhs.type.is_block() and rhs.type.is_block()
assert lhs.dtype == rhs.dtype, f"First input ({lhs.dtype}) and second input ({rhs.dtype}) must have the same dtype!"
assert lhs.dtype == rhs.dtype or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp16()) or (lhs.type.scalar.is_fp16() and rhs.type.scalar.is_fp8()), f"First input ({lhs.dtype}) and second input ({rhs.dtype}) must have the same dtype!"
assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!"
assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!"
assert lhs.shape[1].value == rhs.shape[0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})"
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
and rhs.shape[1].value >= 16,\
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!"
if lhs.type.scalar.is_fp8():
lhs = cast(lhs, tl.float16, builder)
elif rhs.type.scalar.is_fp8():
rhs = cast(rhs, tl.float16, builder)
if lhs.type.scalar.is_int():
assert lhs.type.scalar == tl.int8, "only int8 supported!"
# TODO: This is CUDA specific, check if ROCm has the same limitation