mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Initial commit to resolve merge conflicts
rename tl.float8e4 to tl.float8e4nv to align with upstream ROCM IFU: Fix python arch issues ROCM IFU: Fix kernel launcher ROCM IFU: Fix merge conflicts fix debug build Set correct threadsPerCTA
This commit is contained in:
9
.gitignore
vendored
9
.gitignore
vendored
@@ -25,14 +25,5 @@ venv.bak/
|
||||
.idea
|
||||
cmake-build-*
|
||||
|
||||
<<<<<<< HEAD
|
||||
# cache dumps
|
||||
triton_cache*
|
||||
log_*
|
||||
|
||||
#
|
||||
python/triton/third_party/cuda/bin/ptxas
|
||||
=======
|
||||
# Third-party binaries
|
||||
ptxas
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
|
||||
@@ -200,6 +200,10 @@ include_directories(${LLVM_INCLUDE_DIRS})
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
|
||||
|
||||
set(ROCM_LIBRARIES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lib/rocm/libhsa-runtime64.so
|
||||
)
|
||||
|
||||
# link_directories(${LLVM_LIBRARY_DIR})
|
||||
add_subdirectory(include)
|
||||
add_subdirectory(lib)
|
||||
@@ -239,10 +243,6 @@ if(TRITON_BUILD_PYTHON_MODULE)
|
||||
MLIRIR
|
||||
)
|
||||
|
||||
set(ROCM_LIBRARIES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lib/rocm/libhsa-runtime64.so
|
||||
)
|
||||
|
||||
if(WIN32)
|
||||
target_link_libraries(triton PRIVATE ${ROCM_LIBRARIES} ${LLVM_LIBRARIES} ${CMAKE_DL_LIBS}
|
||||
${TRITON_LIBRARIES}
|
||||
|
||||
@@ -124,20 +124,14 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
|
||||
}
|
||||
|
||||
llvm::LLVMContext llvmContext;
|
||||
<<<<<<< HEAD
|
||||
mlir::triton::gpu::TMAMetadataTy tmaInfos;
|
||||
#ifdef USE_ROCM
|
||||
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module,
|
||||
SMArch.getValue(), true /*isRocm*/);
|
||||
SMArch.getValue(), tmaInfos, Target::ROCDL);
|
||||
#else
|
||||
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module,
|
||||
SMArch.getValue(), false /*isRocm*/);
|
||||
SMArch.getValue(), tmaInfos, Target::Default);
|
||||
#endif
|
||||
=======
|
||||
mlir::triton::gpu::TMAMetadataTy tmaInfos;
|
||||
auto llvmir = translateTritonGPUToLLVMIR(
|
||||
&llvmContext, *module, SMArch.getValue(), tmaInfos, Target::Default);
|
||||
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
if (!llvmir) {
|
||||
llvm::errs() << "Translate to LLVM IR failed";
|
||||
}
|
||||
|
||||
@@ -21,18 +21,8 @@ enum Target { NVVM, ROCDL, Default = NVVM };
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
createConvertTritonGPUToLLVMPass(int computeCapability = 80,
|
||||
bool isROCM = true);
|
||||
#else
|
||||
createConvertTritonGPUToLLVMPass(int computeCapability = 80,
|
||||
bool isROCM = false);
|
||||
#endif
|
||||
=======
|
||||
createConvertTritonGPUToLLVMPass(const ConvertTritonGPUToLLVMOptions &options);
|
||||
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
} // namespace triton
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -147,11 +147,11 @@ compared to 1*64 when the hasLeadingOffset is false.
|
||||
int vecSize = ((typeWidthInBit == 16) ? 64 : 32 ) / typeWidthInBit;
|
||||
int maxPhase = SIMDWidth / perPhase;
|
||||
|
||||
return $_get(context, vecSize, perPhase, maxPhase, order);
|
||||
return get(context, vecSize, perPhase, maxPhase, order, CTALayout);
|
||||
} else {
|
||||
// Do not swizzle in case k dimension is not innermost.
|
||||
// In this case accesses will go in different banks even without swizzling.
|
||||
return $_get(context, 1, 1, 1, order);
|
||||
return get(context, 1, 1, 1, order, CTALayout);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -185,20 +185,12 @@ compared to 1*64 when the hasLeadingOffset is false.
|
||||
|
||||
// ---- begin Ampere ----
|
||||
if (mmaEnc.isAmpere()) {
|
||||
<<<<<<< HEAD
|
||||
int perPhase = 128 / (shape[order[0]] * 4 / dotOpEnc.getKWidth());
|
||||
=======
|
||||
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth());
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
|
||||
// for now, disable swizzle when using transposed int8 tensor cores
|
||||
if ((32 / typeWidthInBit != dotOpEnc.getKWidth()) && order[0] == inner)
|
||||
<<<<<<< HEAD
|
||||
return $_get(context, 1, 1, 1, order);
|
||||
=======
|
||||
return get(context, 1, 1, 1, order, CTALayout);
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
|
||||
// --- handle A operand ---
|
||||
if (opIdx == 0) { // compute swizzling for A operand
|
||||
|
||||
@@ -52,7 +52,6 @@ def TritonGPU_Dialect : Dialect {
|
||||
}
|
||||
return threadsPerWarp.cast<IntegerAttr>().getInt();
|
||||
}
|
||||
<<<<<<< HEAD
|
||||
static int getSharedSize(ModuleOp mod) {
|
||||
Attribute sharedAttr = mod->getDiscardableAttr("triton_gpu.shared");
|
||||
if(!sharedAttr) {
|
||||
@@ -61,8 +60,6 @@ def TritonGPU_Dialect : Dialect {
|
||||
return sharedAttr.cast<IntegerAttr>().getInt();
|
||||
}
|
||||
|
||||
=======
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
}];
|
||||
|
||||
let useDefaultAttributePrinterParser = 1;
|
||||
|
||||
@@ -356,9 +356,6 @@ bool supportMMA(triton::DotOp op, int version) {
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
|
||||
auto aElemTy = op.getA().getType().cast<RankedTensorType>().getElementType();
|
||||
auto bElemTy = op.getB().getType().cast<RankedTensorType>().getElementType();
|
||||
<<<<<<< HEAD
|
||||
|
||||
=======
|
||||
if (version == 3) {
|
||||
if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3"))
|
||||
return false;
|
||||
@@ -374,7 +371,6 @@ bool supportMMA(triton::DotOp op, int version) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
if (aElemTy.isF32() && bElemTy.isF32()) {
|
||||
return (op.getAllowTF32() && version == 2) || version == 3;
|
||||
}
|
||||
@@ -446,7 +442,6 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
!srcTy.getElementType().isF32();
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
@@ -464,7 +459,7 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
|
||||
}
|
||||
#endif
|
||||
=======
|
||||
|
||||
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
auto src = srcTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
|
||||
auto dst = dstTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
|
||||
@@ -475,7 +470,6 @@ bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
dst.getVersionMajor() == 3 && dst.getWarpsPerCTA()[1] == 1 &&
|
||||
srcElemsPerThread == dstElemsPerThread;
|
||||
}
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
|
||||
bool isSingleValue(Value value) {
|
||||
// Don't consider load as expensive if it is loading a scalar.
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
add_library(rocm_libraries SHARED IMPORTED )
|
||||
set_target_properties(rocm_libraries PROPERTIES IMPORTED_LOCATION ${ROCM_LIBRARIES})
|
||||
|
||||
add_mlir_conversion_library(TritonGPUToLLVM
|
||||
ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp
|
||||
ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp
|
||||
@@ -62,4 +65,5 @@ add_mlir_conversion_library(TritonGPUToLLVM
|
||||
TritonGPUTransforms
|
||||
TritonNvidiaGPUTransforms
|
||||
NVGPUIR
|
||||
rocm_libraries
|
||||
)
|
||||
|
||||
@@ -253,7 +253,7 @@ private:
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, type);
|
||||
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, type, false);
|
||||
SmallVector<SmallVector<unsigned>> offsets;
|
||||
assert(rank == 2);
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
|
||||
@@ -573,11 +573,7 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, Value tensor,
|
||||
int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth;
|
||||
int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth;
|
||||
|
||||
<<<<<<< HEAD
|
||||
auto numRep = encoding.getMMAv2Rep(tensorTy.getShape(), bitwidth);
|
||||
=======
|
||||
auto numRep = encoding.getMMAv2Rep(shapePerCTA, bitwidth);
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
int kWidth = encoding.getKWidth();
|
||||
|
||||
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
||||
|
||||
@@ -25,13 +25,11 @@ LogicalResult convertMMA16816(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
#endif
|
||||
=======
|
||||
LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Value thread);
|
||||
@@ -41,7 +39,6 @@ LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value thread);
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
|
||||
struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
@@ -72,13 +69,10 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
return convertMMA1688(op, adaptor, getTypeConverter(), rewriter);
|
||||
if (mmaLayout.isAmpere())
|
||||
return convertMMA16816(op, adaptor, getTypeConverter(), rewriter);
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
if (mmaLayout.isHopper())
|
||||
return convertWGMMA(op, adaptor, getTypeConverter(), rewriter,
|
||||
getThreadId(rewriter, loc));
|
||||
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
llvm::report_fatal_error(
|
||||
"Unsupported MMA kind found when converting DotOp to LLVM.");
|
||||
}
|
||||
|
||||
@@ -10,15 +10,14 @@ using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
const SmallVector<Value> &v) {
|
||||
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||
Value fp16x2Vec0 = undef(fp16x2VecTy);
|
||||
Value fp16x2Vec1 = undef(fp16x2VecTy);
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0));
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[2], i32_val(0));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[3], i32_val(1));
|
||||
|
||||
Value a0 = bitcast(fp16x2Vec0, i32_ty);
|
||||
Value a1 = bitcast(fp16x2Vec1, i32_ty);
|
||||
@@ -58,20 +57,19 @@ const std::string Fp16_to_Fp8E5M2 =
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
const SmallVector<Value> &v) {
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
Value a0 = undef(fp8x4VecTy);
|
||||
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1));
|
||||
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3));
|
||||
a0 = bitcast(a0, i32_ty);
|
||||
Value a1 = undef(fp8x4VecTy);
|
||||
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v[2], i32_val(1));
|
||||
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v[3], i32_val(3));
|
||||
a1 = bitcast(a1, i32_ty);
|
||||
|
||||
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||
@@ -94,21 +92,20 @@ const std::string Fp8E5M2_to_Fp16 = "{ \n"
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp8E5M2_to_Bf16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
const SmallVector<Value> &v) {
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
Value a0 = undef(fp8x4VecTy);
|
||||
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1));
|
||||
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3));
|
||||
a0 = bitcast(a0, i32_ty);
|
||||
|
||||
Value a1 = undef(fp8x4VecTy);
|
||||
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v[2], i32_val(1));
|
||||
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v[3], i32_val(3));
|
||||
a1 = bitcast(a1, i32_ty);
|
||||
|
||||
Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
|
||||
@@ -155,15 +152,14 @@ const std::string Fp8E5M2_to_Bf16 =
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Bf16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
const SmallVector<Value> &v) {
|
||||
auto bf16x2VecTy = vec_ty(i16_ty, 2);
|
||||
Value bf16x2Vec0 = undef(bf16x2VecTy);
|
||||
Value bf16x2Vec1 = undef(bf16x2VecTy);
|
||||
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v0, i32_val(0));
|
||||
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v1, i32_val(1));
|
||||
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v2, i32_val(0));
|
||||
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v3, i32_val(1));
|
||||
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v[0], i32_val(0));
|
||||
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v[1], i32_val(1));
|
||||
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v[2], i32_val(0));
|
||||
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v[3], i32_val(1));
|
||||
bf16x2Vec0 = bitcast(bf16x2Vec0, i32_ty);
|
||||
bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty);
|
||||
|
||||
@@ -276,21 +272,20 @@ const std::string Bf16_to_Fp8E5M2 =
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp8E4M3B15_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
const SmallVector<Value> &v) {
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
Value a0 = undef(fp8x4VecTy);
|
||||
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1));
|
||||
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3));
|
||||
a0 = bitcast(a0, i32_ty);
|
||||
|
||||
Value a1 = undef(fp8x4VecTy);
|
||||
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v[2], i32_val(1));
|
||||
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v[3], i32_val(3));
|
||||
a1 = bitcast(a1, i32_ty);
|
||||
|
||||
Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
|
||||
@@ -325,22 +320,20 @@ const std::string Fp8E4M3B15_to_Fp16 =
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
|
||||
"shl.b32 $1, b1, 7; \n"
|
||||
"} \n";
|
||||
<<<<<<< HEAD
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp16_to_Fp8E4M3B15(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
const SmallVector<Value> &v) {
|
||||
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||
Value fp16x2Vec0 = undef(fp16x2VecTy);
|
||||
Value fp16x2Vec1 = undef(fp16x2VecTy);
|
||||
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0));
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[2], i32_val(0));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[3], i32_val(1));
|
||||
|
||||
Value fp16x2VecMin = i32_val(0xBF80BF80);
|
||||
Value fp16x2VecMax = i32_val(0x3F803F80);
|
||||
@@ -374,29 +367,6 @@ Fp16_to_Fp8E4M3B15(Location loc, ConversionPatternRewriter &rewriter,
|
||||
};
|
||||
}
|
||||
#else
|
||||
const std::string Fp16_to_Fp8E4M3B15 =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n"
|
||||
".reg .b32 min_val, max_val; \n"
|
||||
"mov.b32 min_val, 0xBF80BF80; \n"
|
||||
"mov.b32 max_val, 0x3F803F80; \n"
|
||||
"max.f16x2 $1, $1, min_val; \n"
|
||||
"min.f16x2 $1, $1, max_val; \n"
|
||||
"max.f16x2 $2, $2, min_val; \n"
|
||||
"min.f16x2 $2, $2, max_val; \n"
|
||||
"shl.b32 a0, $1, 1; \n"
|
||||
"shl.b32 a1, $2, 1; \n"
|
||||
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"add.u32 a0, a0, 0x00800080; \n"
|
||||
"add.u32 a1, a1, 0x00800080; \n"
|
||||
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n"
|
||||
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n"
|
||||
"prmt.b32 $0, b0, b1, 0x7531; \n"
|
||||
"}";
|
||||
#endif
|
||||
=======
|
||||
|
||||
const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
|
||||
std::string ret;
|
||||
ret += "{ \n"
|
||||
@@ -431,7 +401,7 @@ const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
|
||||
"}";
|
||||
return ret;
|
||||
}
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
#endif
|
||||
|
||||
/* ----- FP8E4M3B15X4 ------ */
|
||||
// NOTE: NOT USED RIGHT NOW
|
||||
@@ -446,8 +416,7 @@ const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp8E4M3B15x4_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
const SmallVector<Value> &v) {
|
||||
return {};
|
||||
}
|
||||
#else
|
||||
@@ -472,8 +441,7 @@ const std::string Fp8E4M3B15x4_to_Fp16 =
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp16_to_Fp8E4M3B15x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
const SmallVector<Value> &v) {
|
||||
return {};
|
||||
}
|
||||
#else
|
||||
@@ -488,7 +456,6 @@ const std::string Fp16_to_Fp8E4M3B15x4 =
|
||||
"}";
|
||||
#endif
|
||||
|
||||
<<<<<<< HEAD
|
||||
/* ----- FP8E4M3 ------ */
|
||||
// Note: when handled by software, this format
|
||||
// does not handle denormals and has
|
||||
@@ -498,21 +465,20 @@ const std::string Fp16_to_Fp8E4M3B15x4 =
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp8E4M3_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
const SmallVector<Value> &v) {
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
Value a0 = undef(fp8x4VecTy);
|
||||
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1));
|
||||
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3));
|
||||
a0 = bitcast(a0, i32_ty);
|
||||
|
||||
Value a1 = undef(fp8x4VecTy);
|
||||
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v[2], i32_val(1));
|
||||
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v[3], i32_val(3));
|
||||
a1 = bitcast(a1, i32_ty);
|
||||
|
||||
Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
|
||||
@@ -559,16 +525,15 @@ const std::string Fp8E4M3_to_Fp16 =
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp16_to_Fp8E4M3(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
const SmallVector<Value> &v) {
|
||||
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||
Value fp16x2Vec0 = undef(fp16x2VecTy);
|
||||
Value fp16x2Vec1 = undef(fp16x2VecTy);
|
||||
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0));
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[2], i32_val(0));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[3], i32_val(1));
|
||||
|
||||
fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
|
||||
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
|
||||
@@ -618,21 +583,20 @@ const std::string Fp16_to_Fp8E4M3 =
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp8E4M3_to_Bf16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
const SmallVector<Value> &v) {
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
Value a0 = undef(fp8x4VecTy);
|
||||
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1));
|
||||
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3));
|
||||
a0 = bitcast(a0, i32_ty);
|
||||
|
||||
Value a1 = undef(fp8x4VecTy);
|
||||
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v[2], i32_val(1));
|
||||
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v[3], i32_val(3));
|
||||
a1 = bitcast(a1, i32_ty);
|
||||
|
||||
Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
|
||||
@@ -679,15 +643,14 @@ const std::string Fp8E4M3_to_Bf16 =
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Bf16_to_Fp8E4M3(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
const SmallVector<Value> &v) {
|
||||
auto bf16x2VecTy = vec_ty(i16_ty, 2);
|
||||
Value bf16x2Vec0 = undef(bf16x2VecTy);
|
||||
Value bf16x2Vec1 = undef(bf16x2VecTy);
|
||||
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v0, i32_val(0));
|
||||
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v1, i32_val(1));
|
||||
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v2, i32_val(0));
|
||||
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v3, i32_val(1));
|
||||
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v[0], i32_val(0));
|
||||
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v[1], i32_val(1));
|
||||
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v[2], i32_val(0));
|
||||
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v[3], i32_val(1));
|
||||
bf16x2Vec0 = bitcast(bf16x2Vec0, i32_ty);
|
||||
bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty);
|
||||
|
||||
@@ -788,7 +751,7 @@ const std::string Bf16_to_Fp8E4M3 =
|
||||
"or.b32 $0, nosign, sign; \n" // restore sign
|
||||
"}";
|
||||
#endif
|
||||
=======
|
||||
|
||||
// Fp8E4M3 (x2) -> Fp16 (x2) (packed)
|
||||
const std::string Fp8E4M3Nv_to_Fp16 = "{ \n"
|
||||
"cvt.rn.f16x2.e4m3x2 $0, $1; \n"
|
||||
@@ -797,7 +760,6 @@ const std::string Fp8E4M3Nv_to_Fp16 = "{ \n"
|
||||
const std::string Fp16_to_Fp8E4M3Nv = "{ \n"
|
||||
"cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; \n"
|
||||
"}";
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
|
||||
/* ----- Packed integer to BF16 ------ */
|
||||
#ifndef USE_ROCM
|
||||
@@ -1245,18 +1207,23 @@ struct FpToFpOpConversion
|
||||
// F8 -> F16
|
||||
{{F8E4M3B15TyID, F16TyID}, Fp8E4M3B15_to_Fp16},
|
||||
{{F8E4M3FNTyID, F16TyID}, Fp8E4M3B15x4_to_Fp16},
|
||||
{{F8E4M3TyID, F16TyID}, Fp8E4M3Nv_to_Fp16},
|
||||
{{F8E4M3TyID, F16TyID}, Fp8E4M3_to_Fp16},
|
||||
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16},
|
||||
// F16 -> F8
|
||||
#ifdef USE_ROCM
|
||||
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15},
|
||||
#else
|
||||
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15(computeCapability >= 80)},
|
||||
#endif
|
||||
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
|
||||
{{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3Nv},
|
||||
{{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3},
|
||||
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2},
|
||||
// F8 -> BF16
|
||||
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
|
||||
// BF16 -> F8
|
||||
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2},
|
||||
};
|
||||
|
||||
int inVecWidthBits = 32;
|
||||
int outVecWidthBits = 32;
|
||||
if (srcTy.isFloat8E4M3FNUZ()) {
|
||||
@@ -1274,15 +1241,9 @@ struct FpToFpOpConversion
|
||||
<< "\n";
|
||||
llvm_unreachable("");
|
||||
}
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
return srcMap.lookup(key);
|
||||
#else
|
||||
return makeConverterFromPtx(srcMap.lookup(key),
|
||||
getTypeConverter()->convertType(srcTy),
|
||||
getTypeConverter()->convertType(dstTy));
|
||||
#endif
|
||||
=======
|
||||
if (computeCapability < 90 &&
|
||||
(srcTy.isFloat8E4M3FNUZ() || dstTy.isFloat8E4M3FNUZ())) {
|
||||
llvm::errs() << "Conversion from/to f8e4m3nv is only supported on "
|
||||
@@ -1294,7 +1255,7 @@ struct FpToFpOpConversion
|
||||
getTypeConverter()->convertType(srcTy),
|
||||
getTypeConverter()->convertType(dstTy),
|
||||
inVecWidthBits, outVecWidthBits);
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
#endif
|
||||
}
|
||||
|
||||
SmallVector<Value> createDestOps(triton::FpToFpOp op, OpAdaptor adaptor,
|
||||
@@ -1712,9 +1673,8 @@ struct FSubOpConversion
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
S8_to_Bf16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
SmallVector<Value> inValues = {v0, v1, v2, v3};
|
||||
const SmallVector<Value> &v) {
|
||||
SmallVector<Value> inValues = {v[0], v[1], v[2], v[3]};
|
||||
SmallVector<Value> outValues = {};
|
||||
for (Value inVal : inValues) {
|
||||
Value i32Val = sext(i32_ty, inVal);
|
||||
@@ -1751,21 +1711,17 @@ struct SIToFPOpConversion
|
||||
Type outElemTy = getElementType(op.getOut());
|
||||
if (outElemTy.isBF16() && inElemTy.isInteger(8) && operands.size() >= 4) {
|
||||
#if USE_ROCM
|
||||
auto outVals = S8_to_Bf16(loc, rewriter, operands[0][0], operands[1][0],
|
||||
operands[2][0], operands[3][0]);
|
||||
SmallVector<Value> inVals = {operands[0][0], operands[1][0],
|
||||
operands[2][0], operands[3][0]};
|
||||
auto outVals = S8_to_Bf16(loc, rewriter, inVals);
|
||||
#else
|
||||
auto cvtFunc = makeConverterFromPtx(
|
||||
S8_to_Bf16, getTypeConverter()->convertType(inElemTy),
|
||||
getTypeConverter()->convertType(outElemTy));
|
||||
<<<<<<< HEAD
|
||||
auto outVals = cvtFunc(loc, rewriter, operands[0][0], operands[1][0],
|
||||
operands[2][0], operands[3][0]);
|
||||
auto cvtFunc = makeConverterFromPtx(
|
||||
S8_to_Bf16, getTypeConverter()->convertType(inElemTy),
|
||||
getTypeConverter()->convertType(outElemTy));
|
||||
SmallVector<Value> inVals = {operands[0][0], operands[1][0],
|
||||
operands[2][0], operands[3][0]};
|
||||
auto outVals = cvtFunc(loc, rewriter, inVals);
|
||||
#endif
|
||||
=======
|
||||
SmallVector<Value> inVals = {operands[0][0], operands[1][0],
|
||||
operands[2][0], operands[3][0]};
|
||||
auto outVals = cvtFunc(loc, rewriter, inVals);
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
assert(outVals.size() == 4);
|
||||
return outVals;
|
||||
} else if (outElemTy.isBF16()) {
|
||||
|
||||
@@ -425,8 +425,21 @@ private:
|
||||
|
||||
for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) {
|
||||
SmallVector<Value> shfl(acc.size());
|
||||
unsigned shuffleIdx = N;
|
||||
#ifdef USE_ROCM
|
||||
auto srcTys = op.getInputTypes();
|
||||
auto inputTy = srcTys[0].cast<RankedTensorType>();
|
||||
auto inMfma =
|
||||
inputTy.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
|
||||
if (inMfma && inMfma.getIsTransposed()) {
|
||||
//assert(sizeIntraWarps == 2);
|
||||
// Adjecant threads in y dimension in transposed MFMA layout are 32
|
||||
// apart: [[0 0 0 0 32 32 32 32 ...] [1 1 1 1 33 33 33 33 ...] ...].
|
||||
shuffleIdx = 32;
|
||||
}
|
||||
#endif
|
||||
for (unsigned i = 0; i < acc.size(); ++i) {
|
||||
shfl[i] = shflSync(loc, rewriter, acc[i], N);
|
||||
shfl[i] = shflSync(loc, rewriter, acc[i], shuffleIdx);
|
||||
}
|
||||
accumulate(rewriter, op.getCombineOp(), acc, shfl, false);
|
||||
}
|
||||
@@ -491,11 +504,11 @@ private:
|
||||
triton::ReduceOp op = helper.getOperation();
|
||||
Location loc = op.getLoc();
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
auto srcLayout = helper.getSrcLayout();
|
||||
unsigned wavefront_size = triton::gpu::getWarpSize(srcLayout);
|
||||
Value warpSize = i32_val(wavefront_size);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
auto srcLayout = helper.getSrcLayout();
|
||||
auto srcShape = helper.getSrcShape();
|
||||
unsigned axis = op.getAxis();
|
||||
auto smemShapes = helper.getScratchConfigsFast();
|
||||
@@ -511,6 +524,7 @@ private:
|
||||
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
auto srcTys = op.getInputTypes();
|
||||
auto inputTy = srcTys[0].cast<RankedTensorType>();
|
||||
auto inMfma =
|
||||
inputTy.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
|
||||
@@ -532,35 +546,8 @@ private:
|
||||
|
||||
for (auto it : accs) {
|
||||
const SmallVector<unsigned> &key = it.first;
|
||||
<<<<<<< HEAD
|
||||
SmallVector<Value> acc = it.second;
|
||||
|
||||
// Reduce within warps
|
||||
for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) {
|
||||
SmallVector<Value> shfl(op.getNumOperands());
|
||||
unsigned shuffleIdx = N;
|
||||
#ifdef USE_ROCM
|
||||
if (inMfma && inMfma.getIsTransposed()) {
|
||||
assert(sizeIntraWarps == 2);
|
||||
// Adjecant threads in y dimension in transposed MFMA layout are 32
|
||||
// apart: [[0 0 0 0 32 32 32 32 ...] [1 1 1 1 33 33 33 33 ...] ...].
|
||||
shuffleIdx = 32;
|
||||
}
|
||||
#endif
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
shfl[i] = shflSync(loc, rewriter, acc[i], shuffleIdx);
|
||||
}
|
||||
accumulate(rewriter, *combineOp, acc, shfl, false);
|
||||
}
|
||||
|
||||
if (isWarpSync) {
|
||||
finalAccs[key] = acc;
|
||||
continue;
|
||||
}
|
||||
=======
|
||||
SmallVector<Value> &acc = it.second;
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
|
||||
SmallVector<Value> writeIdx = indices[key];
|
||||
writeIdx[axis] = warpIdAxis;
|
||||
Value writeOffset =
|
||||
@@ -620,6 +607,9 @@ private:
|
||||
icmp_eq(laneIdModSizeInterWarps, zero);
|
||||
Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero);
|
||||
|
||||
auto srcLayout = helper.getSrcLayout();
|
||||
unsigned wavefront_size = triton::gpu::getWarpSize(srcLayout);
|
||||
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
#if USE_ROCM
|
||||
// This barrier is known to be critical for Navi 2x/3x
|
||||
|
||||
@@ -446,6 +446,16 @@ struct GetProgramIdOpConversion
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
#ifdef USE_ROCM
|
||||
Location loc = op->getLoc();
|
||||
assert(op.getAxisAsInt() < 3);
|
||||
|
||||
Value blockId =
|
||||
rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxisAsInt()]);
|
||||
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, blockId);
|
||||
return success();
|
||||
#else
|
||||
// It is not easy to get the compute capability here, so we use numCTAs to
|
||||
// decide the semantic of GetProgramIdOp. If numCTAs = 1, then
|
||||
// GetProgramIdOp is converted to "%ctaid", otherwise it is converted to
|
||||
@@ -462,7 +472,11 @@ struct GetProgramIdOpConversion
|
||||
Value programId = getSRegValue(rewriter, loc, sreg);
|
||||
rewriter.replaceOp(op, programId);
|
||||
return success();
|
||||
#endif
|
||||
}
|
||||
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
|
||||
mlir::gpu::Dimension::y,
|
||||
mlir::gpu::Dimension::z};
|
||||
};
|
||||
|
||||
struct GetNumProgramsOpConversion
|
||||
@@ -473,6 +487,26 @@ struct GetNumProgramsOpConversion
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
#ifdef USE_ROCM
|
||||
|
||||
Location loc = op->getLoc();
|
||||
assert(op.getAxis() < 3);
|
||||
// Seem like GridDimOp returns the number of threads (not the number of
|
||||
// workgroups) in a kernel (a bug in llvm https://reviews.llvm.org/D156009),
|
||||
// so as a workaround here, we divide by the number of threads
|
||||
// per workgroup to get the number of workgroups in a kernel.
|
||||
// TODO: when we do upstream to include llvm fix, we can remove this workaround
|
||||
// The unit test added in this PR can guarantee that.
|
||||
Value threadsPerGrid =
|
||||
rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxis()]);
|
||||
Value threadsPerBlock =
|
||||
rewriter.create<::mlir::gpu::BlockDimOp>(loc, dims[op.getAxis()]);
|
||||
Value threadNumPerGrid = rewriter.create<arith::TruncIOp>(loc, i32_ty, threadsPerGrid);
|
||||
Value threadNumPerBlock = rewriter.create<arith::TruncIOp>(loc, i32_ty, threadsPerBlock);
|
||||
rewriter.replaceOpWithNewOp<LLVM::UDivOp>(op, threadNumPerGrid, threadNumPerBlock);
|
||||
return success();
|
||||
#else
|
||||
// It is not easy to get the compute capability here, so we use numCTAs to
|
||||
// decide the semantic of GetNumProgramsOp. If numCTAs = 1, then
|
||||
// GetNumProgramsOp is converted to "%nctaid", otherwise it is converted to
|
||||
@@ -486,33 +520,17 @@ struct GetNumProgramsOpConversion
|
||||
std::string sreg = numCTAs == 1 ? "%nctaid." : "%nclusterid.";
|
||||
sreg.append(1, 'x' + op.getAxis()); // 0 -> 'x', 1 -> 'y', 2 -> 'z'
|
||||
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
// Seem like GridDimOp returns the number of threads (not the number of
|
||||
// workgroups) in a kernel (a bug in llvm https://reviews.llvm.org/D156009),
|
||||
// so as a workaround here, we divide by the number of threads
|
||||
// per workgroup to get the number of workgroups in a kernel.
|
||||
// TODO: when we do upstream to include llvm fix, we can remove this workaround
|
||||
// The unit test added in this PR can guarantee that.
|
||||
Value threadsPerGrid =
|
||||
rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxis()]);
|
||||
Value threadsPerBlock =
|
||||
rewriter.create<::mlir::gpu::BlockDimOp>(loc, dims[op.getAxis()]);
|
||||
Value threadNumPerGrid = rewriter.create<arith::TruncIOp>(loc, i32_ty, threadsPerGrid);
|
||||
Value threadNumPerBlock = rewriter.create<arith::TruncIOp>(loc, i32_ty, threadsPerBlock);
|
||||
rewriter.replaceOpWithNewOp<LLVM::UDivOp>(op, threadNumPerGrid, threadNumPerBlock);
|
||||
#else
|
||||
Value blockId =
|
||||
rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxis()]);
|
||||
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, blockId);
|
||||
#endif // USE_ROCM
|
||||
=======
|
||||
Value numPrograms = getSRegValue(rewriter, loc, sreg);
|
||||
rewriter.replaceOp(op, numPrograms);
|
||||
return success();
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
|
||||
mlir::gpu::Dimension::y,
|
||||
mlir::gpu::Dimension::z};
|
||||
};
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
|
||||
// TODO[goostavz]: GetThreadIdOp/GetClusterCTAIdOp is a temporary solution
|
||||
// before async dialect is done. These concepts should appear in ttgpu
|
||||
|
||||
@@ -324,12 +324,8 @@ public:
|
||||
// then (x + y) XOR z = 0byyyyxxxx XOR 0b00000zzzz = (x XOR z) + y
|
||||
// This means that we can use some immediate offsets for shared memory
|
||||
// operations.
|
||||
<<<<<<< HEAD
|
||||
resElemTy = getTypeConverter()->convertType(resElemTy);
|
||||
auto dstPtrTy = ptr_ty(resElemTy, 3);
|
||||
=======
|
||||
auto dstPtrTy = ptr_ty(getTypeConverter()->convertType(resElemTy), 3);
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides);
|
||||
Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset);
|
||||
|
||||
@@ -555,13 +551,8 @@ public:
|
||||
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout);
|
||||
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout);
|
||||
auto order = triton::gpu::getOrder(layout);
|
||||
<<<<<<< HEAD
|
||||
auto shapePerCTA = triton::gpu::getShapePerCTA(layout, shape);
|
||||
Value warpSize = i32_val(triton::gpu::getWarpSize(layout));
|
||||
=======
|
||||
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape);
|
||||
Value warpSize = i32_val(32);
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
Value warpSize = i32_val(triton::gpu::getWarpSize(layout));
|
||||
Value laneId = urem(tid, warpSize);
|
||||
Value warpId = udiv(tid, warpSize);
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
@@ -694,19 +685,13 @@ public:
|
||||
blockedLayout, type);
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.isVolta())
|
||||
<<<<<<< HEAD
|
||||
result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, type);
|
||||
if (mmaLayout.isAmpere())
|
||||
result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, type);
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
result = emitBaseIndexForMfmaLayout(loc, rewriter, mfmaLayout, type);
|
||||
=======
|
||||
result = emitBaseIndexWithinCTAForMmaLayoutV1(loc, rewriter,
|
||||
mmaLayout, type);
|
||||
if (mmaLayout.isAmpere() || mmaLayout.isHopper())
|
||||
result = emitBaseIndexWithinCTAForMmaLayoutV2V3(loc, rewriter,
|
||||
mmaLayout, type);
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
result = emitBaseIndexForMfmaLayout(loc, rewriter, mfmaLayout, type);
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
auto parentLayout = sliceLayout.getParent();
|
||||
auto parentShape = sliceLayout.paddedShape(type.getShape());
|
||||
@@ -764,7 +749,7 @@ public:
|
||||
const unsigned elemsPerThreadPerGroup = 4;
|
||||
auto warpSize = getWarpSize(mfmaLayout);
|
||||
assert(warpSize == 64);
|
||||
auto shapePerCta = getShapePerCTA(mfmaLayout);
|
||||
auto shapePerCta = getShapePerCTATile(mfmaLayout);
|
||||
for (unsigned block = 0; block < numGroups; block++) {
|
||||
unsigned rowOrColOffset = block * elemsPerThreadPerGroup * warpSize / 32;
|
||||
for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) {
|
||||
@@ -802,14 +787,11 @@ public:
|
||||
result = emitIndicesForDistributedLayout(loc, b, blocked, type,
|
||||
withCTAOffset);
|
||||
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
<<<<<<< HEAD
|
||||
result = emitIndicesForDistributedLayout(loc, b, mma, type);
|
||||
} else if (auto mfma = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
result = emitIndicesForDistributedLayout(loc, b, mfma, type);
|
||||
=======
|
||||
result =
|
||||
emitIndicesForDistributedLayout(loc, b, mma, type, withCTAOffset);
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
} else if (auto mfma = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
result =
|
||||
emitIndicesForDistributedLayout(loc, b, mfma, type, withCTAOffset);
|
||||
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
result =
|
||||
emitIndicesForDistributedLayout(loc, b, slice, type, withCTAOffset);
|
||||
@@ -848,7 +830,7 @@ private:
|
||||
const BlockedEncodingAttr &blockedLayout, RankedTensorType type) const {
|
||||
auto shape = type.getShape();
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = i32_val(triton::gpu::getWarpSize(blocked_layout));
|
||||
Value warpSize = i32_val(triton::gpu::getWarpSize(blockedLayout));
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
auto sizePerThread = blockedLayout.getSizePerThread();
|
||||
@@ -1208,7 +1190,7 @@ private:
|
||||
|
||||
auto tensorShape = type.getShape();
|
||||
SmallVector<SmallVector<unsigned>> offsets;
|
||||
auto shapePerCta = getShapePerCTA(mfmaLayout);
|
||||
auto shapePerCta = getShapePerCTA(mfmaLayout, tensorShape);
|
||||
|
||||
SmallVector<unsigned> numCTAPerDim(2);
|
||||
for (unsigned d = 0; d < 2; ++d) {
|
||||
|
||||
@@ -20,11 +20,11 @@
|
||||
#include "triton/Analysis/Membar.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
#ifndef USE_ROCM
|
||||
#else
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
#endif
|
||||
#include "triton/Tools/Sys/GetPlatform.hpp"
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
|
||||
#include "BarrierOpToLLVM.h"
|
||||
#include "ClusterOpsToLLVM.h"
|
||||
@@ -70,19 +70,13 @@ public:
|
||||
: ConversionTarget(ctx) {
|
||||
addLegalDialect<index::IndexDialect>();
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
<<<<<<< HEAD
|
||||
if (isROCM) {
|
||||
addLegalDialect<ROCDL::ROCDLDialect>();
|
||||
addLegalDialect<mlir::scf::SCFDialect>();
|
||||
} else {
|
||||
=======
|
||||
switch (target) {
|
||||
case Target::NVVM:
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
addLegalDialect<NVVM::NVVMDialect>();
|
||||
break;
|
||||
case Target::ROCDL:
|
||||
addLegalDialect<ROCDL::ROCDLDialect>();
|
||||
addLegalDialect<mlir::scf::SCFDialect>();
|
||||
break;
|
||||
}
|
||||
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
||||
@@ -377,19 +371,13 @@ public:
|
||||
explicit TritonLLVMConversionTarget(MLIRContext &ctx, Target target)
|
||||
: ConversionTarget(ctx) {
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
<<<<<<< HEAD
|
||||
if (isROCM) {
|
||||
addLegalDialect<ROCDL::ROCDLDialect>();
|
||||
addLegalDialect<mlir::scf::SCFDialect>();
|
||||
} else {
|
||||
=======
|
||||
switch (target) {
|
||||
case Target::NVVM:
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
addLegalDialect<NVVM::NVVMDialect>();
|
||||
break;
|
||||
case Target::ROCDL:
|
||||
addLegalDialect<ROCDL::ROCDLDialect>();
|
||||
addLegalDialect<mlir::scf::SCFDialect>();
|
||||
break;
|
||||
}
|
||||
addLegalDialect<mlir::triton::nvgpu::NVGPUDialect>();
|
||||
@@ -419,14 +407,10 @@ struct ConvertTritonGPUToLLVM
|
||||
|
||||
// Preprocess
|
||||
decomposeFp8e4b15Convert(mod);
|
||||
<<<<<<< HEAD
|
||||
decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp);
|
||||
#ifdef USE_ROCM
|
||||
decomposeMfmaToDotOperand(mod, numWarps, threadsPerWarp);
|
||||
#endif
|
||||
=======
|
||||
decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp, numCTAs);
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
#ifdef USE_ROCM
|
||||
decomposeMfmaToDotOperand(mod, numWarps, threadsPerWarp, numCTAs);
|
||||
#endif
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
decomposeInsertSliceAsyncOp(mod);
|
||||
decomposeMixedModeDotOp(mod);
|
||||
@@ -699,8 +683,8 @@ private:
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
void decomposeMfmaToDotOperand(ModuleOp mod, int numWarps,
|
||||
int threadsPerWarp) const {
|
||||
void decomposeMfmaToDotOperand(ModuleOp mod, int numWarps, int threadsPerWarp,
|
||||
int numCTAs) const {
|
||||
// Replace `mfma -> dot_op` with `mfma -> blocked -> dot_op`
|
||||
// unless certain conditions are met
|
||||
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
@@ -716,7 +700,7 @@ private:
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::BlockedEncodingAttr::get(
|
||||
mod.getContext(), srcType.getShape(), getSizePerThread(srcMfma),
|
||||
getOrder(srcMfma), numWarps, threadsPerWarp));
|
||||
getOrder(srcMfma), numWarps, threadsPerWarp, numCTAs));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
|
||||
@@ -331,13 +331,9 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||
Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
|
||||
int i);
|
||||
Value shflUpSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
|
||||
<<<<<<< HEAD
|
||||
int i, Value laneId);
|
||||
|
||||
=======
|
||||
int i);
|
||||
Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr);
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
|
||||
StringRef key, StringRef content);
|
||||
|
||||
|
||||
@@ -475,11 +475,7 @@ public:
|
||||
|
||||
void runOnOperation() override {
|
||||
// Only rewrite if the hardware does not support
|
||||
<<<<<<< HEAD
|
||||
if (!isROCM && computeCapability >= 90)
|
||||
=======
|
||||
if (computeCapability >= 90)
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
return;
|
||||
|
||||
// NOTES(Chenggang): we don't use `ConversionPatternRewriter`, because
|
||||
|
||||
@@ -249,7 +249,6 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
return {};
|
||||
}
|
||||
} else if (parentLayout.isa<MfmaEncodingAttr>()) {
|
||||
auto parentShapePerCTA = getShapePerCTA(parentLayout);
|
||||
auto opIdx = dotLayout.getOpIdx();
|
||||
if (opIdx == 0) {
|
||||
return {4, 1};
|
||||
@@ -381,7 +380,7 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout,
|
||||
}
|
||||
} else if (auto parentMfmaLayout =
|
||||
parentLayout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
auto parentShapePerCTA = getShapePerCTA(parentLayout, tensorShape);
|
||||
auto parentShapePerCTA = getShapePerCTATile(parentLayout, tensorShape);
|
||||
auto opIdx = dotLayout.getOpIdx();
|
||||
|
||||
if (opIdx == 0) {
|
||||
@@ -691,11 +690,7 @@ static LogicalResult parseBoolAttrValue(AsmParser &parser, Attribute attr,
|
||||
bool &value, StringRef desc) {
|
||||
auto boolAttr = attr.dyn_cast<BoolAttr>();
|
||||
if (!boolAttr) {
|
||||
<<<<<<< HEAD
|
||||
parser.emitError(parser.getNameLoc(), "expected bool type in ") << desc;
|
||||
=======
|
||||
parser.emitError(parser.getNameLoc(), "expected an bool type in ") << desc;
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
return failure();
|
||||
}
|
||||
value = boolAttr.getValue();
|
||||
@@ -862,11 +857,11 @@ MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
|
||||
return elemsPerThread;
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
unsigned MfmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
return product<unsigned>(getElemsPerThread(shape, eltTy));
|
||||
=======
|
||||
}
|
||||
|
||||
unsigned
|
||||
MmaEncodingAttr::getElemsPerThreadOfOperand(int opIdx,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
@@ -896,7 +891,6 @@ MmaEncodingAttr::getElemsPerThreadOfOperand(int opIdx,
|
||||
}
|
||||
}
|
||||
return res;
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
}
|
||||
|
||||
unsigned MmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
|
||||
@@ -971,7 +965,6 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
|
||||
|
||||
unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
<<<<<<< HEAD
|
||||
if (auto mfmaParent = getParent().dyn_cast<MfmaEncodingAttr>()) {
|
||||
int warpsPerCTAM = mfmaParent.getWarpsPerCTA()[0];
|
||||
int warpsPerCTAN = mfmaParent.getWarpsPerCTA()[1];
|
||||
@@ -980,9 +973,7 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
|
||||
auto rep = getMFMARep(shape, eltTy);
|
||||
return rep[0] * rep[1];
|
||||
}
|
||||
=======
|
||||
auto shapePerCTA = getShapePerCTA(*this, shape);
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
if (auto mmaParent = getParent().dyn_cast<MmaEncodingAttr>()) {
|
||||
int warpsPerCTAM = mmaParent.getWarpsPerCTA()[0];
|
||||
int warpsPerCTAN = mmaParent.getWarpsPerCTA()[1];
|
||||
@@ -1247,7 +1238,7 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return {};
|
||||
|
||||
unsigned nonKDim = 0;
|
||||
SmallVector<unsigned, 2> warpsPerCTA;
|
||||
SmallVector<unsigned> warpsPerCTA;
|
||||
bool isTransposed;
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
@@ -1438,11 +1429,7 @@ void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
auto mmaParent = getParent().dyn_cast<MmaEncodingAttr>();
|
||||
printer << "<{"
|
||||
<< "opIdx = " << getOpIdx() << ", parent = " << getParent();
|
||||
<<<<<<< HEAD
|
||||
if ((mmaParent && mmaParent.isAmpere()) || getParent().isa<MfmaEncodingAttr>())
|
||||
=======
|
||||
if (mmaParent && mmaParent.isAmpere())
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
printer << ", kWidth = " << getKWidth();
|
||||
printer << "}>";
|
||||
}
|
||||
|
||||
@@ -75,9 +75,8 @@ warpsPerTileV2(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
SmallVector<unsigned, 2> warpsPerTileMI200(triton::DotOp dotOp,
|
||||
SmallVector<unsigned, 2> warpsPerTileMI200(tt::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
// TODO: needs to be updated with appropriate shapePerWarp etc.
|
||||
@@ -86,7 +85,7 @@ SmallVector<unsigned, 2> warpsPerTileMI200(triton::DotOp dotOp,
|
||||
};
|
||||
auto slices = mlir::getSlice(dotOp, filter);
|
||||
for (Operation *op : slices)
|
||||
if (isa<triton::DotOp>(op) && (op != dotOp))
|
||||
if (isa<tt::DotOp>(op) && (op != dotOp))
|
||||
return {(unsigned)numWarps, 1};
|
||||
|
||||
SmallVector<int64_t, 2> tensorShape = {shape[0], shape[1]};
|
||||
@@ -104,7 +103,18 @@ SmallVector<unsigned, 2> warpsPerTileMI200(triton::DotOp dotOp,
|
||||
ret[0] *= 2;
|
||||
} else
|
||||
ret[1] *= 2;
|
||||
=======
|
||||
} else {
|
||||
ret[1] *= 2;
|
||||
}
|
||||
} while (true);
|
||||
|
||||
if (ret[1] * shapePerWarp[1] > tensorShape[1]) {
|
||||
return {ret[1], ret[0]};
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2>
|
||||
warpsPerTileV3(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
|
||||
const SmallVector<unsigned, 3> &instrShape) {
|
||||
@@ -122,17 +132,10 @@ warpsPerTileV3(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
|
||||
break;
|
||||
if (shape[0] > shapePerWarp[0] * ret[0]) {
|
||||
ret[0] *= 2;
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
} else {
|
||||
ret[1] *= 2;
|
||||
}
|
||||
} while (true);
|
||||
<<<<<<< HEAD
|
||||
|
||||
if (ret[1] * shapePerWarp[1] > tensorShape[1]) {
|
||||
return {ret[1], ret[0]};
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -140,15 +143,15 @@ class BlockedToMFMA : public mlir::RewritePattern {
|
||||
int mfmaVersion;
|
||||
public:
|
||||
BlockedToMFMA(mlir::MLIRContext *context, int mfmaVersion)
|
||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context), mfmaVersion(mfmaVersion) {}
|
||||
: mlir::RewritePattern(tt::DotOp::getOperationName(), 2, context), mfmaVersion(mfmaVersion) {}
|
||||
|
||||
bool isChainDot(triton::DotOp &dotOp) const {
|
||||
bool isChainDot(tt::DotOp &dotOp) const {
|
||||
auto filter = [&dotOp](Operation *op) {
|
||||
return op->getParentRegion() == dotOp->getParentRegion();
|
||||
};
|
||||
auto slices = mlir::getSlice(dotOp, filter);
|
||||
for (Operation *op : slices) {
|
||||
if (isa<triton::DotOp>(op) && (op != dotOp))
|
||||
if (isa<tt::DotOp>(op) && (op != dotOp))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@@ -158,7 +161,7 @@ public:
|
||||
/// @param dot target dot operation
|
||||
/// @param mfmaVersion
|
||||
/// @return pair {nonKDim, kDim} sizes of one MFMA instruction arguments
|
||||
std::pair<int64_t, int64_t> chooseMfmaDimensions(triton::DotOp dot, int mfmaVersion) const {
|
||||
std::pair<int64_t, int64_t> chooseMfmaDimensions(tt::DotOp dot, int mfmaVersion) const {
|
||||
int64_t nonKDim = 32;
|
||||
// number of matrix elements along k dim per one MFMA intruction
|
||||
int64_t kDim = -1;
|
||||
@@ -183,11 +186,11 @@ public:
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto dotOp = cast<triton::DotOp>(op);
|
||||
auto dotOp = cast<tt::DotOp>(op);
|
||||
|
||||
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
|
||||
if (!oldRetType.getEncoding() ||
|
||||
!oldRetType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>())
|
||||
!oldRetType.getEncoding().isa<ttg::BlockedEncodingAttr>())
|
||||
return failure();
|
||||
|
||||
if (!supportMFMA(dotOp))
|
||||
@@ -196,7 +199,7 @@ public:
|
||||
// get MFMA encoding for the given number of warps
|
||||
auto retShape = oldRetType.getShape();
|
||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
int numWarps = ttg::TritonGPUDialect::getNumWarps(mod);
|
||||
|
||||
// operands
|
||||
Value a = dotOp.getA();
|
||||
@@ -205,14 +208,14 @@ public:
|
||||
auto oldBType = b.getType().cast<RankedTensorType>();
|
||||
auto ctx = oldAType.getContext();
|
||||
|
||||
triton::gpu::MfmaEncodingAttr mfmaEnc;
|
||||
ttg::MfmaEncodingAttr mfmaEnc;
|
||||
|
||||
auto [nonKDim, kDim] = chooseMfmaDimensions(dotOp, mfmaVersion);
|
||||
|
||||
auto warpsPerTile = warpsPerTileMI200(dotOp, retShape, numWarps);
|
||||
|
||||
bool isTransposed = isChainDot(dotOp);
|
||||
mfmaEnc = triton::gpu::MfmaEncodingAttr::get(
|
||||
mfmaEnc = ttg::MfmaEncodingAttr::get(
|
||||
oldRetType.getContext(), nonKDim, warpsPerTile, isTransposed);
|
||||
|
||||
auto newRetType =
|
||||
@@ -220,44 +223,39 @@ public:
|
||||
|
||||
// convert accumulator
|
||||
auto oldAcc = dotOp.getOperand(2);
|
||||
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
auto newAcc = rewriter.create<ttg::ConvertLayoutOp>(
|
||||
oldAcc.getLoc(), newRetType, oldAcc);
|
||||
auto oldAOrder = oldAType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.cast<ttg::DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.cast<ttg::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
auto oldBOrder = oldBType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.cast<ttg::DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.cast<ttg::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
|
||||
// kWidth is a number of consecutive elements per one instruction per one thread
|
||||
auto kWidth = kDim / 2;
|
||||
auto newAType = RankedTensorType::get(
|
||||
oldAType.getShape(), oldAType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth));
|
||||
ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth));
|
||||
auto newBType = RankedTensorType::get(
|
||||
oldBType.getShape(), oldBType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth));
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
auto newDot = rewriter.create<triton::DotOp>(
|
||||
ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth));
|
||||
a = rewriter.create<ttg::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<ttg::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
auto newDot = rewriter.create<tt::DotOp>(
|
||||
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getAllowTF32());
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||
rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(
|
||||
op, oldRetType, newDot.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
=======
|
||||
return ret;
|
||||
}
|
||||
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
int computeCapability;
|
||||
mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding
|
||||
|
||||
@@ -283,12 +283,6 @@ static bool isLayoutAnchor(Operation *op) {
|
||||
return false;
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
// this may generate unsupported conversions in the LLVM codegen
|
||||
if (newEncoding.isa<triton::gpu::MmaEncodingAttr>() ||
|
||||
newEncoding.isa<triton::gpu::MfmaEncodingAttr>()) {
|
||||
return failure();
|
||||
=======
|
||||
void LayoutPropagation::initAnchorLayout() {
|
||||
funcOp.walk([&](Operation *op) {
|
||||
if (isLayoutAnchor(op)) {
|
||||
@@ -304,7 +298,6 @@ void LayoutPropagation::initAnchorLayout() {
|
||||
layouts.insert({result, tensorType.getEncoding()});
|
||||
}
|
||||
}
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -498,32 +491,6 @@ void LayoutPropagation::map(Value old, Value newV) {
|
||||
newV;
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
// op(cvt(arg_0), arg_1, ..., arg_n)
|
||||
// -> cvt(op(arg_0, cvt(arg_1), ..., cvt(arg_n)))
|
||||
void pushConversionForward(triton::gpu::ConvertLayoutOp cvt,
|
||||
SetVector<Operation *> &cvtSlices,
|
||||
PatternSharedInfo &sharedInfo,
|
||||
mlir::PatternRewriter &rewriter) {
|
||||
auto srcEncoding =
|
||||
cvt.getOperand().getType().cast<RankedTensorType>().getEncoding();
|
||||
auto dstEncoding =
|
||||
cvt.getResult().getType().cast<RankedTensorType>().getEncoding();
|
||||
IRMapping mapping;
|
||||
auto op = cvtSlices.front();
|
||||
for (Value arg : op->getOperands()) {
|
||||
if (arg.getDefiningOp() == cvt)
|
||||
mapping.map(arg, cvt.getOperand());
|
||||
else {
|
||||
auto oldType = arg.getType().cast<RankedTensorType>();
|
||||
auto newType = RankedTensorType::get(
|
||||
oldType.getShape(), oldType.getElementType(), srcEncoding);
|
||||
auto cvtI = rewriter.create<triton::gpu::ConvertLayoutOp>(arg.getLoc(),
|
||||
newType, arg);
|
||||
if (Operation *argOp = arg.getDefiningOp())
|
||||
cvtI->moveAfter(argOp);
|
||||
mapping.map(arg, cvtI);
|
||||
=======
|
||||
Value LayoutPropagation::getValueAs(Value value, Attribute encoding) {
|
||||
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
|
||||
Value rewrittenValue;
|
||||
@@ -538,7 +505,6 @@ Value LayoutPropagation::getValueAs(Value value, Attribute encoding) {
|
||||
rewrittenValue = value;
|
||||
else
|
||||
rewrittenValue = rewriteMapping[{value, encodingPicked}];
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
}
|
||||
assert(rewrittenValue);
|
||||
if (rewrittenValue.getType().cast<RankedTensorType>().getEncoding() ==
|
||||
@@ -572,18 +538,7 @@ Operation *LayoutPropagation::cloneElementwise(OpBuilder &rewriter,
|
||||
origType.getElementType(), encoding);
|
||||
newOp->getResult(i).setType(newType);
|
||||
}
|
||||
<<<<<<< HEAD
|
||||
auto *newOp = cloneWithInferType(rewriter, op, mapping);
|
||||
auto newType = newOp->getResult(0).getType().cast<RankedTensorType>();
|
||||
auto newCvtType = RankedTensorType::get(
|
||||
newType.getShape(), newType.getElementType(), dstEncoding);
|
||||
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
newOp->getLoc(), newCvtType, newOp->getResult(0));
|
||||
sharedInfo.cvtsPushedForwardMap[newCvt] = newCvt->getOperand(0).getDefiningOp();
|
||||
rewriter.replaceOp(op, newCvt->getResults());
|
||||
=======
|
||||
return newOp;
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
}
|
||||
|
||||
Operation *LayoutPropagation::rewriteForOp(scf::ForOp forOp) {
|
||||
@@ -649,16 +604,6 @@ Operation *LayoutPropagation::rewriteWhileOp(scf::WhileOp whileOp) {
|
||||
returnTypes.push_back(newType);
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
//
|
||||
class RematerializeForward : public mlir::RewritePattern {
|
||||
PatternSharedInfo &sharedInfo;
|
||||
|
||||
public:
|
||||
explicit RematerializeForward(mlir::MLIRContext *context, PatternSharedInfo &sharedInfo)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
1, context), sharedInfo(sharedInfo) {}
|
||||
=======
|
||||
auto newWhileOp =
|
||||
rewriter.create<scf::WhileOp>(whileOp.getLoc(), returnTypes, operands);
|
||||
SmallVector<Type> argsTypesBefore;
|
||||
@@ -670,7 +615,6 @@ public:
|
||||
rewriter.createBlock(&newWhileOp.getBefore(), {}, argsTypesBefore,
|
||||
bbArgLocsBefore);
|
||||
rewriter.createBlock(&newWhileOp.getAfter(), {}, returnTypes, bbArgLocsAfter);
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
|
||||
for (int i = 0; i < whileOp.getNumRegions(); ++i) {
|
||||
newWhileOp->getRegion(i).front().getOperations().splice(
|
||||
@@ -741,10 +685,6 @@ void LayoutPropagation::rewriteYieldOp(scf::YieldOp yieldOp) {
|
||||
}
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
pushConversionForward(cvt, cvtSlices, sharedInfo, rewriter);
|
||||
return success();
|
||||
=======
|
||||
void LayoutPropagation::rewriteConditionOp(scf::ConditionOp conditionOp) {
|
||||
scf::WhileOp whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
|
||||
for (unsigned i = 1; i < conditionOp->getNumOperands(); ++i) {
|
||||
@@ -755,58 +695,9 @@ void LayoutPropagation::rewriteConditionOp(scf::ConditionOp conditionOp) {
|
||||
continue;
|
||||
Value newOperand = getValueAs(operand.get(), tensorType.getEncoding());
|
||||
conditionOp->setOperand(operand.getOperandNumber(), newOperand);
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
}
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
// Layout conversions are expensive. They require going through
|
||||
// shared memory, which is orders of magnitude slower than
|
||||
// other non-i/o operations in the dialect.
|
||||
// It therefore makes sense to remove them whenever possible,
|
||||
// even if it means rematerializing all values whose definitions
|
||||
// are reachable from it without passing through any memory operation.
|
||||
class RematerializeBackward : public mlir::RewritePattern {
|
||||
PatternSharedInfo &sharedInfo;
|
||||
|
||||
public:
|
||||
explicit RematerializeBackward(mlir::MLIRContext *context, PatternSharedInfo &sharedInfo)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
3, context), sharedInfo(sharedInfo) {}
|
||||
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *cvt,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(cvt))
|
||||
return mlir::failure();
|
||||
|
||||
auto it = sharedInfo.cvtsPushedForwardMap.find(cvt);
|
||||
if (it != sharedInfo.cvtsPushedForwardMap.end() &&
|
||||
it->second == cvt->getOperand(0).getDefiningOp())
|
||||
return mlir::failure();
|
||||
|
||||
// we don't touch block arguments
|
||||
Operation *op = cvt->getOperand(0).getDefiningOp();
|
||||
if (!op)
|
||||
return mlir::failure();
|
||||
// we don't want to rematerialize any conversion to/from shared
|
||||
if (triton::gpu::isSharedEncoding(cvt->getResults()[0]) ||
|
||||
triton::gpu::isSharedEncoding(cvt->getOperand(0)))
|
||||
return mlir::failure();
|
||||
// we don't handle conversions to DotOperandEncodingAttr
|
||||
// this is a heuristics to accommodate fused attention
|
||||
auto targetType = cvt->getResultTypes()[0].cast<RankedTensorType>();
|
||||
if (targetType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
|
||||
return mlir::failure();
|
||||
// DFS
|
||||
SetVector<Operation *> processed;
|
||||
SetVector<Attribute> layout;
|
||||
llvm::MapVector<Value, Attribute> toConvert;
|
||||
if (simulateBackwardRematerialization(cvt, processed, layout, toConvert,
|
||||
targetType.getEncoding()) > 0)
|
||||
return mlir::failure();
|
||||
=======
|
||||
void LayoutPropagation::rewriteReduceToScalar(Operation *reduceOp) {
|
||||
OpBuilder rewriter(reduceOp);
|
||||
Attribute srcEncoding;
|
||||
@@ -875,7 +766,6 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) {
|
||||
assert(0 && "unexpected op in rewrite");
|
||||
return nullptr;
|
||||
}
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
|
||||
static bool canBeRemat(Operation *op) {
|
||||
if (isa<triton::LoadOp, triton::StoreOp>(op))
|
||||
@@ -898,42 +788,6 @@ static scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter,
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
rewriter.setInsertionPoint(loop);
|
||||
|
||||
<<<<<<< HEAD
|
||||
class MoveConvertOutOfLoop : public mlir::RewritePattern {
|
||||
PatternSharedInfo &sharedInfo;
|
||||
|
||||
public:
|
||||
explicit MoveConvertOutOfLoop(mlir::MLIRContext *context,
|
||||
PatternSharedInfo &sharedInfo)
|
||||
: mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context),
|
||||
sharedInfo(sharedInfo) {}
|
||||
|
||||
SmallVector<Value, 4>
|
||||
rematerializeForLoop(mlir::PatternRewriter &rewriter, scf::ForOp &forOp,
|
||||
size_t i, RankedTensorType newType,
|
||||
triton::gpu::ConvertLayoutOp origConversion) const {
|
||||
// Rewrite init argument
|
||||
Type origType = forOp.getInitArgs()[i].getType();
|
||||
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
|
||||
newInitArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
newInitArgs[i].getLoc(), newType, newInitArgs[i]);
|
||||
// Clone for loop
|
||||
auto newForOp = rewriter.create<scf::ForOp>(
|
||||
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
||||
forOp.getStep(), newInitArgs);
|
||||
newForOp->moveBefore(forOp);
|
||||
rewriter.setInsertionPointToStart(newForOp.getBody());
|
||||
IRMapping mapping;
|
||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
mapping.map(origConversion.getResult(), newForOp.getRegionIterArgs()[i]);
|
||||
|
||||
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
if (&op == (Operation *)(&origConversion))
|
||||
continue;
|
||||
Operation *newOp = rewriter.clone(op, mapping);
|
||||
=======
|
||||
// Create a new loop before the existing one, with the extra operands.
|
||||
rewriter.setInsertionPoint(loop);
|
||||
auto operands = llvm::to_vector<4>(loop.getIterOperands());
|
||||
@@ -966,36 +820,10 @@ static void rewriteSlice(SetVector<Value> &slice,
|
||||
opsToRewrite.insert(v.cast<BlockArgument>().getOwner()->getParentOp());
|
||||
// We also need to rewrite the yield op.
|
||||
opsToRewrite.insert(v.cast<BlockArgument>().getOwner()->getTerminator());
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
}
|
||||
}
|
||||
opsToRewrite = multiRootTopologicalSort(opsToRewrite);
|
||||
|
||||
<<<<<<< HEAD
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto forOp = cast<scf::ForOp>(op);
|
||||
auto iterArgs = forOp.getRegionIterArgs();
|
||||
for (const auto &iterArg : llvm::enumerate(iterArgs)) {
|
||||
// skip non-tensor types
|
||||
if (!iterArg.value().getType().isa<RankedTensorType>())
|
||||
continue;
|
||||
SmallVector<Operation *> cvts;
|
||||
if (canMoveOutOfLoop(iterArg.value(), cvts).failed())
|
||||
continue;
|
||||
// check
|
||||
for (auto *op : cvts) {
|
||||
auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto it = sharedInfo.cvtsPushedForwardMap.find(cvt);
|
||||
if (it != sharedInfo.cvtsPushedForwardMap.end())
|
||||
return mlir::failure();
|
||||
auto targetType = op->getResultTypes()[0].cast<RankedTensorType>();
|
||||
auto newFor = rematerializeForLoop(rewriter, forOp, iterArg.index(),
|
||||
targetType, cvt);
|
||||
rewriter.replaceOp(forOp, newFor);
|
||||
return success();
|
||||
=======
|
||||
SmallVector<Operation *> deadLoops;
|
||||
OpBuilder builder(slice.begin()->getContext());
|
||||
for (Operation *op : opsToRewrite) {
|
||||
@@ -1032,7 +860,6 @@ static void rewriteSlice(SetVector<Value> &slice,
|
||||
if (slice.count(operand) == 0)
|
||||
continue;
|
||||
yieldOperands.push_back(mapping.lookup(operand));
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
}
|
||||
builder.create<scf::YieldOp>(op->getLoc(), yieldOperands);
|
||||
op->erase();
|
||||
@@ -1214,19 +1041,6 @@ public:
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp m = getOperation();
|
||||
|
||||
<<<<<<< HEAD
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
PatternSharedInfo sharedInfo;
|
||||
|
||||
patterns.add<SimplifyConversion>(context);
|
||||
patterns.add<SimplifyReduceCvt>(context);
|
||||
patterns.add<RematerializeBackward>(context, sharedInfo);
|
||||
patterns.add<RematerializeForward>(context, sharedInfo);
|
||||
patterns.add<MoveConvertOutOfLoop>(context, sharedInfo);
|
||||
patterns.add<MoveConvertOutOfIf>(context);
|
||||
patterns.add<DecomposeDotOperand>(context);
|
||||
patterns.add<ConvertDotConvert>(context);
|
||||
=======
|
||||
// 1. Propagate layout forward starting from "anchor" ops.
|
||||
m.walk([](triton::FuncOp funcOp) {
|
||||
LayoutPropagation layoutPropagation(funcOp);
|
||||
@@ -1249,7 +1063,6 @@ public:
|
||||
// 3. For converts left try to hoist them above cast generating larger size
|
||||
// types in order to reduce the cost of the convert op.
|
||||
hoistConvert(m);
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
|
||||
mlir::RewritePatternSet decomposePatterns(context);
|
||||
decomposePatterns.add<DecomposeDotOperand>(context);
|
||||
|
||||
@@ -401,12 +401,13 @@ void LoopPipeliner::createBufferTypes() {
|
||||
ty.getShape().end());
|
||||
Type eType = ty.getElementType();
|
||||
auto blockedEnc = ty.getEncoding().cast<ttg::BlockedEncodingAttr>();
|
||||
auto CTALayout = ttg::getCTALayout(ty.getEncoding());
|
||||
// unsigned bitWidth = dotOpEnc.getMMAv2kWidth()
|
||||
// ? 32 / dotOpEnc.getMMAv2kWidth()
|
||||
// : ty.getElementType().getIntOrFloatBitWidth();
|
||||
auto sharedEnc =
|
||||
ttg::SharedEncodingAttr::get(ty.getContext(), dotOpEnc, ty.getShape(),
|
||||
ttg::getOrder(ty.getEncoding()), eType);
|
||||
ttg::getOrder(ty.getEncoding()), CTALayout, eType);
|
||||
loadsBufferType[loadOp] = RankedTensorType::get(bufferShape, eType, sharedEnc);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||
|
||||
#include "mlir/Conversion/Passes.h"
|
||||
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
|
||||
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
|
||||
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||
@@ -58,11 +59,7 @@ struct NVVMMetadata {
|
||||
|
||||
// Add the nvvm related metadata to LLVM IR.
|
||||
static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
|
||||
<<<<<<< HEAD
|
||||
bool isROCM, const int threadsPerCTA) {
|
||||
=======
|
||||
Target target) {
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
Target target, const int threadsPerCTA) {
|
||||
auto *module = func->getParent();
|
||||
auto &ctx = func->getContext();
|
||||
|
||||
@@ -92,18 +89,8 @@ static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
|
||||
}
|
||||
|
||||
if (metadata.isKernel) {
|
||||
<<<<<<< HEAD
|
||||
if (isROCM) {
|
||||
func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
|
||||
func->addFnAttr("amdgpu-flat-work-group-size",
|
||||
"1, " + std::to_string(threadsPerCTA));
|
||||
func->addFnAttr("denormal-fp-math-f32", "preserve-sign");
|
||||
func->addFnAttr("amdgpu-unsafe-fp-atomics", "true");
|
||||
} else {
|
||||
=======
|
||||
switch (target) {
|
||||
case Target::NVVM: {
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
llvm::Metadata *mdArgs[] = {
|
||||
llvm::ValueAsMetadata::get(func), llvm::MDString::get(ctx, "kernel"),
|
||||
llvm::ValueAsMetadata::get(
|
||||
@@ -113,7 +100,10 @@ static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
|
||||
} break;
|
||||
case Target::ROCDL: {
|
||||
func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
|
||||
func->addFnAttr("amdgpu-flat-work-group-size", "1, 1024");
|
||||
func->addFnAttr("amdgpu-flat-work-group-size",
|
||||
"1, " + std::to_string(threadsPerCTA));
|
||||
func->addFnAttr("denormal-fp-math-f32", "preserve-sign");
|
||||
func->addFnAttr("amdgpu-unsafe-fp-atomics", "true");
|
||||
} break;
|
||||
}
|
||||
}
|
||||
@@ -341,11 +331,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
|
||||
for (auto &func : llvmModule->functions()) {
|
||||
auto it = nvvmMetadata.find(func.getName());
|
||||
if (it != nvvmMetadata.end())
|
||||
<<<<<<< HEAD
|
||||
amendLLVMFunc(&func, it->second, isROCM, threadsPerCTA);
|
||||
=======
|
||||
amendLLVMFunc(&func, it->second, target);
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
amendLLVMFunc(&func, it->second, target, threadsPerCTA);
|
||||
}
|
||||
|
||||
return llvmModule;
|
||||
@@ -379,7 +365,9 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
pm.addPass(mlir::createConvertIndexToLLVMPass());
|
||||
pm.addPass(
|
||||
createConvertTritonGPUToLLVMPass({computeCapability, &tmaInfos, target}));
|
||||
#ifndef USE_ROCM
|
||||
pm.addPass(createConvertNVGPUToLLVMPass());
|
||||
#endif
|
||||
pm.addPass(mlir::createArithToLLVMConversionPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
// Simplify the IR
|
||||
|
||||
@@ -36,11 +36,8 @@
|
||||
#include "triton/Target/HSACO/HSACOTranslation.h"
|
||||
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||
#include "triton/Target/PTX/PTXTranslation.h"
|
||||
<<<<<<< HEAD
|
||||
#include "triton/Target/HSACO/HSACOTranslation.h"
|
||||
=======
|
||||
#include "triton/Target/PTX/TmaMetadata.h"
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
#include "triton/Tools/Sys/GetEnv.hpp"
|
||||
#include "triton/Tools/Sys/GetPlatform.hpp"
|
||||
|
||||
|
||||
@@ -1005,7 +1005,7 @@ def deserialize_fp8(np_data, in_dtype):
|
||||
# return np_data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4, tl.float8e5])
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4b15x4, tl.float8e4nv, tl.float8e5])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32])
|
||||
def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
|
||||
"""
|
||||
@@ -1056,9 +1056,9 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
|
||||
[32, 32, 128],
|
||||
[128, 128, 64],
|
||||
[64, 128, 128]]
|
||||
for ab_type in [[tl.float8e4, tl.float16],
|
||||
for ab_type in [[tl.float8e4nv, tl.float16],
|
||||
[tl.float8e5, tl.float16],
|
||||
[tl.float16, tl.float8e4],
|
||||
[tl.float16, tl.float8e4nv],
|
||||
[tl.float16, tl.float8e5]]
|
||||
for out_dtype in [torch.float16, torch.float32]
|
||||
])
|
||||
|
||||
@@ -53,14 +53,8 @@ def compile_fn(config, device_type, cc):
|
||||
|
||||
|
||||
def test_compile_in_subproc() -> None:
|
||||
<<<<<<< HEAD
|
||||
cc, device_type = get_device_type()
|
||||
config = instance_descriptor(tuple(range(4)), ())
|
||||
=======
|
||||
major, minor = torch.cuda.get_device_capability(0)
|
||||
cc = major * 10 + minor
|
||||
config = instance_descriptor(tuple(range(4)), (), (), ())
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
|
||||
multiprocessing.set_start_method('fork')
|
||||
proc = multiprocessing.Process(
|
||||
@@ -92,14 +86,8 @@ def compile_fn_dot(config, device_type, cc):
|
||||
|
||||
def test_compile_in_forked_subproc() -> None:
|
||||
reset_tmp_dir()
|
||||
<<<<<<< HEAD
|
||||
cc, device_type = get_device_type()
|
||||
config = instance_descriptor(tuple(range(1)), ())
|
||||
=======
|
||||
major, minor = torch.cuda.get_device_capability(0)
|
||||
cc = major * 10 + minor
|
||||
config = instance_descriptor(tuple(range(1)), (), (), ())
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
|
||||
assert multiprocessing.get_start_method() == 'fork'
|
||||
proc = multiprocessing.Process(
|
||||
|
||||
@@ -73,17 +73,13 @@ def optimize_ttir(mod, arch):
|
||||
return mod
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
def ttir_to_ttgir(mod, num_warps, warpsize):
|
||||
def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, arch):
|
||||
pm = ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize)
|
||||
=======
|
||||
def ttir_to_ttgir(mod, num_warps, num_ctas, arch):
|
||||
pm = ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps, 32, num_ctas, arch)
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
if is_hip():
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize, num_ctas, 0)
|
||||
else:
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize, num_ctas, arch)
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
@@ -109,26 +105,16 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch,
|
||||
if optimize_epilogue:
|
||||
pm.add_tritongpu_optimize_epilogue_pass()
|
||||
pm.add_tritongpu_optimize_dot_operands_pass()
|
||||
<<<<<<< HEAD
|
||||
if num_stages == 0 and is_hip() and gpu_matrix_core_version() != 0:
|
||||
pm.add_tritongpu_stream_pipeline_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
else:
|
||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
pm.add_tritongpu_prefetch_pass()
|
||||
pm.add_tritongpu_optimize_dot_operands_pass()
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
pm.add_tritongpu_decompose_conversions_pass()
|
||||
if num_stages != 0:
|
||||
pm.add_tritongpu_reorder_instructions_pass()
|
||||
=======
|
||||
ws_enabled = False
|
||||
# `num_warps` does not mean the total number of warps of a CTA when
|
||||
# warp specialization is enabled.
|
||||
# it's the responsibility of the compiler to figure out the exact
|
||||
# `num_warps` to use.
|
||||
# TODO: support the case where `num_warps` from user is not 4.
|
||||
if arch // 10 >= 9 and enable_warp_specialization and num_warps == 4:
|
||||
if _is_cuda(arch) and arch // 10 >= 9 and enable_warp_specialization and num_warps == 4:
|
||||
pm.add_tritongpu_ws_feasibility_checking_pass(arch)
|
||||
pm.run(mod)
|
||||
ws_enabled = ir.is_ws_supported(mod)
|
||||
@@ -142,20 +128,27 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch,
|
||||
pm.add_tritongpu_wsmaterialization_pass(arch)
|
||||
pm.add_cse_pass()
|
||||
else:
|
||||
pm.add_tritongpu_pipeline_pass(
|
||||
num_stages, num_warps, num_ctas, arch)
|
||||
pm.add_tritongpu_materialize_load_store_pass(num_warps, arch)
|
||||
if arch // 10 <= 8:
|
||||
if is_hip():
|
||||
pm.add_tritongpu_pipeline_pass(
|
||||
num_stages, num_warps, num_ctas, 0)
|
||||
else:
|
||||
pm.add_tritongpu_pipeline_pass(
|
||||
num_stages, num_warps, num_ctas, arch)
|
||||
if is_hip():
|
||||
pm.add_tritongpu_materialize_load_store_pass(num_warps, 0)
|
||||
else:
|
||||
pm.add_tritongpu_materialize_load_store_pass(num_warps, arch)
|
||||
if _is_cuda(arch) and arch // 10 <= 8:
|
||||
pm.add_tritongpu_prefetch_pass()
|
||||
pm.add_tritongpu_optimize_dot_operands_pass()
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
pm.add_tritongpu_decompose_conversions_pass()
|
||||
pm.add_tritongpu_ws_fixup_missing_attrs_pass()
|
||||
pm.add_tritongpu_reorder_instructions_pass()
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
if num_stages != 0:
|
||||
pm.add_tritongpu_reorder_instructions_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_symbol_dce_pass()
|
||||
if arch // 10 >= 9:
|
||||
if _is_cuda(arch) and arch // 10 >= 9:
|
||||
pm.add_tritongpu_fence_insertion_pass()
|
||||
pm.add_tritongpu_ws_fixup_missing_attrs_pass()
|
||||
pm.run(mod)
|
||||
@@ -475,10 +468,6 @@ def compile(fn, **kwargs):
|
||||
warp_size = CUDA_DEFAULT_WARP_SIZE if _is_cuda(arch) else arch[3]
|
||||
context = ir.context()
|
||||
constants = kwargs.get("constants", dict())
|
||||
<<<<<<< HEAD
|
||||
num_warps = kwargs.get("num_warps", 4)
|
||||
num_stages = kwargs.get("num_stages", 3 if is_cuda and arch >= 75 else (1 if is_hip else 2))
|
||||
=======
|
||||
num_warps = kwargs.get("num_warps", get_arch_default_num_warps(device_type))
|
||||
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
|
||||
num_ctas = kwargs.get("num_ctas", 1)
|
||||
@@ -487,7 +476,6 @@ def compile(fn, **kwargs):
|
||||
enable_warp_specialization = kwargs.get("enable_warp_specialization", False)
|
||||
# TODO[shuhaoj]: persistent can be decoupled with warp specialization
|
||||
enable_persistent = kwargs.get("enable_persistent", enable_warp_specialization)
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
extern_libs = kwargs.get("extern_libs", dict())
|
||||
if extern_libs is None:
|
||||
extern_libs = dict()
|
||||
@@ -509,11 +497,7 @@ def compile(fn, **kwargs):
|
||||
stages["ttir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
||||
<<<<<<< HEAD
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size), num_stages, arch))
|
||||
=======
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos))
|
||||
if is_cuda:
|
||||
@@ -584,11 +568,8 @@ def compile(fn, **kwargs):
|
||||
InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']]
|
||||
else:
|
||||
metadata = {"num_warps": num_warps,
|
||||
<<<<<<< HEAD
|
||||
"warp_size": warp_size,
|
||||
=======
|
||||
"num_ctas": num_ctas,
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
"num_stages": num_stages,
|
||||
"enable_warp_specialization": enable_warp_specialization,
|
||||
"enable_persistent": enable_persistent,
|
||||
@@ -705,11 +686,8 @@ class CompiledKernel:
|
||||
# initialize metadata
|
||||
self.shared = metadata["shared"]
|
||||
self.num_warps = metadata["num_warps"]
|
||||
<<<<<<< HEAD
|
||||
self.warp_size = metadata["warp_size"]
|
||||
=======
|
||||
self.num_ctas = metadata["num_ctas"]
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
self.num_stages = metadata["num_stages"]
|
||||
self.clusterDims = metadata["clusterDims"]
|
||||
if "tensormaps_info" in metadata:
|
||||
|
||||
@@ -104,156 +104,145 @@ def generate_launcher(constants, signature, ids):
|
||||
|
||||
# generate glue code
|
||||
if is_hip():
|
||||
src = f"""
|
||||
#define __HIP_PLATFORM_AMD__
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <Python.h>
|
||||
#include <stdio.h>
|
||||
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
|
||||
params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)]
|
||||
src = f"""
|
||||
#define __HIP_PLATFORM_AMD__
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <Python.h>
|
||||
#include <stdbool.h>
|
||||
#include <dlfcn.h>
|
||||
|
||||
static inline void gpuAssert(hipError_t code, const char *file, int line)
|
||||
{{
|
||||
if (code != HIP_SUCCESS)
|
||||
{{
|
||||
const char* prefix = "Triton Error [HIP]: ";
|
||||
const char* str = hipGetErrorString(code);
|
||||
char err[1024] = {{0}};
|
||||
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
}}
|
||||
static inline void gpuAssert(hipError_t code, const char *file, int line)
|
||||
{{
|
||||
if (code != HIP_SUCCESS)
|
||||
{{
|
||||
const char* prefix = "Triton Error [HIP]: ";
|
||||
const char* str = hipGetErrorString(code);
|
||||
char err[1024] = {{0}};
|
||||
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
}}
|
||||
}}
|
||||
|
||||
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
||||
|
||||
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
||||
// printf("_launch hip kernel\\n");
|
||||
void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }};
|
||||
if (gridX*gridY*gridZ > 0) {{
|
||||
HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, 64*num_warps, 1, 1, shared_memory, stream, params, 0));
|
||||
}}
|
||||
}}
|
||||
|
||||
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
||||
typedef struct _DevicePtrInfo {{
|
||||
hipDeviceptr_t dev_ptr;
|
||||
bool valid;
|
||||
}} DevicePtrInfo;
|
||||
|
||||
static int getWarpSize(hipStream_t stream)
|
||||
{{
|
||||
int device_id = hipGetStreamDeviceId(stream);
|
||||
gpuAssert(device_id >= 0 ? hipSuccess : hipErrorInvalidDevice, __FILE__, __LINE__);
|
||||
hipDeviceProp_t prop;
|
||||
HIP_CHECK(hipGetDeviceProperties(&prop, device_id));
|
||||
return prop.warpSize;
|
||||
}}
|
||||
|
||||
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, hipStream_t stream, hipFunction_t function, {arg_decls}) {{
|
||||
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
|
||||
if (gridX*gridY*gridZ > 0) {{
|
||||
int warp_size = getWarpSize(stream);
|
||||
HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, num_warps * warp_size, 1, 1, shared_memory, stream, params, 0));
|
||||
}}
|
||||
}}
|
||||
|
||||
typedef struct _DevicePtrInfo {{
|
||||
hipDeviceptr_t dev_ptr;
|
||||
bool valid;
|
||||
}} DevicePtrInfo;
|
||||
|
||||
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
||||
DevicePtrInfo ptr_info;
|
||||
ptr_info.dev_ptr = 0;
|
||||
ptr_info.valid = true;
|
||||
|
||||
if (PyLong_Check(obj)) {{
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
if (obj == Py_None) {{
|
||||
// valid nullptr
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
|
||||
|
||||
if (ptr) {{
|
||||
PyObject *empty_tuple = PyTuple_New(0);
|
||||
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
|
||||
Py_DECREF(empty_tuple);
|
||||
Py_DECREF(ptr);
|
||||
|
||||
if (!PyLong_Check(ret)) {{
|
||||
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
||||
ptr_info.valid = false;
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
|
||||
|
||||
if (!ptr_info.dev_ptr)
|
||||
return ptr_info;
|
||||
|
||||
uint64_t dev_ptr;
|
||||
hipError_t status = hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
||||
if (status == hipErrorInvalidValue) {{
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
||||
ptr_info.valid = false;
|
||||
}}
|
||||
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
||||
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
||||
DevicePtrInfo ptr_info;
|
||||
ptr_info.dev_ptr = 0;
|
||||
ptr_info.valid = true;
|
||||
if (PyLong_Check(obj)) {{
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
|
||||
return ptr_info;
|
||||
}}
|
||||
if (obj == Py_None) {{
|
||||
// valid nullptr
|
||||
return ptr_info;
|
||||
}}
|
||||
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
|
||||
if(ptr){{
|
||||
PyObject *empty_tuple = PyTuple_New(0);
|
||||
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
|
||||
Py_DECREF(empty_tuple);
|
||||
Py_DECREF(ptr);
|
||||
if (!PyLong_Check(ret)) {{
|
||||
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
||||
ptr_info.valid = false;
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
|
||||
int gridX, gridY, gridZ;
|
||||
uint64_t _stream;
|
||||
uint64_t _function;
|
||||
int num_warps;
|
||||
int shared_memory;
|
||||
PyObject *launch_enter_hook = NULL;
|
||||
PyObject *launch_exit_hook = NULL;
|
||||
PyObject *compiled_kernel = NULL;
|
||||
|
||||
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
||||
if (!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel{', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{
|
||||
return NULL;
|
||||
}}
|
||||
|
||||
if (launch_enter_hook != Py_None && !PyObject_CallObject(launch_enter_hook, args)) {{
|
||||
return NULL;
|
||||
}}
|
||||
|
||||
// raise exception asap
|
||||
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''});
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
if (launch_exit_hook != Py_None && !PyObject_CallObject(launch_exit_hook, args)) {{
|
||||
return NULL;
|
||||
}}
|
||||
|
||||
// return None
|
||||
Py_INCREF(Py_None);
|
||||
return Py_None;
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
|
||||
if(!ptr_info.dev_ptr)
|
||||
return ptr_info;
|
||||
uint64_t dev_ptr;
|
||||
hipError_t status = hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
||||
if (status == hipErrorInvalidValue) {{
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
||||
ptr_info.valid = false;
|
||||
}}
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
|
||||
Py_DECREF(ret);
|
||||
return ptr_info;
|
||||
}}
|
||||
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {{
|
||||
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
||||
{{NULL, NULL, 0, NULL}} // sentinel
|
||||
}};
|
||||
static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
// printf("launch\\n");
|
||||
int gridX, gridY, gridZ;
|
||||
uint64_t _stream;
|
||||
uint64_t _function;
|
||||
int num_warps;
|
||||
int num_ctas;
|
||||
int clusterDimX;
|
||||
int clusterDimY;
|
||||
int clusterDimZ;
|
||||
int shared_memory;
|
||||
PyObject *launch_enter_hook = NULL;
|
||||
PyObject *launch_exit_hook = NULL;
|
||||
PyObject *compiled_kernel = NULL;
|
||||
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
||||
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &num_ctas, &clusterDimX, &clusterDimY, &clusterDimZ, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel{', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{
|
||||
return NULL;
|
||||
}}
|
||||
|
||||
static struct PyModuleDef ModuleDef = {{
|
||||
PyModuleDef_HEAD_INIT,
|
||||
\"__triton_launcher\",
|
||||
NULL, //documentation
|
||||
-1, //size
|
||||
ModuleMethods
|
||||
}};
|
||||
if (launch_enter_hook != Py_None) {{
|
||||
PyObject_CallObject(launch_enter_hook, args);
|
||||
}}
|
||||
|
||||
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if(m == NULL) {{
|
||||
return NULL;
|
||||
}}
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
return m;
|
||||
}}
|
||||
"""
|
||||
|
||||
// raise exception asap
|
||||
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
|
||||
_launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''});
|
||||
|
||||
if (launch_exit_hook != Py_None) {{
|
||||
PyObject_CallObject(launch_exit_hook, args);
|
||||
}}
|
||||
|
||||
if(PyErr_Occurred()) {{
|
||||
return NULL;
|
||||
}}
|
||||
// return None
|
||||
Py_INCREF(Py_None);
|
||||
return Py_None;
|
||||
}}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {{
|
||||
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
||||
{{NULL, NULL, 0, NULL}} // sentinel
|
||||
}};
|
||||
|
||||
static struct PyModuleDef ModuleDef = {{
|
||||
PyModuleDef_HEAD_INIT,
|
||||
\"__triton_launcher\",
|
||||
NULL, //documentation
|
||||
-1, //size
|
||||
ModuleMethods
|
||||
}};
|
||||
|
||||
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if(m == NULL) {{
|
||||
return NULL;
|
||||
}}
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
return m;
|
||||
}}
|
||||
"""
|
||||
else:
|
||||
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
|
||||
params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)]
|
||||
@@ -279,12 +268,6 @@ static inline void gpuAssert(CUresult code, const char *file, int line)
|
||||
|
||||
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
||||
|
||||
<<<<<<< HEAD
|
||||
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
|
||||
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
|
||||
if(gridX*gridY*gridZ > 0){{
|
||||
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, num_warps * 32, 1, 1, shared_memory, stream, params, 0));
|
||||
=======
|
||||
typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra);
|
||||
|
||||
static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
|
||||
@@ -336,7 +319,6 @@ static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas
|
||||
}}
|
||||
CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
|
||||
}}
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
}}
|
||||
}}
|
||||
|
||||
|
||||
@@ -1324,13 +1324,8 @@ def dot(lhs: tl.tensor,
|
||||
assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
|
||||
|
||||
assert lhs.type.is_block() and rhs.type.is_block()
|
||||
<<<<<<< HEAD
|
||||
assert lhs.dtype == rhs.dtype or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp16()) or (lhs.type.scalar.is_fp16() and rhs.type.scalar.is_fp8()), f"First input ({lhs.dtype}) and second input ({rhs.dtype}) must have the same dtype!"
|
||||
=======
|
||||
|
||||
assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.arch)
|
||||
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!"
|
||||
assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!"
|
||||
assert lhs.shape[1].value == rhs.shape[0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})"
|
||||
|
||||
@@ -363,17 +363,10 @@ class JITFunction(KernelInterface[T]):
|
||||
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
|
||||
|
||||
src = f"""
|
||||
<<<<<<< HEAD
|
||||
|
||||
def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
from ..compiler import compile, CompiledKernel
|
||||
sig_key = {sig_keys},
|
||||
=======
|
||||
import triton
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages
|
||||
sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()}
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
|
||||
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
|
||||
assert num_ctas > 0
|
||||
|
||||
@@ -90,11 +90,7 @@ def _fwd_kernel(
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)
|
||||
if IS_CAUSAL:
|
||||
qk = tl.where(P_SEQ + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
<<<<<<< HEAD
|
||||
qk += tl.dot(q, k)
|
||||
=======
|
||||
qk += tl.dot(q, k, out_dtype=tl.float16)
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
# -- compute scaling constant ---
|
||||
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
||||
alpha = tl.math.exp2(m_i - m_i_new)
|
||||
@@ -168,17 +164,11 @@ def _bwd_kernel(
|
||||
V += off_z * stride_vz + off_h * stride_vh
|
||||
DO += off_z * stride_qz + off_h * stride_qh
|
||||
DQ += off_z * stride_qz + off_h * stride_qh
|
||||
<<<<<<< HEAD
|
||||
DK += off_z * stride_qz + off_h * stride_qh
|
||||
DV += off_z * stride_qz + off_h * stride_qh
|
||||
# See fwd pass above for explanation.
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
for start_n in range(0, num_block):
|
||||
=======
|
||||
DK += off_z * stride_kz + off_h * stride_kh
|
||||
DV += off_z * stride_vz + off_h * stride_vh
|
||||
# See fwd pass above for explanation.
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
for start_n in range(0, num_block_kv):
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
if CAUSAL:
|
||||
lo = tl.math.max(start_n * BLOCK_M - P_SEQ, 0)
|
||||
else:
|
||||
@@ -466,7 +456,6 @@ class _attention(torch.autograd.Function):
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
BLOCK_M = 128
|
||||
<<<<<<< HEAD
|
||||
if torch.version.hip is None:
|
||||
BLOCK_N = 64 if Lk <= 64 else 32
|
||||
num_stages = 4 if Lk <= 64 else 3
|
||||
@@ -480,14 +469,6 @@ class _attention(torch.autograd.Function):
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2]
|
||||
|
||||
=======
|
||||
BLOCK_N = 64 if Lk <= 64 else 32
|
||||
num_stages = 4 if Lk <= 64 else 3
|
||||
num_warps = 4
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2]
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
L,
|
||||
@@ -507,10 +488,7 @@ class _attention(torch.autograd.Function):
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = Lk
|
||||
ctx.causal = causal
|
||||
<<<<<<< HEAD
|
||||
ctx.split_kernel = split_kernel
|
||||
=======
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
ctx.P_SEQ = P_SEQ
|
||||
return o
|
||||
|
||||
@@ -538,28 +516,8 @@ class _attention(torch.autograd.Function):
|
||||
block_scale = (q.shape[2] // ctx.grid[0]) // BLOCK
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do,
|
||||
<<<<<<< HEAD
|
||||
do_scaled, delta,
|
||||
BLOCK_M=block_scale * BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
=======
|
||||
delta,
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do,
|
||||
dq, dk, dv,
|
||||
L, delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2], ctx.P_SEQ,
|
||||
ctx.grid[0], triton.cdiv(k.shape[2], BLOCK),
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||
CAUSAL=ctx.causal,
|
||||
num_stages=1,
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
)
|
||||
if not ctx.split_kernel:
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
@@ -611,7 +569,6 @@ class _attention(torch.autograd.Function):
|
||||
attention = _attention.apply
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, P_SEQ',
|
||||
[(4, 48, 1024, 64, 128),
|
||||
(4, 48, 2048, 64, 128),
|
||||
@@ -621,16 +578,10 @@ attention = _attention.apply
|
||||
])
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
def test_op_fwd(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16):
|
||||
=======
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, P_SEQ', [(6, 9, 1024, 64, 128)])
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16):
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
<<<<<<< HEAD
|
||||
sm_scale = q.shape[-1] ** (-0.5)
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
@@ -663,12 +614,6 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, P_SEQ, dtype=torch.float16):
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ)
|
||||
=======
|
||||
sm_scale = 0.5
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ)
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
||||
if causal:
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
set -o xtrace
|
||||
|
||||
alias drun='sudo docker run -it --rm --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined'
|
||||
DRUN='sudo docker run -it --rm --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined'
|
||||
|
||||
# DEVICES="--gpus all"
|
||||
DEVICES="--device=/dev/kfd --device=/dev/dri"
|
||||
@@ -21,7 +21,7 @@ CONTAINER_NAME=triton
|
||||
# start new container
|
||||
docker stop $CONTAINER_NAME
|
||||
docker rm $CONTAINER_NAME
|
||||
CONTAINER_ID=$(drun -d -w $WORK_DIR --name $CONTAINER_NAME $MEMORY $VOLUMES $DEVICES $IMAGE_NAME)
|
||||
CONTAINER_ID=$($DRUN -d -w $WORK_DIR --name $CONTAINER_NAME $MEMORY $VOLUMES $DEVICES $IMAGE_NAME)
|
||||
echo "CONTAINER_ID: $CONTAINER_ID"
|
||||
# docker cp . $CONTAINER_ID:$WORK_DIR
|
||||
# docker exec $CONTAINER_ID bash -c "bash scripts/amd/run.sh"
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s
|
||||
|
||||
tt.func @ops() {
|
||||
<<<<<<< HEAD
|
||||
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {{.*}}
|
||||
=======
|
||||
// CHECK: module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {{.*}}
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
|
||||
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
|
||||
%c = arith.constant dense<3.00e+00> : tensor<128x128xf32>
|
||||
@@ -37,17 +33,10 @@ tt.func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
tt.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
// Test if the total number of threadsPerWarp is 64
|
||||
// Test if the total number of warps is 2
|
||||
<<<<<<< HEAD
|
||||
// CHECK: #[[blocked0:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 16], warpsPerCTA = [1, 2], order = [0, 1]}>
|
||||
// CHECK: #[[blocked1:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 2], order = [0, 1]}>
|
||||
// CHECK: #[[blocked2:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 2], order = [0, 1]}>
|
||||
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {{.*}}
|
||||
=======
|
||||
// CHECK: #[[blocked0:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
// CHECK: #[[blocked1:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
// CHECK: #[[blocked2:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
// CHECK: module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {{.*}}
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
%c0 = arith.constant dense<1.00e+00> : tensor<4x4xf32>
|
||||
%c1 = arith.constant dense<2.00e+00> : tensor<8x2xf32>
|
||||
%c2 = arith.constant dense<3.00e+00> : tensor<16x16xf32>
|
||||
|
||||
@@ -1,17 +1,9 @@
|
||||
<<<<<<< HEAD
|
||||
// RUN: not triton-opt %s -split-input-file --convert-triton-gpu-to-llvm --mlir-pass-pipeline-crash-reproducer=%t 2>/dev/null | FileCheck --check-prefixes=CHECK,GCN %s
|
||||
=======
|
||||
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=nvvm" | FileCheck %s
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
// RUN: not triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=rocdl" --mlir-pass-pipeline-crash-reproducer=%t 2>/dev/null | FileCheck --check-prefixes=CHECK,GCN %s
|
||||
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<f16, 1>)
|
||||
// Here the 128 comes from the 4 in module attribute multiples 32
|
||||
<<<<<<< HEAD
|
||||
// PTX: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = [128 : i32]} {{.*}}
|
||||
=======
|
||||
// CHECK: nvvm.kernel = 1 : ui1, nvvm.maxntid = [128 : i32]
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
// PTX: nvvm.kernel = 1 : ui1, nvvm.maxntid = [128 : i32]
|
||||
tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
// CHECK: llvm.return
|
||||
tt.return
|
||||
@@ -711,11 +703,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_program_id
|
||||
tt.func @basic_program_id() {
|
||||
<<<<<<< HEAD
|
||||
// PTX: nvvm.read.ptx.sreg.ctaid.x : i32
|
||||
=======
|
||||
// CHECK: llvm.inline_asm asm_dialect = att operand_attrs = [] "mov.u32 $0, %ctaid.x;", "=r" : () -> i32
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
%0 = tt.get_program_id x : i32
|
||||
tt.return
|
||||
}
|
||||
@@ -788,15 +776,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
|
||||
// -----
|
||||
|
||||
<<<<<<< HEAD
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
// PTX-LABEL: basic_async_wait
|
||||
// This test is disabled for GCN target, because it is PTX specific
|
||||
// GCN-NOT: basic_async_wait
|
||||
=======
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_async_wait
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
tt.func @basic_async_wait() {
|
||||
// PTX: cp.async.wait_group 0x4
|
||||
triton_gpu.async_wait {num = 4: i32}
|
||||
@@ -947,14 +930,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
%tensor = triton_gpu.alloc_tensor : tensor<2x16x32xf32, #A>
|
||||
%index = arith.constant 1 : i32
|
||||
|
||||
<<<<<<< HEAD
|
||||
// This test is PTX specific, GCN targets decompose async operations into oridinary load/stores.
|
||||
// TODO: Fix AMD compilation.
|
||||
// last operation (commit_group) is still emitted by AMD pipeline,
|
||||
// It is left to catch changes in AMD compilation pipeline.
|
||||
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// PTX: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// PTX: llvm.inline_asm
|
||||
@@ -978,18 +960,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
// GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32>
|
||||
// GCN-COUNT-4: llvm.store {{.*}} : !llvm.ptr<vector<1xf32>, 3>
|
||||
// GCN: llvm.inline_asm {{.*}}cp.async.commit_group
|
||||
=======
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.commit_group
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x32x!tt.ptr<f32>, #AL> -> tensor<2x16x32xf32, #A>
|
||||
triton_gpu.async_commit_group
|
||||
tt.return
|
||||
@@ -1224,7 +1194,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
#mma0 = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=2}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}>
|
||||
<<<<<<< HEAD
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// PTX-LABEL: convert_dot
|
||||
// This test is not relevant to GCN target, because it is PTX specific
|
||||
@@ -1232,20 +1201,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
%AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
|
||||
%BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: ldmatrix.sync.aligned.m8n8.x4
|
||||
// PTX: ldmatrix.sync.aligned.m8n8.x4
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: ldmatrix.sync.aligned.m8n8.x4
|
||||
=======
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: convert_dot
|
||||
tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
|
||||
%AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
|
||||
%BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK: ldmatrix.sync.aligned.m8n8.x4
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
%AA_DOT = triton_gpu.convert_layout %AA : (tensor<16x16xf16, #shared0>) -> tensor<16x16xf16, #dot_operand_a>
|
||||
%BB_DOT = triton_gpu.convert_layout %BB : (tensor<16x16xf16, #shared0>) -> tensor<16x16xf16, #dot_operand_b>
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
|
||||
@@ -1271,7 +1229,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
|
||||
// -----
|
||||
|
||||
<<<<<<< HEAD
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0]}>
|
||||
#mfma0 = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
@@ -1312,19 +1269,12 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [2, 2]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// PTX: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// PTX-LABEL: convert_layout_mmav2_block
|
||||
// This test is not relevant to GCN target, because it is PTX specific
|
||||
=======
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_mmav2_block
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
// PTX: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// PTX-LABEL: convert_layout_mmav2_block
|
||||
// This test is not relevant to GCN target, because it is PTX specific
|
||||
tt.func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) {
|
||||
// PTX: llvm.store
|
||||
// PTX-SAME: !llvm.ptr<vector<2xf32>, 3>
|
||||
@@ -1340,20 +1290,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
|
||||
// -----
|
||||
|
||||
<<<<<<< HEAD
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// PTX: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// PTX-LABEL: convert_layout_mmav1_block
|
||||
// This test is not relevant to GCN target, because it is PTX specific
|
||||
=======
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 16]}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_mmav1_block
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
// PTX: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// PTX-LABEL: convert_layout_mmav1_block
|
||||
// This test is not relevant to GCN target, because it is PTX specific
|
||||
tt.func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) {
|
||||
// PTX: llvm.store
|
||||
// PTX-SAME: !llvm.ptr<vector<2xf32>, 3>
|
||||
@@ -1438,14 +1380,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
#mma = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=2}>
|
||||
<<<<<<< HEAD
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
// PTX-LABEL: matmul_kernel_dot_operand_layout
|
||||
// This test is disabled for GCN target, because it is PTX specific
|
||||
// This test is not relevant to GCN target, because it is PTX specific
|
||||
=======
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
tt.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
|
||||
@@ -1465,7 +1403,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
|
||||
// -----
|
||||
|
||||
<<<<<<< HEAD
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false}>
|
||||
@@ -1494,16 +1431,6 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// -----
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0]}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2]}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// PTX-LABEL: matmul884_kernel_dot_operand_layout
|
||||
// This test is not relevant to GCN target, because it is PTX specific
|
||||
=======
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
@@ -1511,7 +1438,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
// PTX-LABEL: matmul884_kernel_dot_operand_layout
|
||||
// This test is not relevant to GCN target, because it is PTX specific
|
||||
tt.func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<32x64xf16, #shared0>, %b:tensor<64x64xf16, #shared1>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<32x64xf32, #mma>
|
||||
@@ -1557,14 +1485,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=1}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=1}>
|
||||
<<<<<<< HEAD
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
// PTX-LABEL: matmul_tf32dot
|
||||
// This test is not relevant to GCN target, because it is PTX specific
|
||||
=======
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: matmul_tf32dot
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
|
||||
@@ -1617,20 +1540,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: atomic_add_f32_scalar
|
||||
tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
|
||||
<<<<<<< HEAD
|
||||
// GCN-NOT: llvm.inline_asm
|
||||
// GCN: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f32, 1>, f32
|
||||
// PTX: llvm.icmp "eq"
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: @$3 atom.global.gpu.add.f32
|
||||
// PTX-SAME: @$3 atom.global.gpu.relaxed.add.f32
|
||||
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32, sem = 1: i32} : (!tt.ptr<f32>, f32, i1) -> f32
|
||||
=======
|
||||
// CHECK: llvm.icmp "eq"
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$3 atom.global.gpu.relaxed.add.f32
|
||||
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32, sem = 1 : i32} : (!tt.ptr<f32>, f32, i1) -> f32
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
@@ -1699,18 +1614,9 @@ tt.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
%blockidx = tt.get_program_id x: i32
|
||||
%blockidy = tt.get_program_id y: i32
|
||||
%blockidz = tt.get_program_id z : i32
|
||||
<<<<<<< HEAD
|
||||
// PTX: nvvm.read.ptx.sreg.ctaid.x
|
||||
// PTX: nvvm.read.ptx.sreg.ctaid.y
|
||||
// PTX: nvvm.read.ptx.sreg.ctaid.z
|
||||
// GCN: rocdl.workgroup.id.x
|
||||
// GCN: rocdl.workgroup.id.y
|
||||
// GCN: rocdl.workgroup.id.z
|
||||
=======
|
||||
// CHECK: clusterid.x
|
||||
// CHECK: clusterid.y
|
||||
// CHECK: clusterid.z
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
%v0 = arith.addi %blockidx, %blockidy : i32
|
||||
%v1 = arith.addi %v0, %blockidz : i32
|
||||
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>
|
||||
@@ -1747,15 +1653,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
|
||||
module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
tt.func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
<<<<<<< HEAD
|
||||
// PTX: nvvm.read.ptx.sreg.nctaid.x
|
||||
// PTX: nvvm.read.ptx.sreg.nctaid.y
|
||||
// PTX: nvvm.read.ptx.sreg.nctaid.z
|
||||
// GCN: rocdl.grid.dim.x
|
||||
// GCN: rocdl.grid.dim.y
|
||||
// GCN: rocdl.grid.dim.z
|
||||
=======
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
%blockdimx = tt.get_num_programs {axis=0:i32} : i32
|
||||
%blockdimy = tt.get_num_programs {axis=1:i32} : i32
|
||||
%blockdimz = tt.get_num_programs {axis=2:i32} : i32
|
||||
@@ -1887,14 +1784,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
}
|
||||
|
||||
// -----
|
||||
<<<<<<< HEAD
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
=======
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
// CHECK-LABEL: test_s8_to_bf16_conversion
|
||||
tt.func @test_s8_to_bf16_conversion(%in: tensor<32xi8, #blocked>) {
|
||||
// We can't vectorize if we only process
|
||||
@@ -1907,12 +1799,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
}
|
||||
|
||||
// -----
|
||||
<<<<<<< HEAD
|
||||
|
||||
#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1]}>
|
||||
=======
|
||||
#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
#dot = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: test_s8_to_bf16_vectorized_conversion
|
||||
@@ -1935,25 +1823,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
|
||||
// -----
|
||||
|
||||
<<<<<<< HEAD
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [2, 1], order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
// CHECK-LABEL: atomic_add_f16
|
||||
tt.func @atomic_add_f16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: f16 {tt.difisibility = 16 : i32}) {
|
||||
%c1_i1 = arith.constant 1 : i1
|
||||
%1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%2 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x32xi32, #blocked>
|
||||
%3 = tt.broadcast %2 : (tensor<1x32xi32, #blocked>) -> tensor<32x32xi32, #blocked>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<32x32x!tt.ptr<f16>, #blocked>
|
||||
%5 = tt.addptr %4, %3 : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
|
||||
%6 = tt.splat %arg1 : (f16) -> tensor<32x32xf16, #blocked>
|
||||
%7 = tt.splat %c1_i1 : (i1) -> tensor<32x32xi1, #blocked>
|
||||
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: @$3 atom.global.gpu.add.noftz.f16x2
|
||||
// GCN-COUNT-8: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f16, 1>, f16
|
||||
%8 = "tt.atomic_rmw"(%5, %6, %7) {atomic_rmw_op = 5 : i32, sem = 1: i32} : (tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xf16, #blocked>, tensor<32x32xi1, #blocked>) -> tensor<32x32xf16, #blocked>
|
||||
=======
|
||||
// CHECK-LABEL: sum_reduction
|
||||
// CHECK: %[[M:.+]] = llvm.mlir.constant(-1 : i32) : i32
|
||||
// CHECK: nvvm.redux.sync add %{{.*}}, %[[M]]
|
||||
@@ -1986,14 +1855,12 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c
|
||||
%13 = tt.splat %arg1 : (!tt.ptr<i32, 1>) -> tensor<1x!tt.ptr<i32, 1>, #blocked1>
|
||||
%14 = tt.addptr %13, %0 : tensor<1x!tt.ptr<i32, 1>, #blocked1>, tensor<1xi32, #blocked1>
|
||||
tt.store %14, %12 {cache = 1 : i32, evict = 1 : i32} : tensor<1xi32, #blocked1>
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#slice = #triton_gpu.slice<{dim = 1, parent = #blocked}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} {
|
||||
@@ -2047,4 +1914,26 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
|
||||
// -----
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [2, 1], order = [1, 0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} {
|
||||
// CHECK-LABEL: atomic_add_f16
|
||||
tt.func @atomic_add_f16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: f16 {tt.difisibility = 16 : i32}) {
|
||||
%c1_i1 = arith.constant 1 : i1
|
||||
%1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%2 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x32xi32, #blocked>
|
||||
%3 = tt.broadcast %2 : (tensor<1x32xi32, #blocked>) -> tensor<32x32xi32, #blocked>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<32x32x!tt.ptr<f16>, #blocked>
|
||||
%5 = tt.addptr %4, %3 : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
|
||||
%6 = tt.splat %arg1 : (f16) -> tensor<32x32xf16, #blocked>
|
||||
%7 = tt.splat %c1_i1 : (i1) -> tensor<32x32xi1, #blocked>
|
||||
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: @$3 atom.global.gpu.add.noftz.f16x2
|
||||
// GCN-COUNT-8: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f16, 1>, f16
|
||||
%8 = "tt.atomic_rmw"(%5, %6, %7) {atomic_rmw_op = 5 : i32, sem = 1: i32} : (tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xf16, #blocked>, tensor<32x32xi1, #blocked>) -> tensor<32x32xf16, #blocked>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,16 +8,9 @@
|
||||
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
<<<<<<< HEAD
|
||||
|
||||
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
|
||||
// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
|
||||
=======
|
||||
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr<f32, 1>, [[row_layout]]>
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
// CHECK: [[load_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[row_layout]]>
|
||||
// CHECK: [[load_other:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[row_layout]]>
|
||||
// CHECK: [[load_val:%.*]] = tt.load [[load_ptr]], [[load_mask]], [[load_other]] {{.*}} : tensor<64x64xf32, [[row_layout]]>
|
||||
|
||||
@@ -3,16 +3,15 @@ add_triton_ut(
|
||||
SRCS PTXAsmFormatTest.cpp
|
||||
LIBS TritonGPUToLLVM
|
||||
)
|
||||
<<<<<<< HEAD
|
||||
|
||||
add_triton_ut(
|
||||
NAME TestGcnAsmFormat
|
||||
SRCS GcnAsmFormatTest.cpp
|
||||
LIBS TritonGPUToLLVM
|
||||
=======
|
||||
)
|
||||
|
||||
add_triton_ut(
|
||||
NAME TestEmitIndices
|
||||
SRCS EmitIndicesTest.cpp DumpLayout.cpp
|
||||
LIBS TritonGPUIR TritonNvidiaGPUIR ${dialect_libs} ${conversion_libs}
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user