#include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" #include "ConvertLayoutOpToLLVM.h" #include "LoadStoreOpToLLVM.h" #include "Utility.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" #include using namespace mlir; using namespace mlir::triton; using ::mlir::LLVM::delinearize; using ::mlir::LLVM::getSharedMemoryObjectFromStruct; using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::getCTALayout; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::SharedEncodingAttr; static CUtensorMapDataType getCUtensorMapDataType(Type ty) { if (ty.isF16()) { return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; } else if (ty.isBF16()) { return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else if (ty.isF32()) { return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32; } else if (ty.getIntOrFloatBitWidth() == 8) { return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; } else { llvm::report_fatal_error("Unsupported elemTy for InsertSliceAsyncV2Op"); return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; } } // Contains some helper functions for both Load and Store conversions. struct LoadStoreConversionBase { explicit LoadStoreConversionBase(ModuleAxisInfoAnalysis &axisAnalysisPass) : axisAnalysisPass(axisAnalysisPass) {} unsigned getContiguity(Value ptr) const { auto tensorTy = ptr.getType().dyn_cast(); if (!tensorTy) return 1; return axisAnalysisPass.getPtrContiguity(ptr); } unsigned getVectorSize(Value ptr) const { auto tensorTy = ptr.getType().dyn_cast(); if (!tensorTy) return 1; auto contiguity = getContiguity(ptr); auto pointeeBitWidth = triton::getPointeeBitWidth(tensorTy); // The maximum vector size is 128 bits on NVIDIA GPUs. return std::min(128 / pointeeBitWidth, contiguity); } unsigned getMaskAlignment(Value mask) const { return axisAnalysisPass.getMaskAlignment(mask); } protected: ModuleAxisInfoAnalysis &axisAnalysisPass; }; struct LoadOpConversion : public ConvertTritonGPUOpToLLVMPattern, public LoadStoreConversionBase { using ConvertTritonGPUOpToLLVMPattern< triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern; LoadOpConversion(TritonGPUToLLVMTypeConverter &converter, ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(axisAnalysisPass) {} LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); // original values Value ptr = op.getPtr(); Value mask = op.getMask(); Value other = op.getOther(); // adaptor values assert(!isTensorPointerType(ptr.getType()) && "Cannot convert load with a tensor pointer into LLVM; " "this case should be transformed to normal load before lowering"); Value llPtr = adaptor.getPtr(); Value llMask = adaptor.getMask(); Value llOther = adaptor.getOther(); // Determine the vectorization size Type valueTy = op.getResult().getType(); Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(valueTy)); unsigned vec = getVectorSize(ptr); unsigned numElems = getTotalElemsPerThread(ptr.getType()); if (llMask) vec = std::min(vec, getMaskAlignment(mask)); // Get the LLVM values for pointers auto ptrElems = getTypeConverter()->unpackLLElements(loc, llPtr, rewriter, ptr.getType()); assert(ptrElems.size() == numElems); // Get the LLVM values for mask SmallVector maskElems; if (llMask) { maskElems = getTypeConverter()->unpackLLElements(loc, llMask, rewriter, mask.getType()); assert(maskElems.size() == numElems); } // Get the LLVM values for `other` // TODO: (goostavz) handle when other is const but not splat, which // should be rarely seen bool otherIsSplatConstInt = false; DenseElementsAttr constAttr; int64_t splatVal = 0; if (other && valueElemTy.isa() && matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat() && constAttr.getElementType().isa()) { otherIsSplatConstInt = true; splatVal = constAttr.getSplatValue().getSExtValue(); } SmallVector otherElems; if (other) { otherElems = getTypeConverter()->unpackLLElements(loc, llOther, rewriter, other.getType()); } // vectorized iteration through all the pointer/mask/other elements const int valueElemNBits = std::max(8u, valueElemTy.getIntOrFloatBitWidth()); const int numVecs = numElems / vec; SmallVector loadedVals; for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { // TODO: optimization when ptr is GEP with constant offset size_t in_off = 0; const size_t maxWordWidth = std::max(32, valueElemNBits); const size_t totalWidth = valueElemNBits * vec; const size_t width = std::min(totalWidth, maxWordWidth); const size_t nWords = std::max(1, totalWidth / width); const size_t wordNElems = width / valueElemNBits; const size_t movWidth = width < 16 ? 16 : width; assert(wordNElems * nWords * numVecs == numElems); #ifdef USE_ROCM Value pred = mask ? maskElems[vecStart] : int_val(1, 1); auto vecTy = LLVM::getFixedVectorType(valueElemTy, vec); auto loaded = rewriter.create( loc, pred, [&](OpBuilder& builder, Location loc) { Value ptr = addrspacecast(ptrElems[vecStart], ptr_ty(vecTy)); auto loadVal = rewriter.create(loc, ptr); builder.create(loc, ValueRange({loadVal})); }, [&](OpBuilder& builder, Location loc) { mlir::Attribute zero = builder.getZeroAttr(valueElemTy); auto denseValue = DenseElementsAttr::get(vecTy.cast(), zero); Value zeroVal = rewriter.create(loc, vecTy, denseValue); Value otherVal; if (other) { Value v = undef(vecTy); for (size_t s = 0; s < vec; ++s) { Value falseVal = otherElems[vecStart + s]; Value sVal = createIndexAttrConstant( rewriter, loc, this->getTypeConverter()->getIndexType(), s); v = insert_element(vecTy, v, falseVal, sVal); } otherVal = v; } Value falseVal = other ? otherVal : zeroVal; builder.create(loc, ValueRange({falseVal})); }); Value loadVal = bitcast(loaded->getResult(0), vecTy); for (size_t ii = 0; ii < vec; ++ii) { Value vecIdx = createIndexAttrConstant( rewriter, loc, this->getTypeConverter()->getIndexType(), ii % vec); Value loaded = extract_element(valueElemTy, loadVal, vecIdx); loadedVals.push_back(loaded); } #else // TODO(Superjomn) Add cache policy fields to StoreOp. // TODO(Superjomn) Deal with cache policy here. const bool hasL2EvictPolicy = false; PTXBuilder ptxBuilder; Value pred = mask ? maskElems[vecStart] : int_val(1, 1); const std::string readConstraint = (width == 64) ? "l" : ((width == 32) ? "r" : "c"); const std::string writeConstraint = (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c"); // prepare asm operands auto *dstsOpr = ptxBuilder.newListOperand(); for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) { auto *opr = ptxBuilder.newOperand(writeConstraint, /*init=*/true); // =r operations dstsOpr->listAppend(opr); } auto *addrOpr = ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); // Define the instruction opcode auto &ld = ptxBuilder.create<>("ld") ->o("volatile", op.getIsVolatile()) .global() .o("ca", op.getCache() == triton::CacheModifier::CA) .o("cg", op.getCache() == triton::CacheModifier::CG) .o("L1::evict_first", op.getEvict() == triton::EvictionPolicy::EVICT_FIRST) .o("L1::evict_last", op.getEvict() == triton::EvictionPolicy::EVICT_LAST) .o("L1::cache_hint", hasL2EvictPolicy) .v(nWords) .b(width); PTXBuilder::Operand *evictOpr{}; // Here lack a mlir::Value to bind to this operation, so disabled. // if (has_l2_evict_policy) // evictOpr = ptxBuilder.newOperand(l2Evict, "l"); if (!evictOpr) ld(dstsOpr, addrOpr).predicate(pred, "b"); else ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b"); if (other) { for (size_t ii = 0; ii < nWords; ++ii) { // PTX doesn't support mov.u8, so we need to use mov.u16 PTXInstr &mov = ptxBuilder.create<>("mov")->o("u" + std::to_string(movWidth)); size_t size = width / valueElemNBits; auto vecTy = LLVM::getFixedVectorType(valueElemTy, size); Value v = undef(vecTy); for (size_t s = 0; s < size; ++s) { Value falseVal = otherElems[vecStart + ii * size + s]; Value sVal = createIndexAttrConstant( rewriter, loc, this->getTypeConverter()->getIndexType(), s); v = insert_element(vecTy, v, falseVal, sVal); } v = bitcast(v, IntegerType::get(getContext(), width)); PTXInstr::Operand *opr{}; if (otherIsSplatConstInt) { for (size_t s = 0; s < 32; s += valueElemNBits) splatVal |= splatVal << valueElemNBits; opr = ptxBuilder.newConstantOperand(splatVal); } else opr = ptxBuilder.newOperand(v, readConstraint); mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b"); } } // Create inline ASM signature SmallVector retTys(nWords, IntegerType::get(getContext(), width)); Type retTy = retTys.size() > 1 ? LLVM::LLVMStructType::getLiteral(getContext(), retTys) : retTys[0]; // TODO: if (has_l2_evict_policy) // auto asmDialectAttr = // LLVM::AsmDialectAttr::get(rewriter.getContext(), // LLVM::AsmDialect::AD_ATT); Value ret = ptxBuilder.launch(rewriter, loc, retTy); // Extract and store return values SmallVector rets; for (unsigned int ii = 0; ii < nWords; ++ii) { Value curr; if (retTy.isa()) { curr = extract_val(IntegerType::get(getContext(), width), ret, ii); } else { curr = ret; } curr = bitcast(curr, LLVM::getFixedVectorType(valueElemTy, width / valueElemNBits)); rets.push_back(curr); } int tmp = width / valueElemNBits; for (size_t ii = 0; ii < vec; ++ii) { Value vecIdx = createIndexAttrConstant( rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp); Value loaded = extract_element(valueElemTy, rets[ii / tmp], vecIdx); loadedVals.push_back(loaded); } #endif } // end vec Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); Value resultStruct = getTypeConverter()->packLLElements( loc, loadedVals, rewriter, llvmResultStructTy); rewriter.replaceOp(op, {resultStruct}); return success(); } }; struct StoreOpConversion : public ConvertTritonGPUOpToLLVMPattern, public LoadStoreConversionBase { using ConvertTritonGPUOpToLLVMPattern< triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern; StoreOpConversion(TritonGPUToLLVMTypeConverter &converter, ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(axisAnalysisPass) {} LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value ptr = op.getPtr(); Value value = op.getValue(); Value llPtr = adaptor.getPtr(); Value llMask = adaptor.getMask(); Value llValue = adaptor.getValue(); auto loc = op->getLoc(); MLIRContext *ctx = rewriter.getContext(); auto valueTy = value.getType(); Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(valueTy)); unsigned vec = getVectorSize(ptr); unsigned elemsPerThread = getTotalElemsPerThread(ptr.getType()); auto ptrElems = getTypeConverter()->unpackLLElements(loc, llPtr, rewriter, ptr.getType()); auto valueElems = getTypeConverter()->unpackLLElements( loc, llValue, rewriter, value.getType()); assert(ptrElems.size() == valueElems.size()); // Determine the vectorization size SmallVector maskElems; if (llMask) { Value mask = op.getMask(); maskElems = getTypeConverter()->unpackLLElements(loc, llMask, rewriter, mask.getType()); assert(valueElems.size() == maskElems.size()); unsigned maskAlign = getMaskAlignment(mask); vec = std::min(vec, maskAlign); } Value mask = getMask(valueTy, rewriter, loc); const size_t dtsize = std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); const size_t valueElemNBits = dtsize * 8; const int numVecs = elemsPerThread / vec; for (size_t vecStart = 0; vecStart < elemsPerThread; vecStart += vec) { // TODO: optimization when ptr is AddPtr with constant offset size_t in_off = 0; const size_t maxWordWidth = std::max(32, valueElemNBits); const size_t totalWidth = valueElemNBits * vec; const size_t width = std::min(totalWidth, maxWordWidth); const size_t nWords = std::max(1, totalWidth / width); const size_t wordNElems = width / valueElemNBits; assert(wordNElems * nWords * numVecs == elemsPerThread); // TODO(Superjomn) Add cache policy fields to StoreOp. // TODO(Superjomn) Deal with cache policy here. Type valArgTy = IntegerType::get(ctx, width); auto wordTy = vec_ty(valueElemTy, wordNElems); SmallVector> asmArgs; for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) { // llWord is a width-len composition Value llWord = undef(wordTy); // Insert each value element to the composition for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) { const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx; assert(elemOffset < valueElems.size()); Value elem = valueElems[elemOffset]; if (elem.getType().isInteger(1)) elem = sext(i8_ty, elem); elem = bitcast(elem, valueElemTy); llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx)); } llWord = bitcast(llWord, valArgTy); #ifdef USE_ROCM Value maskVal = llMask ? and_(mask, maskElems[vecStart]) : mask; rewriter.create(loc, maskVal, [&](OpBuilder &builder, Location loc){ auto storeOp = builder.create(loc, llWord, ptrElems[vecStart + wordIdx * wordNElems]); builder.create(loc); }, nullptr); #else std::string constraint = (width == 64) ? "l" : ((width == 32) ? "r" : "c"); asmArgs.emplace_back(llWord, constraint); #endif } #ifndef USE_ROCM // Prepare the PTX inline asm. PTXBuilder ptxBuilder; auto *asmArgList = ptxBuilder.newListOperand(asmArgs); Value maskVal = llMask ? and_(mask, maskElems[vecStart]) : mask; auto *asmAddr = ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); auto &ptxStoreInstr = ptxBuilder.create<>("st") ->global() .o("wb", op.getCache() == triton::CacheModifier::WB) .o("cg", op.getCache() == triton::CacheModifier::CG) .o("cs", op.getCache() == triton::CacheModifier::CS) .o("wt", op.getCache() == triton::CacheModifier::WT) .o("L1::evict_first", op.getEvict() == triton::EvictionPolicy::EVICT_FIRST) .o("L1::evict_last", op.getEvict() == triton::EvictionPolicy::EVICT_LAST) .v(nWords) .b(width); ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b"); Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1)); llvm::SmallVector argTys({boolTy, ptr.getType()}); argTys.insert(argTys.end(), nWords, valArgTy); auto asmReturnTy = void_ty(ctx); ptxBuilder.launch(rewriter, loc, asmReturnTy); #endif } rewriter.eraseOp(op); return success(); } }; // TODO: refactor to save common logic with insertsliceasyncv2 struct StoreAsyncOpConversion : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< triton::nvidia_gpu::StoreAsyncOp>::ConvertTritonGPUOpToLLVMPattern; StoreAsyncOpConversion(TritonGPUToLLVMTypeConverter &converter, ModuleAllocation &allocation, mlir::triton::gpu::TMAMetadataTy *tmaMetadata, const TensorPtrMapT *tensorPtrMap, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern( converter, allocation, tmaMetadata, benefit), tensorPtrMap(tensorPtrMap) {} LogicalResult matchAndRewrite(triton::nvidia_gpu::StoreAsyncOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcTy = op.getSrc().getType().cast(); auto srcEncoding = srcTy.getEncoding(); if (srcEncoding.isa()) { return lowerStoreAsyncWithSlice(op, adaptor, rewriter); } else { return lowerStoreAsync(op, adaptor, rewriter); } } LogicalResult lowerStoreAsync(triton::nvidia_gpu::StoreAsyncOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); MLIRContext *ctx = rewriter.getContext(); auto dst = op.getDst(); auto src = op.getSrc(); auto srcTy = src.getType().cast(); auto elemTy = srcTy.getElementType(); auto rank = srcTy.getRank(); // The sotre async op only supports tensor with ranke <= 5. // Reference: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-dimension-size-and-format assert(rank > 0 && rank <= 5); auto moduleOp = op->getParentOfType(); assert(moduleOp && "Parent ModuleOp not found for StoreAsyncOp"); auto llFuncOp = op->getParentOfType(); assert(llFuncOp && "LLVMFuncOp not found for StoreAsyncOp"); int numTMADescs = getNumTMADescs(llFuncOp); assert(numTMADescs > 0); auto sharedLayout = srcTy.getEncoding().dyn_cast(); assert(sharedLayout && "expected shared encoding"); mlir::triton::gpu::TMAInfo tmaInfo; tmaInfo.tensorDataType = getCUtensorMapDataType(elemTy); tmaInfo.tensorRank = rank; assert(tmaMetadata); auto inOrder = sharedLayout.getOrder(); unsigned TMADescIdx = tmaMetadata->size(); unsigned numFuncArgs = llFuncOp.getBody().front().getNumArguments(); auto makeTensorPtr = tensorPtrMap->lookup(op.getOperation()); auto dstOrder = makeTensorPtr.getOrder(); unsigned globalAddressArgIdx = getArgIdx(makeTensorPtr.getBase()); tmaInfo.globalAddressArgIdx = globalAddressArgIdx; tmaInfo.TMADescArgIdx = numFuncArgs - numTMADescs + TMADescIdx; auto getDimOfOrder = [](ArrayRef order, int32_t i) { auto it = std::find(order.begin(), order.end(), i); assert(it != order.end()); return std::distance(order.begin(), it); }; std::vector globalDimsArgIdx; std::vector globalStridesArgIdx; // constant values are mapped to (-1 - value) for (int i = 0; i < rank; ++i) { int32_t argIdx = -1; auto dim = getDimOfOrder(dstOrder, i); argIdx = getArgIdx(makeTensorPtr.getShape()[dim]); globalDimsArgIdx.emplace_back(argIdx); // handle constant stride argIdx = getArgIdx(makeTensorPtr.getStrides()[dim]); globalStridesArgIdx.emplace_back(argIdx); } tmaInfo.globalDimsArgIdx = globalDimsArgIdx; tmaInfo.globalStridesArgIdx = globalStridesArgIdx; std::vector boxDims; auto CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA(); auto CTAOrder = sharedLayout.getCTALayout().getCTAOrder(); auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum(); auto tensorShape = makeTensorPtr.getResult() .getType() .cast() .getPointeeType() .cast() .getShape(); auto shapePerCTA = getShapePerCTA(CTASplitNum, tensorShape); const uint32_t bytesPerCacheline = 128; uint32_t bytesPerElem = elemTy.getIntOrFloatBitWidth() / 8; uint32_t numBox{1}; for (int i = 0; i < rank; ++i) { auto dim = getDimOfOrder(dstOrder, i); auto tNumElems = shapePerCTA[dim]; if (i == 0 && tNumElems * bytesPerElem > bytesPerCacheline) { tNumElems = bytesPerCacheline / bytesPerElem; numBox = (shapePerCTA[dim] + tNumElems - 1) / tNumElems; } boxDims.emplace_back(tNumElems); } std::vector elementStrides(rank, 1); tmaInfo.boxDims = boxDims; tmaInfo.elementStrides = elementStrides; CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE; assert( ((elemTy.getIntOrFloatBitWidth() == 16 && sharedLayout.getVec() == 8) or (elemTy.getIntOrFloatBitWidth() == 32 && sharedLayout.getVec() == 4)) && "Unexpected shared layout for StoreAsyncOp"); if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2) swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B; else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4) swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B; else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8) swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B; else llvm::report_fatal_error("Unsupported shared layout for StoreAsyncOp"); tmaInfo.swizzle = swizzle; tmaInfo.interleave = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE; tmaInfo.l2Promotion = CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B; tmaInfo.oobFill = CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; tmaMetadata->emplace_back(tmaInfo); Value llDst = adaptor.getDst(); Value llSrc = adaptor.getSrc(); auto srcShape = srcTy.getShape(); auto smemObj = getSharedMemoryObjectFromStruct(loc, llSrc, rewriter); SmallVector offsetVals; for (auto i = 0; i < srcShape.size(); ++i) { offsetVals.emplace_back(i32_val(0)); } Value tmaDesc = llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx); auto ptrI8SharedTy = LLVM::LLVMPointerType::get( typeConverter->convertType(rewriter.getI8Type()), 3); auto threadId = getThreadId(rewriter, loc); Value pred = icmp_eq(threadId, i32_val(0)); auto llCoord = getTypeConverter()->unpackLLElements(loc, llDst, rewriter, dst.getType()); uint32_t boxStride = std::accumulate(boxDims.begin(), boxDims.end(), 1, std::multiplies()); Value clusterCTAId = getClusterCTAId(rewriter, loc); SmallVector multiDimClusterCTAId = delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); rewriter.create(loc, 0); for (uint32_t b = 0; b < numBox; ++b) { SmallVector coord; // raw coord for (int i = 0; i < rank; ++i) { auto dim = getDimOfOrder(dstOrder, i); coord.push_back(llCoord[dim]); } // coord with box and cta offset for (int i = 0; i < rank; ++i) { auto dim = getDimOfOrder(dstOrder, i); if (i == 0) { coord[i] = add(coord[i], i32_val(b * boxDims[i])); auto CTAOffset = mul(multiDimClusterCTAId[dim], i32_val(numBox * boxDims[i])); coord[i] = add(coord[i], CTAOffset); } else { coord[i] = add(coord[i], mul(multiDimClusterCTAId[dim], i32_val(boxDims[i]))); } } Value srcOffset = i32_val(b * boxStride); auto srcPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); Value srcPtrBase = gep(srcPtrTy, smemObj.base, srcOffset); auto addr = bitcast(srcPtrBase, ptrI8SharedTy); rewriter.create(loc, tmaDesc, addr, pred, coord); } rewriter.eraseOp(op); return success(); } LogicalResult lowerStoreAsyncWithSlice(triton::nvidia_gpu::StoreAsyncOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); MLIRContext *ctx = rewriter.getContext(); auto dst = op.getDst(); auto src = op.getSrc(); auto srcTy = src.getType().cast(); auto makeTensorPtr = tensorPtrMap->lookup(op.getOperation()); auto dstTensorTy = makeTensorPtr.getResult() .getType() .cast() .getPointeeType() .cast(); auto tensorShape = dstTensorTy.getShape(); auto dstOrder = makeTensorPtr.getOrder(); auto dstElemTy = dstTensorTy.getElementType(); auto rank = srcTy.getRank(); // The sotre async op only supports tensor with ranke <= 5. // Reference: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-dimension-size-and-format assert(rank > 0 && rank <= 5); auto moduleOp = op->getParentOfType(); assert(moduleOp && "Parent ModuleOp not found for StoreAsyncOp"); auto llFuncOp = op->getParentOfType(); assert(llFuncOp && "LLVMFuncOp not found for StoreAsyncOp"); int numTMADescs = getNumTMADescs(llFuncOp); assert(numTMADescs > 0); auto ctaLayout = getCTALayout(dstTensorTy.getEncoding()); // The order of smem should be consistent with gmem. SmallVector sharedOrder; for (auto o : makeTensorPtr.getOrder()) { sharedOrder.emplace_back(o); } auto sharedLayout = SharedEncodingAttr::get(ctx, tensorShape, sharedOrder, ctaLayout, dstElemTy); mlir::triton::gpu::TMAInfo tmaInfo; tmaInfo.tensorDataType = getCUtensorMapDataType(dstElemTy); tmaInfo.tensorRank = rank; assert(tmaMetadata); unsigned TMADescIdx = tmaMetadata->size(); unsigned numFuncArgs = llFuncOp.getBody().front().getNumArguments(); unsigned globalAddressArgIdx = getArgIdx(makeTensorPtr.getBase()); tmaInfo.globalAddressArgIdx = globalAddressArgIdx; tmaInfo.TMADescArgIdx = numFuncArgs - numTMADescs + TMADescIdx; auto getDimOfOrder = [](ArrayRef order, int32_t i) { auto it = std::find(order.begin(), order.end(), i); assert(it != order.end()); return std::distance(order.begin(), it); }; std::vector globalDimsArgIdx; std::vector globalStridesArgIdx; // constant values are mapped to (-1 - value) for (int i = 0; i < rank; ++i) { int32_t argIdx = -1; auto dim = getDimOfOrder(dstOrder, i); argIdx = getArgIdx(makeTensorPtr.getShape()[dim]); globalDimsArgIdx.emplace_back(argIdx); // handle constant stride argIdx = getArgIdx(makeTensorPtr.getStrides()[dim]); globalStridesArgIdx.emplace_back(argIdx); } tmaInfo.globalDimsArgIdx = globalDimsArgIdx; tmaInfo.globalStridesArgIdx = globalStridesArgIdx; std::vector boxDims; auto CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA(); auto CTAOrder = sharedLayout.getCTALayout().getCTAOrder(); auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum(); auto shapePerCTA = getShapePerCTA(CTASplitNum, tensorShape); auto srcLayout = srcTy.getEncoding(); auto mmaLayout = srcLayout.dyn_cast(); unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); auto instrShape = mmaLayout.getInstrShape(); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); uint32_t repM = ceil(shapePerCTA[0], instrShape[0] * warpsPerCTA[0]); uint32_t numElemsPerRep = numElems / repM; const uint32_t bytesPerCacheline = 128; uint32_t bytesPerElem = dstElemTy.getIntOrFloatBitWidth() / 8; uint32_t numBox{1}; for (int i = 0; i < rank; ++i) { auto dim = getDimOfOrder(dstOrder, i); auto tNumElems = shapePerCTA[dim]; if (i == 0 && tNumElems * bytesPerElem > bytesPerCacheline) { tNumElems = bytesPerCacheline / bytesPerElem; numBox = (shapePerCTA[dim] + tNumElems - 1) / tNumElems; } if (i == 1) { tNumElems = tNumElems / repM / warpsPerCTA[0]; } boxDims.emplace_back(tNumElems); } std::vector elementStrides(rank, 1); tmaInfo.boxDims = boxDims; tmaInfo.elementStrides = elementStrides; CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE; assert(((dstElemTy.getIntOrFloatBitWidth() == 16 && sharedLayout.getVec() == 8) or (dstElemTy.getIntOrFloatBitWidth() == 32 && sharedLayout.getVec() == 4)) && "Unexpected shared layout for StoreAsyncOp"); if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2) swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B; else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4) swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B; else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8) swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B; else llvm::report_fatal_error("Unsupported shared layout for StoreAsyncOp"); tmaInfo.swizzle = swizzle; tmaInfo.interleave = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE; tmaInfo.l2Promotion = CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B; tmaInfo.oobFill = CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; tmaMetadata->emplace_back(tmaInfo); Value llDst = adaptor.getDst(); Value llSrc = adaptor.getSrc(); auto srcShape = srcTy.getShape(); auto dstElemPtrTy = ptr_ty(getTypeConverter()->convertType(dstElemTy), 3); Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); smemBase = bitcast(smemBase, dstElemPtrTy); SmallVector offsetVals; for (auto i = 0; i < srcShape.size(); ++i) { offsetVals.emplace_back(i32_val(0)); } Value tmaDesc = llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx); auto ptrI8SharedTy = LLVM::LLVMPointerType::get( typeConverter->convertType(rewriter.getI8Type()), 3); auto threadId = getThreadId(rewriter, loc); Value pred = int_val(1, 1); auto llCoord = getTypeConverter()->unpackLLElements(loc, llDst, rewriter, dst.getType()); uint32_t boxStride = std::accumulate(boxDims.begin(), boxDims.end(), 1, std::multiplies()); boxStride = boxStride * repM * warpsPerCTA[0]; Value clusterCTAId = getClusterCTAId(rewriter, loc); SmallVector multiDimClusterCTAId = delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); // rowStride in bytes uint32_t rowStrideInBytes = shapePerCTA[dstOrder[0]] * bytesPerElem; uint32_t swizzlingByteWidth = std::min(rowStrideInBytes, bytesPerCacheline); unsigned numElemsPerSwizzlingRow = swizzlingByteWidth / bytesPerElem; unsigned leadingDimOffset = numElemsPerSwizzlingRow * shapePerCTA[dstOrder[1]]; uint32_t rowsPerRep = getShapePerCTATile(mmaLayout)[0]; Value warpId = udiv(threadId, i32_val(32)); Value warpId0 = urem(urem(warpId, i32_val(warpsPerCTA[0])), i32_val(srcShape[0] / instrShape[0])); auto srcOrder = triton::gpu::getOrder(srcLayout); unsigned inVec = srcOrder == sharedLayout.getOrder() ? triton::gpu::getContigPerThread(srcLayout)[srcOrder[0]] : 1; unsigned outVec = sharedLayout.getVec(); unsigned minVec = std::min(outVec, inVec); assert(minVec == 2); auto wordTy = vec_ty(dstElemTy, minVec); auto inVals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(), rewriter, srcTy); for (uint32_t b = 0; b < numBox; ++b) { for (int rep = 0; rep < repM; ++rep) { Value rowOfWarp = add(mul(warpId0, i32_val(instrShape[0])), i32_val(rep * rowsPerRep)); uint32_t elemIdxOffset = rep * numElemsPerRep; for (unsigned idx = 0; idx < numElemsPerRep / numBox; idx += 8) { uint32_t elemIdx = elemIdxOffset + b * numElemsPerRep / numBox + idx; Value offset = rewriter.create( loc, i32_ty, threadId, rowOfWarp, i32_val(b * numElemsPerRep / numBox + idx), leadingDimOffset, numElemsPerSwizzlingRow, true); Value addr = gep(dstElemPtrTy, smemBase, offset); Value words[4]; for (unsigned i = 0; i < 8; ++i) { if (i % minVec == 0) words[i / 2] = undef(wordTy); words[i / 2] = insert_element( wordTy, words[i / 2], inVals[elemIdx + i], i32_val(i % minVec)); } rewriter.create( loc, bitcast(addr, ptrI8SharedTy), ValueRange{bitcast(words[0], i32_ty), bitcast(words[1], i32_ty), bitcast(words[2], i32_ty), bitcast(words[3], i32_ty)}); } rewriter.create(loc, 0); SmallVector coord; // raw coord for (int i = 0; i < rank; ++i) { auto dim = getDimOfOrder(dstOrder, i); coord.push_back(llCoord[dim]); } // coord with box and cta offset for (int i = 0; i < rank; ++i) { auto dim = getDimOfOrder(dstOrder, i); if (i == 0) { coord[i] = add(coord[i], i32_val(b * boxDims[i])); auto CTAOffset = mul(multiDimClusterCTAId[dim], i32_val(numBox * boxDims[i])); coord[i] = add(coord[i], CTAOffset); } else { Value blockOffset = i32_val(rep * instrShape[0] * warpsPerCTA[0]); Value warpOffset = mul(warpId0, i32_val(instrShape[0])); coord[i] = add(add(coord[i], add(blockOffset, warpOffset)), mul(multiDimClusterCTAId[dim], i32_val(boxDims[i] * repM * warpsPerCTA[0]))); } } Value srcOffset = add(i32_val(b * boxStride + rep * instrShape[0] * warpsPerCTA[0] * instrShape[1] * warpsPerCTA[1] / numBox), mul(warpId0, i32_val(instrShape[0] * numElemsPerSwizzlingRow))); auto srcPtrTy = ptr_ty(getTypeConverter()->convertType(dstElemTy), 3); Value srcPtrBase = gep(srcPtrTy, smemBase, srcOffset); auto addr = bitcast(srcPtrBase, ptrI8SharedTy); rewriter.create(loc, tmaDesc, addr, pred, coord); } } rewriter.eraseOp(op); return success(); } private: unsigned getArgIdx(Value v) const { if (auto op = v.getDefiningOp()) { return -1 - op.getValue().dyn_cast().getValue().getZExtValue(); } if (!isa(v) && !isa( v.getDefiningOp())) llvm::report_fatal_error( "Operand of `MakeTensorPtrOp` is not the function's argument"); if (v.getDefiningOp() && isa(v.getDefiningOp())) { return getArgIdx(v.getDefiningOp()->getOperand(0)); } else if (v.getParentBlock()->isEntryBlock() && v.isa()) { // in entryblock and is BlockArgument; Because argument of func are // arugments of entryblock bb0 in MLIR return v.cast().getArgNumber(); } else if (v.getParentBlock()->isEntryBlock() && (!v.isa())) { // in entryblock but not BlockArgument return getArgIdx(v.getDefiningOp()->getOperand(0)); } else if (!v.getParentBlock()->isEntryBlock()) { // in non-entryblock return getArgIdx(v.getDefiningOp()->getOperand(0)); } else { llvm::report_fatal_error( "Operand of `MakeTensorPtrOp` is not the function's argument"); return 0; } } int getNumTMADescs(LLVM::LLVMFuncOp func) const { if (!func->hasAttr(kAttrNumTMALoadDescsName)) { llvm::report_fatal_error("TritonGPU module should contain a " "triton_gpu.num-tma-load attribute"); return -1; } if (!func->hasAttr(kAttrNumTMAStoreDescsName)) { llvm::report_fatal_error("TritonGPU module should contain a " "triton_gpu.num-tma-store attribute"); return -1; } return func->getAttr(kAttrNumTMAStoreDescsName) .cast() .getInt() + func->getAttr(kAttrNumTMALoadDescsName).cast().getInt(); } const TensorPtrMapT *tensorPtrMap; }; namespace { void createBarrier(ConversionPatternRewriter &rewriter, Location loc, int numCTAs) { #ifdef USE_ROCM barrier(); #else if (numCTAs == 1) { barrier(); } else { rewriter.create(loc, false); rewriter.create(loc); } #endif } } // namespace struct AtomicCASOpConversion : public ConvertTritonGPUOpToLLVMPattern, public LoadStoreConversionBase { using ConvertTritonGPUOpToLLVMPattern< triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern; AtomicCASOpConversion(TritonGPUToLLVMTypeConverter &converter, ModuleAllocation &allocation, ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern( converter, allocation, benefit), LoadStoreConversionBase(axisAnalysisPass) {} #ifdef USE_ROCM LogicalResult matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // extract relevant info from Module auto loc = op.getLoc(); MLIRContext *ctx = rewriter.getContext(); Value ptr = op.getPtr(); Value llPtr = adaptor.getPtr(); Value llCmp = adaptor.getCmp(); Value llVal = adaptor.getVal(); // prep data by unpacking to get data ready auto ptrElements = getTypeConverter()->unpackLLElements( loc, llPtr, rewriter, op.getPtr().getType()); auto cmpElements = getTypeConverter()->unpackLLElements( loc, llCmp, rewriter, op.getCmp().getType()); auto valElements = getTypeConverter()->unpackLLElements( loc, llVal, rewriter, op.getVal().getType()); // deal with tensor or scalar auto valueTy = op.getResult().getType(); auto TensorTy = valueTy.dyn_cast(); Type valueElemTy = TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType()) : valueTy; auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType()); // vec = 1 for scalar auto vec = getVectorSize(op.getPtr()); // tensor if (TensorTy) { auto valTy = op.getVal().getType().cast(); vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); } Value mask = getMask(valueTy, rewriter, loc); auto vecTy = vec_ty(valueElemTy, vec); SmallVector resultVals(elemsPerThread); // atomic ops for (size_t i = 0; i < elemsPerThread; i += vec) { Value casVal = undef(vecTy); for (int ii = 0; ii < vec; ++ii) { Value iiVal = createIndexAttrConstant( rewriter, loc, getTypeConverter()->getIndexType(), ii); casVal = insert_element(vecTy, casVal, valElements[i + ii], iiVal); } Value casPtr = ptrElements[i]; Value casCmp = cmpElements[i]; casVal = valElements[i]; // use op if (TensorTy) { // for tensor auto retType = vec == 1 ? valueElemTy : vecTy; // TODO: USE ATOMIC CAS OP on Tensor auto successOrdering = LLVM::AtomicOrdering::acq_rel; auto failureOrdering = LLVM::AtomicOrdering::monotonic; auto cmpxchg = rewriter.create( loc, casPtr, casCmp, casVal, successOrdering, failureOrdering, StringRef("agent")); // Extract the new_loaded value from the pair. Value ret = extract_val(valueElemTy, cmpxchg, i); for (int ii = 0; ii < vec; ++ii) { resultVals[i + ii] = vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii)); } } else { // for scalar // Build blocks to bypass the atomic instruction for ~rmwMask. auto *curBlock = rewriter.getInsertionBlock(); auto *endBlock = curBlock->splitBlock(rewriter.getInsertionPoint()); auto *atomicBlock = rewriter.createBlock( curBlock->getParent(), std::next(Region::iterator(curBlock))); // Fill entry block with global memory barrier and conditional branch. rewriter.setInsertionPointToEnd(curBlock); Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3)); auto tid = tid_val(); Value pred = icmp_eq(tid, i32_val(i)); rewriter.create(loc, pred, atomicBlock, endBlock); // Build main block with atomic_cmpxchg. rewriter.setInsertionPointToEnd(atomicBlock); auto successOrdering = LLVM::AtomicOrdering::acq_rel; auto failureOrdering = LLVM::AtomicOrdering::monotonic; auto cmpxchg = rewriter.create( loc, casPtr, casCmp, casVal, successOrdering, failureOrdering, StringRef("agent")); // Extract the new_loaded value from the pair. Value newLoaded = extract_val(valueElemTy, cmpxchg, 0); store(newLoaded, atomPtr); rewriter.create(loc, ValueRange(), endBlock); // Build the last block: synced load from shared memory, exit. rewriter.setInsertionPointToStart(endBlock); GCNBuilder BuilderMemfenceLDS; BuilderMemfenceLDS.create<>("s_waitcnt lgkmcnt(0)")->operator()(); BuilderMemfenceLDS.launch(rewriter, loc, void_ty(ctx)); barrier(); Value ret = load(atomPtr); rewriter.replaceOp(op, {ret}); } } // replace op if (TensorTy) { Type structTy = getTypeConverter()->convertType(TensorTy); Value resultStruct = getTypeConverter()->packLLElements( loc, resultVals, rewriter, structTy); rewriter.replaceOp(op, {resultStruct}); } return success(); } #else // USE_ROCM LogicalResult matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); MLIRContext *ctx = rewriter.getContext(); auto moduleOp = op->getParentOfType(); assert(moduleOp && "Parent ModuleOp not found for AtomicCASOp"); int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp); Value llPtr = adaptor.getPtr(); Value llCmp = adaptor.getCmp(); Value llVal = adaptor.getVal(); auto ptrElements = getTypeConverter()->unpackLLElements( loc, llPtr, rewriter, op.getPtr().getType()); auto cmpElements = getTypeConverter()->unpackLLElements( loc, llCmp, rewriter, op.getCmp().getType()); auto valElements = getTypeConverter()->unpackLLElements( loc, llVal, rewriter, op.getVal().getType()); auto valueTy = op.getResult().getType(); auto TensorTy = valueTy.dyn_cast(); Type valueElemTy = TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType()) : valueTy; auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType()); // vec = 1 for scalar auto vec = getVectorSize(op.getPtr()); // tensor if (TensorTy) { auto valTy = op.getVal().getType().cast(); vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); } Value mask = getMask(valueTy, rewriter, loc); auto vecTy = vec_ty(valueElemTy, vec); SmallVector resultVals(elemsPerThread); for (size_t i = 0; i < elemsPerThread; i += vec) { Value casVal = undef(vecTy); for (int ii = 0; ii < vec; ++ii) { Value iiVal = createIndexAttrConstant( rewriter, loc, getTypeConverter()->getIndexType(), ii); casVal = insert_element(vecTy, casVal, valElements[i + ii], iiVal); } Value casPtr = ptrElements[i]; Value casCmp = cmpElements[i]; casVal = valElements[i]; PTXBuilder ptxBuilderAtomicCAS; std::string tyId = valueElemNBits * vec == 64 ? "l" : (valueElemNBits * vec == 32 ? "r" : "h"); auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=" + tyId, /*init=*/true); auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l"); auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, tyId); auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, tyId); auto &atom = *ptxBuilderAtomicCAS.create("atom"); auto sTy = "b" + std::to_string(valueElemNBits); std::string semStr; llvm::raw_string_ostream os(semStr); os << op.getSem(); auto scope = stringifyMemSyncScope(op.getScope()).str(); atom.global().o(semStr).o(scope).o("cas").o(sTy); atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask); if (TensorTy) { auto retType = vec == 1 ? valueElemTy : vecTy; auto ret = ptxBuilderAtomicCAS.launch(rewriter, loc, retType); for (int ii = 0; ii < vec; ++ii) { resultVals[i + ii] = vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii)); } } else { auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy); createBarrier(rewriter, loc, numCTAs); Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3)); // Only threads with mask = True store the result PTXBuilder ptxBuilderStore; auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r"); auto *valOprStore = ptxBuilderStore.newOperand(old, "r"); auto &st = *ptxBuilderStore.create("st"); st.shared().o(sTy); st(dstOprStore, valOprStore).predicate(mask); auto ASMReturnTy = void_ty(ctx); ptxBuilderStore.launch(rewriter, loc, ASMReturnTy); createBarrier(rewriter, loc, numCTAs); Value ret = load(atomPtr); createBarrier(rewriter, loc, numCTAs); rewriter.replaceOp(op, {ret}); } } if (TensorTy) { Type structTy = getTypeConverter()->convertType(TensorTy); Value resultStruct = getTypeConverter()->packLLElements( loc, resultVals, rewriter, structTy); rewriter.replaceOp(op, {resultStruct}); } return success(); } #endif // USE_ROCM }; struct AtomicRMWOpConversion : public ConvertTritonGPUOpToLLVMPattern, public LoadStoreConversionBase { using ConvertTritonGPUOpToLLVMPattern< triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern; AtomicRMWOpConversion(TritonGPUToLLVMTypeConverter &converter, ModuleAllocation &allocation, ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern( converter, allocation, benefit), LoadStoreConversionBase(axisAnalysisPass) {} #ifdef USE_ROCM /// Try to match the mlir::triton::RMWOp to LLVM::AtomicBinOp. static std::optional matchAtomicOp(RMWOp atomicOp) { switch (atomicOp) { case RMWOp::AND: return LLVM::AtomicBinOp::_and; case RMWOp::OR: return LLVM::AtomicBinOp::_or; case RMWOp::XOR: return LLVM::AtomicBinOp::_xor; case RMWOp::ADD: return LLVM::AtomicBinOp::add; case RMWOp::FADD: return LLVM::AtomicBinOp::fadd; case RMWOp::MAX: return LLVM::AtomicBinOp::max; case RMWOp::MIN: return LLVM::AtomicBinOp::min; case RMWOp::UMAX: return LLVM::AtomicBinOp::umax; case RMWOp::UMIN: return LLVM::AtomicBinOp::umin; case RMWOp::XCHG: return LLVM::AtomicBinOp::xchg; default: return std::nullopt; } llvm_unreachable("Invalid RMWOp"); } LogicalResult matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); MLIRContext *ctx = rewriter.getContext(); auto atomicRmwAttr = op.getAtomicRmwOp(); Value ptr = op.getPtr(); Value val = op.getVal(); Value llPtr = adaptor.getPtr(); Value llVal = adaptor.getVal(); Value llMask = adaptor.getMask(); auto valElements = getTypeConverter()->unpackLLElements( loc, llVal, rewriter, val.getType()); auto ptrElements = getTypeConverter()->unpackLLElements( loc, llPtr, rewriter, ptr.getType()); SmallVector maskElements; if (llMask) maskElements = getTypeConverter()->unpackLLElements( loc, llMask, rewriter, op.getMask().getType()); Value opResult = op.getResult(); auto tensorTy = opResult.getType().dyn_cast(); Type valueElemTy = tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType()) : opResult.getType(); const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth(); auto elemsPerThread = getTotalElemsPerThread(val.getType()); // vec = 1, numElements = 1 for scalar auto vec = getVectorSize(ptr); int numElems = 1; // tensor if (tensorTy) { auto valTy = val.getType().cast(); vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); // mask numElems = tensorTy.getNumElements(); } Value mask = int_val(1, 1); auto tid = tid_val(); mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems))); auto vecTy = vec_ty(valueElemTy, vec); auto retType = vec == 1 ? valueElemTy : vecTy; SmallVector resultVals(elemsPerThread); const bool f16v2 = vec == 2 && valueElemTy.isF16(); for (size_t i = 0; i < elemsPerThread; i += vec) { Value rmwPtr = ptrElements[i]; // TODO: in case llMask is zero we can create only one branch for all // elemsPerThread. Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask; Value undefVal = undef(retType); // Build blocks to bypass the atomic instruction for ~rmwMask. auto *curBlock = rewriter.getInsertionBlock(); auto *endBlock = curBlock->splitBlock(rewriter.getInsertionPoint()); auto *atomicBlock = rewriter.createBlock( curBlock->getParent(), std::next(Region::iterator(curBlock))); endBlock->addArgument({retType}, {loc}); rewriter.setInsertionPointToEnd(curBlock); rewriter.create(loc, rmwMask, atomicBlock, endBlock, undefVal); rewriter.setInsertionPointToEnd(atomicBlock); auto maybeKind = matchAtomicOp(atomicRmwAttr); // TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient // atomics for MI-* series of AMD GPU. Value atom = rewriter.create( loc, *maybeKind, rmwPtr, valElements[i], LLVM::AtomicOrdering::monotonic, StringRef("agent")).getResult(); // NV for the f16v2 case generates one packed instruction. We have to // create two separate instructions since LLVM::AtomicRMWOp doesn't // support this. Can be optimized out with rocdl.raw.buffer.atomic. if (f16v2) { Value atom2 = rewriter.create( loc, *maybeKind, ptrElements[i+1], valElements[i + 1], LLVM::AtomicOrdering::monotonic, StringRef("agent")).getResult(); auto tmp = insert_element(vecTy, undef(vecTy), atom, i32_val(0)); atom = insert_element(vecTy, tmp, atom2, i32_val(1)).getResult(); } rewriter.create(loc, atom, endBlock); rewriter.setInsertionPointToStart(endBlock); Value retVal = endBlock->getArgument(0); if (tensorTy) { for (int ii = 0; ii < vec; ++ii) { resultVals[i + ii] = vec == 1 ? retVal : extract_element(valueElemTy, retVal, i32_val(ii)); } } else { Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3)); store(retVal, atomPtr); Value ret = load(atomPtr); rewriter.replaceOp(op, {ret}); } } if (tensorTy) { Type structTy = getTypeConverter()->convertType(tensorTy); Value resultStruct = getTypeConverter()->packLLElements( loc, resultVals, rewriter, structTy); rewriter.replaceOp(op, {resultStruct}); } return success(); } #else // USE_ROCM LogicalResult matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); MLIRContext *ctx = rewriter.getContext(); auto moduleOp = op->getParentOfType(); assert(moduleOp && "Parent ModuleOp not found for AtomicRMWOp"); int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp); auto atomicRmwAttr = op.getAtomicRmwOp(); Value val = op.getVal(); Value ptr = op.getPtr(); Value llPtr = adaptor.getPtr(); Value llVal = adaptor.getVal(); Value llMask = adaptor.getMask(); auto valElements = getTypeConverter()->unpackLLElements( loc, llVal, rewriter, val.getType()); auto ptrElements = getTypeConverter()->unpackLLElements( loc, llPtr, rewriter, ptr.getType()); SmallVector maskElements; if (llMask) maskElements = getTypeConverter()->unpackLLElements( loc, llMask, rewriter, op.getMask().getType()); auto valueTy = op.getResult().getType(); auto tensorTy = valueTy.dyn_cast(); Type valueElemTy = tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType()) : valueTy; const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); auto elemsPerThread = getTotalElemsPerThread(val.getType()); // vec = 1, numElements = 1 for scalar auto vec = getVectorSize(ptr); int numElems = 1; // tensor if (tensorTy) { auto valTy = val.getType().cast(); vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); // mask numElems = tensorTy.getNumElements(); } Value mask = getMask(valueTy, rewriter, loc); auto vecTy = vec_ty(valueElemTy, vec); SmallVector resultVals(elemsPerThread); for (size_t i = 0; i < elemsPerThread; i += vec) { Value rmwVal = undef(vecTy); for (int ii = 0; ii < vec; ++ii) { Value iiVal = createIndexAttrConstant( rewriter, loc, getTypeConverter()->getIndexType(), ii); rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal); } Value rmwPtr = ptrElements[i]; Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask; std::string sTy; PTXBuilder ptxBuilderAtomicRMW; std::string tyId = valueElemNBits * vec == 64 ? "l" : (valueElemNBits * vec == 32 ? "r" : "h"); auto *dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true); auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l"); auto *valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId); auto scope = stringifyMemSyncScope(op.getScope()).str(); auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o(scope); auto rmwOp = stringifyRMWOp(atomicRmwAttr).str(); auto sBits = std::to_string(valueElemNBits); switch (atomicRmwAttr) { case RMWOp::AND: sTy = "b" + sBits; break; case RMWOp::OR: sTy = "b" + sBits; break; case RMWOp::XOR: sTy = "b" + sBits; break; case RMWOp::ADD: sTy = "u" + sBits; break; case RMWOp::FADD: rmwOp = "add"; rmwOp += (valueElemNBits == 16 ? ".noftz" : ""); sTy = "f" + sBits; sTy += (vec == 2 && valueElemNBits == 16) ? "x2" : ""; break; case RMWOp::MAX: sTy = "s" + sBits; break; case RMWOp::MIN: sTy = "s" + sBits; break; case RMWOp::UMAX: rmwOp = "max"; sTy = "u" + sBits; break; case RMWOp::UMIN: rmwOp = "min"; sTy = "u" + sBits; break; case RMWOp::XCHG: sTy = "b" + sBits; break; default: return failure(); } std::string semStr; llvm::raw_string_ostream os(semStr); os << op.getSem(); atom.o(semStr).o(rmwOp).o(sTy); if (tensorTy) { atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); auto retType = vec == 1 ? valueElemTy : vecTy; auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType); for (int ii = 0; ii < vec; ++ii) { resultVals[i + ii] = vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii)); } } else { auto ASMReturnTy = void_ty(ctx); atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy); if (op->user_begin() == op->user_end()) { rewriter.replaceOp(op, {old}); return success(); } Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3)); // Only threads with rmwMask = True store the result PTXBuilder ptxBuilderStore; auto &storeShared = ptxBuilderStore.create<>("st")->shared().o("b" + sBits); auto *ptrOpr = ptxBuilderStore.newAddrOperand(atomPtr, "r"); auto *valOpr = ptxBuilderStore.newOperand(old, tyId); storeShared(ptrOpr, valOpr).predicate(rmwMask); ptxBuilderStore.launch(rewriter, loc, void_ty(ctx)); createBarrier(rewriter, loc, numCTAs); Value ret = load(atomPtr); createBarrier(rewriter, loc, numCTAs); rewriter.replaceOp(op, {ret}); } } if (tensorTy) { Type structTy = getTypeConverter()->convertType(tensorTy); Value resultStruct = getTypeConverter()->packLLElements( loc, resultVals, rewriter, structTy); rewriter.replaceOp(op, {resultStruct}); } return success(); } #endif // USE_ROCM }; struct InsertSliceOpConversion : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< tensor::InsertSliceOp>::ConvertTritonGPUOpToLLVMPattern; LogicalResult matchAndRewrite(tensor::InsertSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // %dst = insert_slice %src into %dst[%offsets] Location loc = op->getLoc(); Value dst = op.getDest(); Value src = op.getSource(); Value res = op.getResult(); auto funcOp = op->getParentOfType(); auto *funcAllocation = allocation->getFuncData(funcOp); assert(funcAllocation->getBufferId(res) == Allocation::InvalidBufferId && "Only support in-place insert_slice for now"); auto srcTy = src.getType().dyn_cast(); auto srcLayout = srcTy.getEncoding(); assert((srcLayout.isa() && "Unexpected srcLayout in InsertSliceOpConversion")); auto srcShape = srcTy.getShape(); assert(srcLayout && "Unexpected srcLayout in InsertSliceOpConversion"); auto dstTy = dst.getType().dyn_cast(); auto dstLayout = dstTy.getEncoding().dyn_cast(); auto llDst = adaptor.getDest(); assert(dstLayout && "Unexpected dstLayout in InsertSliceOpConversion"); assert(op.hasUnitStride() && "Only unit stride supported by InsertSliceOpConversion"); // newBase = base + offset // Triton support either static and dynamic offsets auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter); SmallVector offsets; SmallVector srcStrides; auto mixedOffsets = op.getMixedOffsets(); for (auto i = 0; i < mixedOffsets.size(); ++i) { if (op.isDynamicOffset(i)) { offsets.emplace_back(adaptor.getOffsets()[i]); } else { offsets.emplace_back(i32_val(op.getStaticOffset(i))); } // Like insert_slice_async, we only support slice from one dimension, // which has a slice size of 1 if (op.getStaticSize(i) != 1) { srcStrides.emplace_back(smemObj.strides[i]); } } // Compute the offset based on the original strides of the shared memory // object auto offset = dot(rewriter, loc, offsets, smemObj.strides); auto elemTy = getTypeConverter()->convertType(dstTy.getElementType()); auto elemPtrTy = ptr_ty(elemTy, 3); auto smemBase = gep(elemPtrTy, smemObj.base, offset); auto llSrc = adaptor.getSource(); auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy); storeDistributedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase, elemTy, loc, rewriter); // Barrier is not necessary. // The membar pass knows that it writes to shared memory and will handle it // properly. rewriter.replaceOp(op, llDst); return success(); } }; struct InsertSliceAsyncOpConversion : public ConvertTritonGPUOpToLLVMPattern, public LoadStoreConversionBase { using ConvertTritonGPUOpToLLVMPattern< triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern; InsertSliceAsyncOpConversion( TritonGPUToLLVMTypeConverter &converter, ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern( converter, allocation, indexCacheInfo, benefit), LoadStoreConversionBase(axisAnalysisPass) {} LogicalResult matchAndRewrite(triton::gpu::InsertSliceAsyncOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // insert_slice_async %src, %dst, %index, %mask, %other auto loc = op.getLoc(); Value src = op.getSrc(); Value dst = op.getDst(); Value res = op.getResult(); Value mask = op.getMask(); Value other = op.getOther(); auto funcOp = op->getParentOfType(); auto *funcAllocation = allocation->getFuncData(funcOp); assert(funcAllocation->getBufferId(res) == Allocation::InvalidBufferId && "Only support in-place insert_slice_async for now"); auto srcTy = src.getType().cast(); auto resTy = dst.getType().cast(); auto resElemTy = getTypeConverter()->convertType(resTy.getElementType()); auto srcLayout = srcTy.getEncoding(); assert((srcLayout.isa() && "Unexpected srcLayout in InsertSliceAsyncOpConversion")); auto resSharedLayout = resTy.getEncoding().cast(); auto srcShape = srcTy.getShape(); assert((srcShape.size() == 1 || srcShape.size() == 2) && "insert_slice_async: Unexpected rank of %src"); Value llDst = adaptor.getDst(); Value llSrc = adaptor.getSrc(); Value llMask = adaptor.getMask(); Value llOther = adaptor.getOther(); Value llIndex = adaptor.getIndex(); // %src auto srcElems = getTypeConverter()->unpackLLElements(loc, llSrc, rewriter, src.getType()); // %dst auto dstTy = dst.getType().cast(); auto dstShape = dstTy.getShape(); auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter); auto axis = op->getAttrOfType("axis").getInt(); SmallVector offsetVals; SmallVector srcStrides; for (auto i = 0; i < dstShape.size(); ++i) { if (i == axis) { offsetVals.emplace_back(llIndex); } else { offsetVals.emplace_back(i32_val(0)); srcStrides.emplace_back(smemObj.strides[i]); } } // Compute the offset based on the original dimensions of the shared // memory object auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides); auto dstPtrTy = ptr_ty(resElemTy, 3); Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset); // %mask SmallVector maskElems; if (llMask) { maskElems = getTypeConverter()->unpackLLElements(loc, llMask, rewriter, mask.getType()); assert(srcElems.size() == maskElems.size()); } // %other SmallVector otherElems; if (llOther) { // FIXME(Keren): always assume other is 0 for now // It's not necessary for now because the pipeline pass will skip // generating insert_slice_async if the load op has any "other" tensor. // assert(false && "insert_slice_async: Other value not supported yet"); otherElems = getTypeConverter()->unpackLLElements(loc, llOther, rewriter, other.getType()); assert(srcElems.size() == otherElems.size()); } // We don't use getVec() here because we are copying from memory to memory. // If contiguity > vector size, we can have one pointer maintaining the // start of the vector and the other pointer moving to the next vector. unsigned inVec = getContiguity(src); unsigned outVec = resSharedLayout.getVec(); unsigned minVec = inVec; if (outVec > 1) minVec = std::min(outVec, inVec); unsigned numElems = getTotalElemsPerThread(srcTy); unsigned perPhase = resSharedLayout.getPerPhase(); unsigned maxPhase = resSharedLayout.getMaxPhase(); DenseMap sharedPtrs = getSwizzledSharedPtrs(loc, inVec, srcTy, resSharedLayout, resElemTy, smemObj, rewriter, offsetVals, srcStrides); // A sharedLayout encoding has a "vec" parameter. // On the column dimension, if inVec > outVec, it means we have to divide // single vector read into multiple ones auto numVecCols = std::max(inVec / outVec, 1); for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) { // 16 * 8 = 128bits auto maxBitWidth = std::max(128, resElemTy.getIntOrFloatBitWidth()); auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec; auto bitWidth = std::min(maxBitWidth, vecBitWidth); auto numWords = vecBitWidth / bitWidth; auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth(); // Tune CG and CA here. auto byteWidth = bitWidth / 8; CacheModifier srcCacheModifier = byteWidth == 16 ? CacheModifier::CG : CacheModifier::CA; assert(byteWidth == 16 || byteWidth == 8 || byteWidth == 4); auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8; Value basePtr = sharedPtrs[elemIdx]; for (size_t wordIdx = 0; wordIdx < numWords; ++wordIdx) { PTXBuilder ptxBuilder; auto wordElemIdx = wordIdx * numWordElems; auto ©AsyncOp = *ptxBuilder.create(srcCacheModifier); auto *dstOperand = ptxBuilder.newAddrOperand(basePtr, "r", wordElemIdx * resByteWidth); auto *srcOperand = ptxBuilder.newAddrOperand(srcElems[elemIdx + wordElemIdx], "l"); auto *copySize = ptxBuilder.newConstantOperand(byteWidth); auto *srcSize = copySize; if (op.getMask()) { // We don't use predicate in this case, setting src-size to 0 // if there's any mask. cp.async will automatically fill the // remaining slots with 0 if cp-size > src-size. // XXX(Keren): Always assume other = 0 for now. auto selectOp = select(maskElems[elemIdx + wordElemIdx], i32_val(byteWidth), i32_val(0)); srcSize = ptxBuilder.newOperand(selectOp, "r"); } copyAsyncOp(dstOperand, srcOperand, copySize, srcSize); ptxBuilder.launch(rewriter, loc, void_ty(getContext())); } } rewriter.replaceOp(op, llDst); return success(); } }; struct InsertSliceAsyncV2OpConversion : public ConvertTritonGPUOpToLLVMPattern< triton::nvidia_gpu::InsertSliceAsyncV2Op> { using ConvertTritonGPUOpToLLVMPattern< triton::nvidia_gpu::InsertSliceAsyncV2Op>:: ConvertTritonGPUOpToLLVMPattern; InsertSliceAsyncV2OpConversion(TritonGPUToLLVMTypeConverter &converter, ModuleAllocation &allocation, mlir::triton::gpu::TMAMetadataTy *tmaMetadata, const TensorPtrMapT *tensorPtrMap, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern< triton::nvidia_gpu::InsertSliceAsyncV2Op>(converter, allocation, tmaMetadata, benefit), tensorPtrMap(tensorPtrMap) {} LogicalResult matchAndRewrite(triton::nvidia_gpu::InsertSliceAsyncV2Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); auto resultTy = op.getResult().getType().cast(); auto elemTy = resultTy.getElementType(); auto rank = resultTy.getRank() - 1; // TODO: support any valid rank in (3, 4, 5) // The sotre async op only supports tensor with ranke <= 5. // Reference: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-dimension-size-and-format assert(rank > 0 && rank <= 5); SmallVector shape; auto axis = op->getAttrOfType("axis").getInt(); auto moduleOp = op->getParentOfType(); assert(moduleOp && "Parent ModuleOp not found for InsertSliceAsyncV2Op"); auto llFuncOp = op->getParentOfType(); assert(llFuncOp && "LLVMFuncOp not found for InsertSliceAsyncV2Op"); int numTMADescs = getNumTMADescs(llFuncOp); assert(numTMADescs > 0); auto sharedLayout = resultTy.getEncoding().dyn_cast(); assert(sharedLayout && "unexpected layout of InsertSliceAsyncV2Op"); auto CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA(); auto CTAOrder = sharedLayout.getCTALayout().getCTAOrder(); auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum(); mlir::triton::gpu::TMAInfo tmaInfo; tmaInfo.tensorDataType = getCUtensorMapDataType(elemTy); tmaInfo.tensorRank = rank; assert(tmaMetadata); unsigned TMADescIdx = tmaMetadata->size(); unsigned numFuncArgs = llFuncOp.getBody().front().getNumArguments(); auto makeTensorPtr = tensorPtrMap->lookup(op.getOperation()); auto inOrder = makeTensorPtr.getOrder(); unsigned globalAddressArgIdx = getArgIdx(makeTensorPtr.getBase()); tmaInfo.globalAddressArgIdx = globalAddressArgIdx; tmaInfo.TMADescArgIdx = numFuncArgs - numTMADescs + TMADescIdx; auto getDimOfOrder = [](ArrayRef order, int32_t i) { auto it = std::find(order.begin(), order.end(), i); assert(it != order.end()); return std::distance(order.begin(), it); }; std::vector globalDimsArgIdx; std::vector globalStridesArgIdx; // constant values are mapped to (-1 - value) for (int i = 0; i < rank; ++i) { int32_t argIdx = -1; auto dim = getDimOfOrder(inOrder, i); argIdx = getArgIdx(makeTensorPtr.getShape()[dim]); globalDimsArgIdx.emplace_back(argIdx); // handle constant stride argIdx = getArgIdx(makeTensorPtr.getStrides()[dim]); globalStridesArgIdx.emplace_back(argIdx); } tmaInfo.globalDimsArgIdx = globalDimsArgIdx; tmaInfo.globalStridesArgIdx = globalStridesArgIdx; std::vector boxDims; auto tensorShape = makeTensorPtr.getResult() .getType() .cast() .getPointeeType() .cast() .getShape(); SmallVector numMcast(rank); unsigned accNumMcast = 1; for (unsigned i = 0; i < rank; ++i) { numMcast[i] = CTAsPerCGA[i] / CTASplitNum[i]; accNumMcast *= numMcast[i]; } auto shapePerCTA = getShapePerCTA(CTASplitNum, tensorShape); for (size_t i = 0; i < rank; ++i) { auto dim = getDimOfOrder(inOrder, i); // in case of TMA multicast, we should always slice along higher order // dimensions if (i == rank - 1) { assert(shapePerCTA[dim] >= accNumMcast && "cases when the size of the highest order is smaller " "than numMcasts is not implemented"); boxDims.emplace_back(shapePerCTA[dim] / accNumMcast); } else { boxDims.emplace_back(shapePerCTA[dim]); } } std::vector elementStrides(rank, 1); tmaInfo.elementStrides = elementStrides; CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE; if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2) swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B; else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4) swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B; else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8) swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B; else llvm::report_fatal_error( "Unsupported shared layout for InsertSliceAsyncV2Op"); tmaInfo.swizzle = swizzle; tmaInfo.interleave = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE; tmaInfo.l2Promotion = CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B; tmaInfo.oobFill = CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; uint32_t numBoxes = 1; uint32_t elemSizeOfBytes = elemTy.getIntOrFloatBitWidth() / 8; if (swizzle == CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) { while (elemSizeOfBytes * boxDims[0] > 128) { boxDims[0] = boxDims[0] / 2; numBoxes *= 2; } } tmaInfo.boxDims = boxDims; tmaMetadata->emplace_back(tmaInfo); uint32_t elemsPerBox = std::accumulate(boxDims.begin(), boxDims.end(), 1, std::multiplies{}); Value clusterCTAId = getClusterCTAId(rewriter, loc); SmallVector multiDimClusterCTAId = delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); Value llDst = adaptor.getDst(); Value llIndex = adaptor.getIndex(); Value src = op.getSrc(); Value dst = op.getDst(); auto dstTy = dst.getType().cast(); auto dstShape = dstTy.getShape(); auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter); // the offset of coord considering multicast slicing SmallVector mcastOffsetVals; // The index of slice is this CTAId is responsible for SmallVector multiDimSliceIdx(rank); for (auto i = 0; i < rank; ++i) multiDimSliceIdx[i] = udiv(multiDimClusterCTAId[i], i32_val(CTASplitNum[i])); Value sliceIdx = linearize(rewriter, loc, multiDimSliceIdx, numMcast, CTAOrder); Value sliceCoord; for (auto i = 0; i < rank; ++i) { if (inOrder[i] == rank - 1) { // TODO[goostavz]: Cases when the size of the highest order is smaller // than numMcasts is not implemented. sliceCoord = mul(sliceIdx, i32_val(shapePerCTA[i] / accNumMcast)); mcastOffsetVals.emplace_back( mul(sliceIdx, i32_val(shapePerCTA[i] / accNumMcast))); } else { mcastOffsetVals.emplace_back(i32_val(0)); } } uint32_t elemsPerSlice = std::accumulate( shapePerCTA.begin(), shapePerCTA.end(), 1, std::multiplies{}); Value dstOffsetCommon = mul(llIndex, i32_val(elemsPerSlice)); // [benzh] sliceCoord should be higher dimension's multiplier accumulate. // currently only support rank == 2. dstOffsetCommon = add(dstOffsetCommon, mul(sliceCoord, i32_val(boxDims[0]))); auto dstPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); Value tmaDesc = llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx); // TODO: sink this logic into Triton::NVGPU dialect and support more // cache-policy modes Value l2Desc = int_val(64, 0x1000000000000000ll); auto ptrI8SharedTy = LLVM::LLVMPointerType::get( typeConverter->convertType(rewriter.getI8Type()), 3); SmallVector coordCommon; auto llCoord = getTypeConverter()->unpackLLElements( loc, adaptor.getSrc(), rewriter, src.getType()); for (int i = 0; i < rank; ++i) { auto dim = getDimOfOrder(inOrder, i); Value coordDim = bitcast(llCoord[dim], i32_ty); if (CTASplitNum[dim] != 1) { // Add offset for each CTA // boxDims[i] * (multiDimClusterCTAId[i] % CTASplitNum[i]); auto CTAOffset = mul(i32_val(shapePerCTA[dim]), urem(multiDimClusterCTAId[dim], i32_val(CTASplitNum[dim]))); coordDim = add(coordDim, CTAOffset); } if (i == rank - 1) // Add offset in case of multicast slicing coordCommon.push_back(add(coordDim, mcastOffsetVals[dim])); else coordCommon.push_back(coordDim); } auto threadId = getThreadId(rewriter, loc); Value pred = icmp_eq(threadId, i32_val(0)); auto mask = adaptor.getMask(); if (mask) { // TODO(thomas): What is the right implementation for this case? assert(mask.getType().isInteger(1) && "need to implement cases with tensor mask"); pred = rewriter.create(loc, pred, mask); } Value mcastMask = getMCastMask(sharedLayout, rewriter, loc, clusterCTAId); for (size_t i = 0; i < numBoxes; ++i) { Value dstOffset = add(dstOffsetCommon, i32_val(i * elemsPerBox * accNumMcast)); Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset); SmallVector coord = coordCommon; coord[0] = add(coordCommon[0], i32_val(i * boxDims[0])); rewriter.create( loc, bitcast(dstPtrBase, ptrI8SharedTy), adaptor.getMbar(), tmaDesc, l2Desc, pred, coord, mcastMask); } rewriter.replaceOp(op, llDst); return success(); } private: Value getMCastMask(const SharedEncodingAttr &sharedLayout, ConversionPatternRewriter &rewriter, Location loc, Value clusterCTAId) const { auto CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA(); auto CTAOrder = sharedLayout.getCTALayout().getCTAOrder(); auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum(); // Short path when no multicast is needed if (CTAsPerCGA == CTASplitNum) return nullptr; // Short path when bcastMask is a constant bool isConstMcastMask = true; for (unsigned s : CTASplitNum) { if (s > 1) { isConstMcastMask = false; break; } } if (isConstMcastMask) { unsigned numCTAs = std::accumulate(CTAsPerCGA.begin(), CTAsPerCGA.end(), 1, std::multiplies{}); return int_val(/*width*/ 16, (1u << numCTAs) - 1); } SmallVector multiDimCTAId = delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); auto rank = CTAOrder.size(); SmallVector> multiDimMask(rank); unsigned accNumMcast = 1; SmallVector numMcast(rank); for (unsigned i = 0; i < rank; ++i) { // For the ith dimension, CTAsPerCGA[i]/CTASplitNum[i] vals is to be // broadcasted, which for this CTAId is: // multiDimCTAId[i] % CTASplitNum[i] + (0 .. // (CTAsPerCGA[i]/CTASplitNum[i] - 1)) * CTASplitNum[i] // TODO: will there be cases if CTAsPerCGA[i]/CTASplitNum[i] < 1? Value rem = urem(multiDimCTAId[i], i32_val(CTASplitNum[i])); numMcast[i] = CTAsPerCGA[i] / CTASplitNum[i]; accNumMcast *= numMcast[i]; for (unsigned j = 0; j < numMcast[i]; ++j) { if (j == 0) { multiDimMask[i].push_back(rem); } else { multiDimMask[i].push_back(add(rem, i32_val(j * CTASplitNum[i]))); } } } Value bcastMask = int_val(/*width*/ 16, 0); Value _1_i16 = int_val(/*width*/ 16, 1); for (unsigned i = 0; i < accNumMcast; ++i) { SmallVector multiDimIdx = getMultiDimIndex(i, numMcast, CTAOrder); SmallVector multiDimMaskedCTAId(rank); for (unsigned dim = 0; dim < rank; ++dim) { multiDimMaskedCTAId[dim] = multiDimMask[dim][multiDimIdx[dim]]; } Value bcastCTAId = linearize(rewriter, loc, multiDimMaskedCTAId, CTAsPerCGA, CTAOrder); // bcastMask |= 1u << bcastCTAId; bcastMask = or_(bcastMask, shl(_1_i16, trunc(i16_ty, bcastCTAId))); } return bcastMask; } unsigned getArgIdx(Value v) const { if (auto op = v.getDefiningOp()) { return -1 - op.getValue().dyn_cast().getValue().getZExtValue(); } if (!isa(v) && !isa( v.getDefiningOp())) llvm::report_fatal_error( "Operand of `MakeTensorPtrOp` is not the function's argument"); if (v.getDefiningOp() && isa(v.getDefiningOp())) { return getArgIdx(v.getDefiningOp()->getOperand(0)); } else if (v.getParentBlock()->isEntryBlock() && v.isa()) { // in entryblock and is BlockArgument; Because argument of func are // arugments of entryblock bb0 in MLIR return v.cast().getArgNumber(); } else if (v.getParentBlock()->isEntryBlock() && (!v.isa())) { // in entryblock but not BlockArgument return getArgIdx(v.getDefiningOp()->getOperand(0)); } else if (!v.getParentBlock()->isEntryBlock()) { // in non-entryblock return getArgIdx(v.getDefiningOp()->getOperand(0)); } else { llvm::report_fatal_error( "Operand of `MakeTensorPtrOp` is not the function's argument"); return 0; } } int getNumTMADescs(LLVM::LLVMFuncOp func) const { if (!func->hasAttr(kAttrNumTMALoadDescsName)) { llvm::report_fatal_error("TritonGPU module should contain a " "triton_gpu.num-tma-load attribute"); return -1; } if (!func->hasAttr(kAttrNumTMAStoreDescsName)) { llvm::report_fatal_error("TritonGPU module should contain a " "triton_gpu.num-tma-store attribute"); return -1; } return func->getAttr(kAttrNumTMAStoreDescsName) .cast() .getInt() + func->getAttr(kAttrNumTMALoadDescsName).cast().getInt(); } const TensorPtrMapT *tensorPtrMap; }; void populateLoadStoreOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, mlir::triton::gpu::TMAMetadataTy *tmaMetadata, const TensorPtrMapT *tensorPtrMap, PatternBenefit benefit) { patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, allocation, axisInfoAnalysis, benefit); patterns.add(typeConverter, allocation, axisInfoAnalysis, benefit); patterns.add(typeConverter, allocation, indexCacheInfo, benefit); patterns.add( typeConverter, allocation, indexCacheInfo, axisInfoAnalysis, benefit); patterns.add( typeConverter, allocation, tmaMetadata, tensorPtrMap, benefit); patterns.add(typeConverter, allocation, tmaMetadata, tensorPtrMap, benefit); }