From 6e82aa8dbca46c1656fe45e405ed57f2efe7d97c Mon Sep 17 00:00:00 2001 From: Shucai Xiao Date: Wed, 27 Sep 2023 08:00:31 -0500 Subject: [PATCH] support gemm fp8/fp16 mixed input (#333) * changes to support fp8/fp16 mixed inputs * add unit test for fp8/fp16 mixed input for gemm --- .../SharedToDotOperandMFMA.cpp | 16 +- .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 9 +- .../TritonGPUToLLVM/TritonGPUToLLVMBase.h | 1 + .../TritonGPUToLLVM/TypeConverter.cpp | 31 ++++ .../TritonGPUToLLVM/TypeConverter.h | 4 + lib/Conversion/TritonGPUToLLVM/Utility.cpp | 6 + lib/Conversion/TritonGPUToLLVM/Utility.h | 3 + python/test/unit/language/test_core_amd.py | 141 +++++++++++++++++- python/triton/language/semantic.py | 8 +- 9 files changed, 205 insertions(+), 14 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index 01fa4c238..c3cc48748 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -15,6 +15,7 @@ Type getShemPtrTy(Type elemTy) { auto ctx = elemTy.getContext(); return ptr_ty(type::i16Ty(ctx), 3); } + return ptr_ty(elemTy, 3); } @@ -439,9 +440,9 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread, std::max(mfmaInstrM * mfmaInstrK / iWaveSize /*wave size*/, 1); unsigned int maxNumWarps = shape[0] / mfmaInstrM; int warpsPerGroupM = std::min(warpsPerCTA[0], maxNumWarps); + aElemTy = typeConverter->convertType(aElemTy); SmallVector ha; - if (fastPathAvailable(smemObj, sharedLayout, mfmaLayout)) { Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); SmallVector offsets; @@ -459,7 +460,7 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread, Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter); Type smemPtrTy = getShemPtrTy(aElemTy); - Type resElemTy = aElemTy.isBF16() ? i16_ty : aElemTy; + Type resElemTy = typeConverter->convertType(aElemTy); int loadsPerThread = offsets.size() / (numRepM * numRepK); const int elemsPerLoad = numOfElems / loadsPerThread; @@ -500,7 +501,7 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread, numReps, smemObj, sharedLayout); Value smemBase = computeBasePtr(rewriter, loc, smemObj); - Type resElemTy = aElemTy.isBF16() ? i16_ty : aElemTy; + Type resElemTy = typeConverter->convertType(aElemTy); Type smemPtrTy = getShemPtrTy(aElemTy); @@ -585,9 +586,9 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread, unsigned int maxNumWarps = shape[1] / mfmaInstrN; int warpsPerGroupN = std::min(warpsPerCTA[1], maxNumWarps); + bElemTy = typeConverter->convertType(bElemTy); SmallVector hb; - if (fastPathAvailable(smemObj, sharedLayout, mfmaLayout)) { Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); @@ -608,8 +609,7 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread, Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter); - Type resElemTy = bElemTy.isBF16() ? i16_ty : bElemTy; - + Type resElemTy = typeConverter->convertType(bElemTy); Type smemPtrTy = getShemPtrTy(bElemTy); const int loadsPerThread = offsets.size() / (numRepN * numRepK); @@ -651,9 +651,7 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread, numReps, smemObj, sharedLayout); Value smemBase = computeBasePtr(rewriter, loc, smemObj); - - Type resElemTy = bElemTy.isBF16() ? i16_ty : bElemTy; - + Type resElemTy = typeConverter->convertType(bElemTy); Type smemPtrTy = getShemPtrTy(bElemTy); int loadsPerThread = offsets.size() / (numReps[0] * numReps[1]); diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 7c6e95aff..fdfa8767d 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -69,7 +69,6 @@ Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2)); a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3)); a0 = bitcast(a0, i32_ty); - Value a1 = undef(fp8x4VecTy); a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0)); a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1)); @@ -862,8 +861,10 @@ inline SmallVector unpackI32(const SmallVector &inValues, if (!tensorTy) return inValues; auto encoding = tensorTy.getEncoding().dyn_cast(); - if (!(encoding && encoding.getParent().isa())) + if (!(encoding && (encoding.getParent().isa() or + encoding.getParent().isa()))) { return inValues; + } SmallVector outValues; for (auto v : inValues) { // cast i32 to appropriate eltType vector and extract elements @@ -997,6 +998,7 @@ public: Location loc = op->getLoc(); // element type auto resultElementTy = getElementTypeOrSelf(resultTy); + Type elemTy = this->getTypeConverter()->convertType(resultElementTy); SmallVector> allOperands; for (auto operand : adaptor.getOperands()) { @@ -1025,12 +1027,15 @@ public: } it += curr.size(); } + if (op->getNumOperands() > 0) { auto argTy = op->getOperand(0).getType(); resultVals = reorderValues(resultVals, argTy, resultTy); } resultVals = packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter()); + resultVals = this->getTypeConverter()->packMfmaOperand(resultVals, resultTy, rewriter, loc); + Value view = this->getTypeConverter()->packLLElements(loc, resultVals, rewriter, resultTy); rewriter.replaceOp(op, view); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index 689dc498a..b09701272 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -281,6 +281,7 @@ public: // then (x + y) XOR z = 0byyyyxxxx XOR 0b00000zzzz = (x XOR z) + y // This means that we can use some immediate offsets for shared memory // operations. + resElemTy = getTypeConverter()->convertType(resElemTy); auto dstPtrTy = ptr_ty(resElemTy, 3); auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides); Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset); diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index 3f20337e8..29a2ebab8 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -79,6 +79,37 @@ Value TritonGPUToLLVMTypeConverter::packLLElements( return llvmStruct; } +SmallVector TritonGPUToLLVMTypeConverter::packMfmaOperand( + const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc) { + auto tensorTy = srcTy.dyn_cast(); + if (!tensorTy) + return inValues; + auto encoding = tensorTy.getEncoding().dyn_cast(); + if (!(encoding && encoding.getParent().isa())) { + return inValues; + } + + auto structType = this->convertType(srcTy).dyn_cast(); + auto elementTypes = structType.getBody(); + assert(elementTypes.size() > 0); + mlir::VectorType vecTy = elementTypes[0].dyn_cast(); + if (!vecTy) return inValues; + + unsigned size = vecTy.getNumElements(); + + SmallVector result; + for (int i = 0; i < inValues.size(); i += size) { + Value valVec = undef(vecTy); + for (unsigned j = 0; j < size; ++j) { + valVec = insert_element(vecTy, valVec, inValues[i + j], i32_val(j)); + } + result.push_back(valVec); + } + + return result; +} + SmallVector TritonGPUToLLVMTypeConverter::unpackLLElements( Location loc, Value llvmStruct, ConversionPatternRewriter &rewriter, Type type) { diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.h b/lib/Conversion/TritonGPUToLLVM/TypeConverter.h index 038363754..975808bb1 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.h +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.h @@ -26,6 +26,10 @@ public: Type type); Type convertTritonTensorType(RankedTensorType type); + + SmallVector packMfmaOperand( + const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc); }; #endif diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 351d31d71..e7c56c942 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -289,4 +289,10 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, } } // namespace LLVM + +bool isF8(Type eType) { + return eType.isFloat8E5M2FNUZ() or eType.isFloat8E4M3FNUZ() or + eType.isFloat8E5M2() or eType.isFloat8E5M2FNUZ(); +} + } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index be47ce4e2..1142951b1 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -300,6 +300,9 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, StringRef key, StringRef content); } // namespace LLVM + +bool isF8(Type eType); + } // namespace mlir #endif diff --git a/python/test/unit/language/test_core_amd.py b/python/test/unit/language/test_core_amd.py index 191e76f71..c1facb092 100644 --- a/python/test/unit/language/test_core_amd.py +++ b/python/test/unit/language/test_core_amd.py @@ -1043,11 +1043,148 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device): copy_kernel[(1,)](tri_fp16, triton.reinterpret(ref_fp8, in_dtype), tri_fp16.shape[0], BLOCK_SIZE=1024) assert torch.all(tri_fp8 == ref_fp8) + +@pytest.mark.parametrize("M, N, K, a_type, b_type, out_dtype", + [(*shape, *ab_type, out_dtype) + for shape in [[128, 256, 32], + [128, 16, 32], + [32, 128, 64], + [128, 128, 64], + [64, 128, 128], + [32, 128, 64], + [64, 64, 32], + [32, 32, 128], + [128, 128, 64], + [64, 128, 128]] + for ab_type in [[tl.float8e4, tl.float16], + [tl.float8e5, tl.float16], + [tl.float16, tl.float8e4], + [tl.float16, tl.float8e5]] + for out_dtype in [torch.float16, torch.float32] + ]) +def test_gemm_fp816_mixed_inputs(M, N, K, a_type, b_type, out_dtype, device = 'cuda'): + + check_type_supported(out_dtype, device) + + @triton.jit + def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + input = tl.load(input_ptr + offsets, mask=mask) + output = input + tl.store(output_ptr + offsets, output, mask=mask) + + @triton.jit + def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + compute_type:tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b, out_dtype=compute_type) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + def matmul(a, b, c_type): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + + if c_type == torch.float16: + comp_type = tl.float16 + else: + comp_type = tl.float32 + + + c = torch.empty((M, N), device = a.device, dtype=c_type) + # 1D launch kernel where each block gets its own program. + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + compute_type = comp_type, + BLOCK_SIZE_M=32, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=64, + GROUP_SIZE_M=4, + num_stages=1, + num_warps=2, + ) + + return c + + + def gen_input(M, N, d_type, seed, device='cuda'): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + if d_type == tl.float16: + input = torch.randn((M, K), dtype=torch.float16, device=device) + input_f16 = input + else: # d_type is float8 + f8_tensor = torch.randn((M, N), dtype=torch.float32, device='cuda') * 10 + f8_tensor = f8_tensor.to(torch.int8) + # keep only two bits of exponent to avoid overflow + f8_tensor = f8_tensor & 0b00111111 + input = triton.reinterpret(f8_tensor, d_type) + input_f16 = torch.empty_like(f8_tensor, dtype=torch.float16) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + n_elements = f8_tensor.numel() + copy_kernel[grid](input, input_f16, n_elements, BLOCK_SIZE=1024) + return input, input_f16 + + a, a_f16 = gen_input(M, K, a_type, 11, device=device) + b, b_f16 = gen_input(K, N, b_type, 22, device=device) + + # call torch function to compute gold + golden = torch.matmul(a_f16, b_f16) + + c = matmul(a, b, out_dtype) + torch.testing.assert_close(c.to(golden.dtype), golden, rtol=1e-2, atol=6e-2) + + # --------------- # test reduce # --------------- - - def get_reduced_dtype(dtype_str, op): if op in ('argmin', 'argmax'): return 'int32' diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 94067d8ef..57a738593 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1301,13 +1301,19 @@ def dot(lhs: tl.tensor, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: assert lhs.type.is_block() and rhs.type.is_block() - assert lhs.dtype == rhs.dtype, f"First input ({lhs.dtype}) and second input ({rhs.dtype}) must have the same dtype!" + assert lhs.dtype == rhs.dtype or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp16()) or (lhs.type.scalar.is_fp16() and rhs.type.scalar.is_fp8()), f"First input ({lhs.dtype}) and second input ({rhs.dtype}) must have the same dtype!" assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!" assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!" assert lhs.shape[1].value == rhs.shape[0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})" assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \ and rhs.shape[1].value >= 16,\ f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!" + + if lhs.type.scalar.is_fp8(): + lhs = cast(lhs, tl.float16, builder) + elif rhs.type.scalar.is_fp8(): + rhs = cast(rhs, tl.float16, builder) + if lhs.type.scalar.is_int(): assert lhs.type.scalar == tl.int8, "only int8 supported!" # TODO: This is CUDA specific, check if ROCm has the same limitation