Merge commit '5df904233c11a65bd131ead7268f84cca7804275' into ifu230810-2

Conflicts:
	include/triton/Dialect/Triton/Transforms/Passes.h
	include/triton/Dialect/TritonGPU/IR/Dialect.h
	include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
	lib/Analysis/Allocation.cpp
	lib/Analysis/Utility.cpp
	lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
	lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp
	lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
	lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp
	lib/Target/LLVMIR/LLVMIRTranslation.cpp
	python/src/triton.cc
	python/triton/compiler/compiler.py
	python/triton/ops/flash_attention.py
	python/triton/runtime/autotuner.py
	python/triton/runtime/jit.py
	python/triton/tools/aot.py
	python/tutorials/06-fused-attention.py
	test/Conversion/tritongpu_to_llvm.mlir
	test/Target/tritongpu_to_llvmir.mlir
	test/Target/tritongpu_to_llvmir_noinline.mlir
This commit is contained in:
Jason Furmanek
2023-09-01 03:25:33 +00:00
122 changed files with 7341 additions and 2234 deletions

View File

@@ -408,6 +408,10 @@ struct StoreOpConversion
auto &ptxStoreInstr =
ptxBuilder.create<>("st")
->global()
.o("wb", op.getCache() == triton::CacheModifier::WB)
.o("cg", op.getCache() == triton::CacheModifier::CG)
.o("cs", op.getCache() == triton::CacheModifier::CS)
.o("wt", op.getCache() == triton::CacheModifier::WT)
.o("L1::evict_first",
op.getEvict() == triton::EvictionPolicy::EVICT_FIRST)
.o("L1::evict_last",
@@ -1050,7 +1054,9 @@ struct InsertSliceAsyncOpConversion
// start of the vector and the other pointer moving to the next vector.
unsigned inVec = getContiguity(src);
unsigned outVec = resSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
unsigned minVec = inVec;
if (outVec > 1)
minVec = std::min(outVec, inVec);
unsigned numElems = getTotalElemsPerThread(srcTy);
unsigned perPhase = resSharedLayout.getPerPhase();
unsigned maxPhase = resSharedLayout.getMaxPhase();