Initial commit to resolve merge conflicts

rename tl.float8e4 to tl.float8e4nv to align with upstream

ROCM IFU: Fix python arch issues

ROCM IFU: Fix kernel launcher

ROCM IFU: Fix merge conflicts

fix debug build

Set correct threadsPerCTA
This commit is contained in:
Jason Furmanek
2023-09-12 20:43:59 +00:00
parent 74fd8e9754
commit e5d7bb4fae
36 changed files with 414 additions and 1005 deletions

9
.gitignore vendored
View File

@@ -25,14 +25,5 @@ venv.bak/
.idea
cmake-build-*
<<<<<<< HEAD
# cache dumps
triton_cache*
log_*
#
python/triton/third_party/cuda/bin/ptxas
=======
# Third-party binaries
ptxas
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272

View File

@@ -200,6 +200,10 @@ include_directories(${LLVM_INCLUDE_DIRS})
include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
set(ROCM_LIBRARIES
${CMAKE_CURRENT_SOURCE_DIR}/lib/rocm/libhsa-runtime64.so
)
# link_directories(${LLVM_LIBRARY_DIR})
add_subdirectory(include)
add_subdirectory(lib)
@@ -239,10 +243,6 @@ if(TRITON_BUILD_PYTHON_MODULE)
MLIRIR
)
set(ROCM_LIBRARIES
${CMAKE_CURRENT_SOURCE_DIR}/lib/rocm/libhsa-runtime64.so
)
if(WIN32)
target_link_libraries(triton PRIVATE ${ROCM_LIBRARIES} ${LLVM_LIBRARIES} ${CMAKE_DL_LIBS}
${TRITON_LIBRARIES}

View File

@@ -124,20 +124,14 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
}
llvm::LLVMContext llvmContext;
<<<<<<< HEAD
mlir::triton::gpu::TMAMetadataTy tmaInfos;
#ifdef USE_ROCM
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module,
SMArch.getValue(), true /*isRocm*/);
SMArch.getValue(), tmaInfos, Target::ROCDL);
#else
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module,
SMArch.getValue(), false /*isRocm*/);
SMArch.getValue(), tmaInfos, Target::Default);
#endif
=======
mlir::triton::gpu::TMAMetadataTy tmaInfos;
auto llvmir = translateTritonGPUToLLVMIR(
&llvmContext, *module, SMArch.getValue(), tmaInfos, Target::Default);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
if (!llvmir) {
llvm::errs() << "Translate to LLVM IR failed";
}

View File

@@ -21,18 +21,8 @@ enum Target { NVVM, ROCDL, Default = NVVM };
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>>
<<<<<<< HEAD
#ifdef USE_ROCM
createConvertTritonGPUToLLVMPass(int computeCapability = 80,
bool isROCM = true);
#else
createConvertTritonGPUToLLVMPass(int computeCapability = 80,
bool isROCM = false);
#endif
=======
createConvertTritonGPUToLLVMPass(const ConvertTritonGPUToLLVMOptions &options);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
} // namespace triton
} // namespace mlir

View File

@@ -147,11 +147,11 @@ compared to 1*64 when the hasLeadingOffset is false.
int vecSize = ((typeWidthInBit == 16) ? 64 : 32 ) / typeWidthInBit;
int maxPhase = SIMDWidth / perPhase;
return $_get(context, vecSize, perPhase, maxPhase, order);
return get(context, vecSize, perPhase, maxPhase, order, CTALayout);
} else {
// Do not swizzle in case k dimension is not innermost.
// In this case accesses will go in different banks even without swizzling.
return $_get(context, 1, 1, 1, order);
return get(context, 1, 1, 1, order, CTALayout);
}
}
#endif
@@ -185,20 +185,12 @@ compared to 1*64 when the hasLeadingOffset is false.
// ---- begin Ampere ----
if (mmaEnc.isAmpere()) {
<<<<<<< HEAD
int perPhase = 128 / (shape[order[0]] * 4 / dotOpEnc.getKWidth());
=======
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth());
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
perPhase = std::max<int>(perPhase, 1);
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
// for now, disable swizzle when using transposed int8 tensor cores
if ((32 / typeWidthInBit != dotOpEnc.getKWidth()) && order[0] == inner)
<<<<<<< HEAD
return $_get(context, 1, 1, 1, order);
=======
return get(context, 1, 1, 1, order, CTALayout);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
// --- handle A operand ---
if (opIdx == 0) { // compute swizzling for A operand

View File

@@ -52,7 +52,6 @@ def TritonGPU_Dialect : Dialect {
}
return threadsPerWarp.cast<IntegerAttr>().getInt();
}
<<<<<<< HEAD
static int getSharedSize(ModuleOp mod) {
Attribute sharedAttr = mod->getDiscardableAttr("triton_gpu.shared");
if(!sharedAttr) {
@@ -61,8 +60,6 @@ def TritonGPU_Dialect : Dialect {
return sharedAttr.cast<IntegerAttr>().getInt();
}
=======
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
}];
let useDefaultAttributePrinterParser = 1;

View File

@@ -356,9 +356,6 @@ bool supportMMA(triton::DotOp op, int version) {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
auto aElemTy = op.getA().getType().cast<RankedTensorType>().getElementType();
auto bElemTy = op.getB().getType().cast<RankedTensorType>().getElementType();
<<<<<<< HEAD
=======
if (version == 3) {
if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3"))
return false;
@@ -374,7 +371,6 @@ bool supportMMA(triton::DotOp op, int version) {
return false;
}
}
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
if (aElemTy.isF32() && bElemTy.isF32()) {
return (op.getAllowTF32() && version == 2) || version == 3;
}
@@ -446,7 +442,6 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
!srcTy.getElementType().isF32();
}
<<<<<<< HEAD
#ifdef USE_ROCM
bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
auto srcLayout = srcTy.getEncoding();
@@ -464,7 +459,7 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
}
#endif
=======
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
auto src = srcTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
auto dst = dstTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
@@ -475,7 +470,6 @@ bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
dst.getVersionMajor() == 3 && dst.getWarpsPerCTA()[1] == 1 &&
srcElemsPerThread == dstElemsPerThread;
}
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
bool isSingleValue(Value value) {
// Don't consider load as expensive if it is loading a scalar.

View File

@@ -1,3 +1,6 @@
add_library(rocm_libraries SHARED IMPORTED )
set_target_properties(rocm_libraries PROPERTIES IMPORTED_LOCATION ${ROCM_LIBRARIES})
add_mlir_conversion_library(TritonGPUToLLVM
ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp
ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp
@@ -62,4 +65,5 @@ add_mlir_conversion_library(TritonGPUToLLVM
TritonGPUTransforms
TritonNvidiaGPUTransforms
NVGPUIR
rocm_libraries
)

View File

@@ -253,7 +253,7 @@ private:
}
#ifdef USE_ROCM
if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, type);
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, type, false);
SmallVector<SmallVector<unsigned>> offsets;
assert(rank == 2);
SmallVector<Value> multiDimOffset(rank);

View File

@@ -573,11 +573,7 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, Value tensor,
int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth;
int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth;
<<<<<<< HEAD
auto numRep = encoding.getMMAv2Rep(tensorTy.getShape(), bitwidth);
=======
auto numRep = encoding.getMMAv2Rep(shapePerCTA, bitwidth);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
int kWidth = encoding.getKWidth();
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();

View File

@@ -25,13 +25,11 @@ LogicalResult convertMMA16816(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter);
<<<<<<< HEAD
#ifdef USE_ROCM
LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter);
#endif
=======
LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Value thread);
@@ -41,7 +39,6 @@ LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
Value thread);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
using ConvertTritonGPUOpToLLVMPattern<
@@ -72,13 +69,10 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
return convertMMA1688(op, adaptor, getTypeConverter(), rewriter);
if (mmaLayout.isAmpere())
return convertMMA16816(op, adaptor, getTypeConverter(), rewriter);
<<<<<<< HEAD
=======
if (mmaLayout.isHopper())
return convertWGMMA(op, adaptor, getTypeConverter(), rewriter,
getThreadId(rewriter, loc));
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
llvm::report_fatal_error(
"Unsupported MMA kind found when converting DotOp to LLVM.");
}

View File

@@ -10,15 +10,14 @@ using ::mlir::triton::gpu::getTotalElemsPerThread;
#ifdef USE_ROCM
static SmallVector<Value>
Fp16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
const SmallVector<Value> &v) {
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);
Value fp16x2Vec1 = undef(fp16x2VecTy);
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[2], i32_val(0));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[3], i32_val(1));
Value a0 = bitcast(fp16x2Vec0, i32_ty);
Value a1 = bitcast(fp16x2Vec1, i32_ty);
@@ -58,20 +57,19 @@ const std::string Fp16_to_Fp8E5M2 =
#ifdef USE_ROCM
static SmallVector<Value>
Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
const SmallVector<Value> &v) {
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value a0 = undef(fp8x4VecTy);
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0));
a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1));
a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1));
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2));
a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3));
a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3));
a0 = bitcast(a0, i32_ty);
Value a1 = undef(fp8x4VecTy);
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0));
a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1));
a1 = insert_element(fp8x4VecTy, a1, v[2], i32_val(1));
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2));
a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3));
a1 = insert_element(fp8x4VecTy, a1, v[3], i32_val(3));
a1 = bitcast(a1, i32_ty);
auto fp16x2VecTy = vec_ty(f16_ty, 2);
@@ -94,21 +92,20 @@ const std::string Fp8E5M2_to_Fp16 = "{ \n"
#ifdef USE_ROCM
static SmallVector<Value>
Fp8E5M2_to_Bf16(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
const SmallVector<Value> &v) {
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value a0 = undef(fp8x4VecTy);
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0));
a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1));
a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1));
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2));
a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3));
a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3));
a0 = bitcast(a0, i32_ty);
Value a1 = undef(fp8x4VecTy);
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0));
a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1));
a1 = insert_element(fp8x4VecTy, a1, v[2], i32_val(1));
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2));
a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3));
a1 = insert_element(fp8x4VecTy, a1, v[3], i32_val(3));
a1 = bitcast(a1, i32_ty);
Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
@@ -155,15 +152,14 @@ const std::string Fp8E5M2_to_Bf16 =
#ifdef USE_ROCM
static SmallVector<Value>
Bf16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
const SmallVector<Value> &v) {
auto bf16x2VecTy = vec_ty(i16_ty, 2);
Value bf16x2Vec0 = undef(bf16x2VecTy);
Value bf16x2Vec1 = undef(bf16x2VecTy);
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v0, i32_val(0));
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v1, i32_val(1));
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v2, i32_val(0));
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v3, i32_val(1));
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v[0], i32_val(0));
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v[1], i32_val(1));
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v[2], i32_val(0));
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v[3], i32_val(1));
bf16x2Vec0 = bitcast(bf16x2Vec0, i32_ty);
bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty);
@@ -276,21 +272,20 @@ const std::string Bf16_to_Fp8E5M2 =
#ifdef USE_ROCM
static SmallVector<Value>
Fp8E4M3B15_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
const SmallVector<Value> &v) {
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value a0 = undef(fp8x4VecTy);
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0));
a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1));
a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1));
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2));
a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3));
a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3));
a0 = bitcast(a0, i32_ty);
Value a1 = undef(fp8x4VecTy);
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0));
a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1));
a1 = insert_element(fp8x4VecTy, a1, v[2], i32_val(1));
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2));
a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3));
a1 = insert_element(fp8x4VecTy, a1, v[3], i32_val(3));
a1 = bitcast(a1, i32_ty);
Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
@@ -325,22 +320,20 @@ const std::string Fp8E4M3B15_to_Fp16 =
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
"shl.b32 $1, b1, 7; \n"
"} \n";
<<<<<<< HEAD
#endif
#ifdef USE_ROCM
static SmallVector<Value>
Fp16_to_Fp8E4M3B15(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
const SmallVector<Value> &v) {
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);
Value fp16x2Vec1 = undef(fp16x2VecTy);
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[2], i32_val(0));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[3], i32_val(1));
Value fp16x2VecMin = i32_val(0xBF80BF80);
Value fp16x2VecMax = i32_val(0x3F803F80);
@@ -374,29 +367,6 @@ Fp16_to_Fp8E4M3B15(Location loc, ConversionPatternRewriter &rewriter,
};
}
#else
const std::string Fp16_to_Fp8E4M3B15 =
"{ \n"
".reg .b32 a<2>, b<2>; \n"
".reg .b32 min_val, max_val; \n"
"mov.b32 min_val, 0xBF80BF80; \n"
"mov.b32 max_val, 0x3F803F80; \n"
"max.f16x2 $1, $1, min_val; \n"
"min.f16x2 $1, $1, max_val; \n"
"max.f16x2 $2, $2, min_val; \n"
"min.f16x2 $2, $2, max_val; \n"
"shl.b32 a0, $1, 1; \n"
"shl.b32 a1, $2, 1; \n"
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n"
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n"
"add.u32 a0, a0, 0x00800080; \n"
"add.u32 a1, a1, 0x00800080; \n"
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n"
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n"
"prmt.b32 $0, b0, b1, 0x7531; \n"
"}";
#endif
=======
const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
std::string ret;
ret += "{ \n"
@@ -431,7 +401,7 @@ const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
"}";
return ret;
}
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
#endif
/* ----- FP8E4M3B15X4 ------ */
// NOTE: NOT USED RIGHT NOW
@@ -446,8 +416,7 @@ const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
#ifdef USE_ROCM
static SmallVector<Value>
Fp8E4M3B15x4_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
const SmallVector<Value> &v) {
return {};
}
#else
@@ -472,8 +441,7 @@ const std::string Fp8E4M3B15x4_to_Fp16 =
#ifdef USE_ROCM
static SmallVector<Value>
Fp16_to_Fp8E4M3B15x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
const SmallVector<Value> &v) {
return {};
}
#else
@@ -488,7 +456,6 @@ const std::string Fp16_to_Fp8E4M3B15x4 =
"}";
#endif
<<<<<<< HEAD
/* ----- FP8E4M3 ------ */
// Note: when handled by software, this format
// does not handle denormals and has
@@ -498,21 +465,20 @@ const std::string Fp16_to_Fp8E4M3B15x4 =
#ifdef USE_ROCM
static SmallVector<Value>
Fp8E4M3_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
const SmallVector<Value> &v) {
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value a0 = undef(fp8x4VecTy);
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0));
a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1));
a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1));
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2));
a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3));
a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3));
a0 = bitcast(a0, i32_ty);
Value a1 = undef(fp8x4VecTy);
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0));
a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1));
a1 = insert_element(fp8x4VecTy, a1, v[2], i32_val(1));
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2));
a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3));
a1 = insert_element(fp8x4VecTy, a1, v[3], i32_val(3));
a1 = bitcast(a1, i32_ty);
Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
@@ -559,16 +525,15 @@ const std::string Fp8E4M3_to_Fp16 =
#ifdef USE_ROCM
static SmallVector<Value>
Fp16_to_Fp8E4M3(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
const SmallVector<Value> &v) {
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);
Value fp16x2Vec1 = undef(fp16x2VecTy);
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[2], i32_val(0));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[3], i32_val(1));
fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
@@ -618,21 +583,20 @@ const std::string Fp16_to_Fp8E4M3 =
#ifdef USE_ROCM
static SmallVector<Value>
Fp8E4M3_to_Bf16(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
const SmallVector<Value> &v) {
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value a0 = undef(fp8x4VecTy);
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0));
a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1));
a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1));
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2));
a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3));
a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3));
a0 = bitcast(a0, i32_ty);
Value a1 = undef(fp8x4VecTy);
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0));
a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1));
a1 = insert_element(fp8x4VecTy, a1, v[2], i32_val(1));
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2));
a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3));
a1 = insert_element(fp8x4VecTy, a1, v[3], i32_val(3));
a1 = bitcast(a1, i32_ty);
Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
@@ -679,15 +643,14 @@ const std::string Fp8E4M3_to_Bf16 =
#ifdef USE_ROCM
static SmallVector<Value>
Bf16_to_Fp8E4M3(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
const SmallVector<Value> &v) {
auto bf16x2VecTy = vec_ty(i16_ty, 2);
Value bf16x2Vec0 = undef(bf16x2VecTy);
Value bf16x2Vec1 = undef(bf16x2VecTy);
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v0, i32_val(0));
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v1, i32_val(1));
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v2, i32_val(0));
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v3, i32_val(1));
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v[0], i32_val(0));
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v[1], i32_val(1));
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v[2], i32_val(0));
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v[3], i32_val(1));
bf16x2Vec0 = bitcast(bf16x2Vec0, i32_ty);
bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty);
@@ -788,7 +751,7 @@ const std::string Bf16_to_Fp8E4M3 =
"or.b32 $0, nosign, sign; \n" // restore sign
"}";
#endif
=======
// Fp8E4M3 (x2) -> Fp16 (x2) (packed)
const std::string Fp8E4M3Nv_to_Fp16 = "{ \n"
"cvt.rn.f16x2.e4m3x2 $0, $1; \n"
@@ -797,7 +760,6 @@ const std::string Fp8E4M3Nv_to_Fp16 = "{ \n"
const std::string Fp16_to_Fp8E4M3Nv = "{ \n"
"cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; \n"
"}";
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
/* ----- Packed integer to BF16 ------ */
#ifndef USE_ROCM
@@ -1245,18 +1207,23 @@ struct FpToFpOpConversion
// F8 -> F16
{{F8E4M3B15TyID, F16TyID}, Fp8E4M3B15_to_Fp16},
{{F8E4M3FNTyID, F16TyID}, Fp8E4M3B15x4_to_Fp16},
{{F8E4M3TyID, F16TyID}, Fp8E4M3Nv_to_Fp16},
{{F8E4M3TyID, F16TyID}, Fp8E4M3_to_Fp16},
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16},
// F16 -> F8
#ifdef USE_ROCM
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15},
#else
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15(computeCapability >= 80)},
#endif
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
{{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3Nv},
{{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3},
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2},
// F8 -> BF16
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
// BF16 -> F8
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2},
};
int inVecWidthBits = 32;
int outVecWidthBits = 32;
if (srcTy.isFloat8E4M3FNUZ()) {
@@ -1274,15 +1241,9 @@ struct FpToFpOpConversion
<< "\n";
llvm_unreachable("");
}
<<<<<<< HEAD
#ifdef USE_ROCM
return srcMap.lookup(key);
#else
return makeConverterFromPtx(srcMap.lookup(key),
getTypeConverter()->convertType(srcTy),
getTypeConverter()->convertType(dstTy));
#endif
=======
if (computeCapability < 90 &&
(srcTy.isFloat8E4M3FNUZ() || dstTy.isFloat8E4M3FNUZ())) {
llvm::errs() << "Conversion from/to f8e4m3nv is only supported on "
@@ -1294,7 +1255,7 @@ struct FpToFpOpConversion
getTypeConverter()->convertType(srcTy),
getTypeConverter()->convertType(dstTy),
inVecWidthBits, outVecWidthBits);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
#endif
}
SmallVector<Value> createDestOps(triton::FpToFpOp op, OpAdaptor adaptor,
@@ -1712,9 +1673,8 @@ struct FSubOpConversion
#ifdef USE_ROCM
static SmallVector<Value>
S8_to_Bf16(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
SmallVector<Value> inValues = {v0, v1, v2, v3};
const SmallVector<Value> &v) {
SmallVector<Value> inValues = {v[0], v[1], v[2], v[3]};
SmallVector<Value> outValues = {};
for (Value inVal : inValues) {
Value i32Val = sext(i32_ty, inVal);
@@ -1751,21 +1711,17 @@ struct SIToFPOpConversion
Type outElemTy = getElementType(op.getOut());
if (outElemTy.isBF16() && inElemTy.isInteger(8) && operands.size() >= 4) {
#if USE_ROCM
auto outVals = S8_to_Bf16(loc, rewriter, operands[0][0], operands[1][0],
operands[2][0], operands[3][0]);
SmallVector<Value> inVals = {operands[0][0], operands[1][0],
operands[2][0], operands[3][0]};
auto outVals = S8_to_Bf16(loc, rewriter, inVals);
#else
auto cvtFunc = makeConverterFromPtx(
S8_to_Bf16, getTypeConverter()->convertType(inElemTy),
getTypeConverter()->convertType(outElemTy));
<<<<<<< HEAD
auto outVals = cvtFunc(loc, rewriter, operands[0][0], operands[1][0],
operands[2][0], operands[3][0]);
auto cvtFunc = makeConverterFromPtx(
S8_to_Bf16, getTypeConverter()->convertType(inElemTy),
getTypeConverter()->convertType(outElemTy));
SmallVector<Value> inVals = {operands[0][0], operands[1][0],
operands[2][0], operands[3][0]};
auto outVals = cvtFunc(loc, rewriter, inVals);
#endif
=======
SmallVector<Value> inVals = {operands[0][0], operands[1][0],
operands[2][0], operands[3][0]};
auto outVals = cvtFunc(loc, rewriter, inVals);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
assert(outVals.size() == 4);
return outVals;
} else if (outElemTy.isBF16()) {

View File

@@ -425,8 +425,21 @@ private:
for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) {
SmallVector<Value> shfl(acc.size());
unsigned shuffleIdx = N;
#ifdef USE_ROCM
auto srcTys = op.getInputTypes();
auto inputTy = srcTys[0].cast<RankedTensorType>();
auto inMfma =
inputTy.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
if (inMfma && inMfma.getIsTransposed()) {
//assert(sizeIntraWarps == 2);
// Adjecant threads in y dimension in transposed MFMA layout are 32
// apart: [[0 0 0 0 32 32 32 32 ...] [1 1 1 1 33 33 33 33 ...] ...].
shuffleIdx = 32;
}
#endif
for (unsigned i = 0; i < acc.size(); ++i) {
shfl[i] = shflSync(loc, rewriter, acc[i], N);
shfl[i] = shflSync(loc, rewriter, acc[i], shuffleIdx);
}
accumulate(rewriter, op.getCombineOp(), acc, shfl, false);
}
@@ -491,11 +504,11 @@ private:
triton::ReduceOp op = helper.getOperation();
Location loc = op.getLoc();
Value threadId = getThreadId(rewriter, loc);
auto srcLayout = helper.getSrcLayout();
unsigned wavefront_size = triton::gpu::getWarpSize(srcLayout);
Value warpSize = i32_val(wavefront_size);
Value warpId = udiv(threadId, warpSize);
Value laneId = urem(threadId, warpSize);
auto srcLayout = helper.getSrcLayout();
auto srcShape = helper.getSrcShape();
unsigned axis = op.getAxis();
auto smemShapes = helper.getScratchConfigsFast();
@@ -511,6 +524,7 @@ private:
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
#ifdef USE_ROCM
auto srcTys = op.getInputTypes();
auto inputTy = srcTys[0].cast<RankedTensorType>();
auto inMfma =
inputTy.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
@@ -532,35 +546,8 @@ private:
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
<<<<<<< HEAD
SmallVector<Value> acc = it.second;
// Reduce within warps
for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) {
SmallVector<Value> shfl(op.getNumOperands());
unsigned shuffleIdx = N;
#ifdef USE_ROCM
if (inMfma && inMfma.getIsTransposed()) {
assert(sizeIntraWarps == 2);
// Adjecant threads in y dimension in transposed MFMA layout are 32
// apart: [[0 0 0 0 32 32 32 32 ...] [1 1 1 1 33 33 33 33 ...] ...].
shuffleIdx = 32;
}
#endif
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
shfl[i] = shflSync(loc, rewriter, acc[i], shuffleIdx);
}
accumulate(rewriter, *combineOp, acc, shfl, false);
}
if (isWarpSync) {
finalAccs[key] = acc;
continue;
}
=======
SmallVector<Value> &acc = it.second;
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = warpIdAxis;
Value writeOffset =
@@ -620,6 +607,9 @@ private:
icmp_eq(laneIdModSizeInterWarps, zero);
Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero);
auto srcLayout = helper.getSrcLayout();
unsigned wavefront_size = triton::gpu::getWarpSize(srcLayout);
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
#if USE_ROCM
// This barrier is known to be critical for Navi 2x/3x

View File

@@ -446,6 +446,16 @@ struct GetProgramIdOpConversion
LogicalResult
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
#ifdef USE_ROCM
Location loc = op->getLoc();
assert(op.getAxisAsInt() < 3);
Value blockId =
rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxisAsInt()]);
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, blockId);
return success();
#else
// It is not easy to get the compute capability here, so we use numCTAs to
// decide the semantic of GetProgramIdOp. If numCTAs = 1, then
// GetProgramIdOp is converted to "%ctaid", otherwise it is converted to
@@ -462,7 +472,11 @@ struct GetProgramIdOpConversion
Value programId = getSRegValue(rewriter, loc, sreg);
rewriter.replaceOp(op, programId);
return success();
#endif
}
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
mlir::gpu::Dimension::y,
mlir::gpu::Dimension::z};
};
struct GetNumProgramsOpConversion
@@ -473,6 +487,26 @@ struct GetNumProgramsOpConversion
LogicalResult
matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
#ifdef USE_ROCM
Location loc = op->getLoc();
assert(op.getAxis() < 3);
// Seem like GridDimOp returns the number of threads (not the number of
// workgroups) in a kernel (a bug in llvm https://reviews.llvm.org/D156009),
// so as a workaround here, we divide by the number of threads
// per workgroup to get the number of workgroups in a kernel.
// TODO: when we do upstream to include llvm fix, we can remove this workaround
// The unit test added in this PR can guarantee that.
Value threadsPerGrid =
rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxis()]);
Value threadsPerBlock =
rewriter.create<::mlir::gpu::BlockDimOp>(loc, dims[op.getAxis()]);
Value threadNumPerGrid = rewriter.create<arith::TruncIOp>(loc, i32_ty, threadsPerGrid);
Value threadNumPerBlock = rewriter.create<arith::TruncIOp>(loc, i32_ty, threadsPerBlock);
rewriter.replaceOpWithNewOp<LLVM::UDivOp>(op, threadNumPerGrid, threadNumPerBlock);
return success();
#else
// It is not easy to get the compute capability here, so we use numCTAs to
// decide the semantic of GetNumProgramsOp. If numCTAs = 1, then
// GetNumProgramsOp is converted to "%nctaid", otherwise it is converted to
@@ -486,33 +520,17 @@ struct GetNumProgramsOpConversion
std::string sreg = numCTAs == 1 ? "%nctaid." : "%nclusterid.";
sreg.append(1, 'x' + op.getAxis()); // 0 -> 'x', 1 -> 'y', 2 -> 'z'
<<<<<<< HEAD
#ifdef USE_ROCM
// Seem like GridDimOp returns the number of threads (not the number of
// workgroups) in a kernel (a bug in llvm https://reviews.llvm.org/D156009),
// so as a workaround here, we divide by the number of threads
// per workgroup to get the number of workgroups in a kernel.
// TODO: when we do upstream to include llvm fix, we can remove this workaround
// The unit test added in this PR can guarantee that.
Value threadsPerGrid =
rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxis()]);
Value threadsPerBlock =
rewriter.create<::mlir::gpu::BlockDimOp>(loc, dims[op.getAxis()]);
Value threadNumPerGrid = rewriter.create<arith::TruncIOp>(loc, i32_ty, threadsPerGrid);
Value threadNumPerBlock = rewriter.create<arith::TruncIOp>(loc, i32_ty, threadsPerBlock);
rewriter.replaceOpWithNewOp<LLVM::UDivOp>(op, threadNumPerGrid, threadNumPerBlock);
#else
Value blockId =
rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxis()]);
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, blockId);
#endif // USE_ROCM
=======
Value numPrograms = getSRegValue(rewriter, loc, sreg);
rewriter.replaceOp(op, numPrograms);
return success();
#endif
}
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
mlir::gpu::Dimension::y,
mlir::gpu::Dimension::z};
};
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
// TODO[goostavz]: GetThreadIdOp/GetClusterCTAIdOp is a temporary solution
// before async dialect is done. These concepts should appear in ttgpu

View File

@@ -324,12 +324,8 @@ public:
// then (x + y) XOR z = 0byyyyxxxx XOR 0b00000zzzz = (x XOR z) + y
// This means that we can use some immediate offsets for shared memory
// operations.
<<<<<<< HEAD
resElemTy = getTypeConverter()->convertType(resElemTy);
auto dstPtrTy = ptr_ty(resElemTy, 3);
=======
auto dstPtrTy = ptr_ty(getTypeConverter()->convertType(resElemTy), 3);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides);
Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset);
@@ -555,13 +551,8 @@ public:
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout);
auto order = triton::gpu::getOrder(layout);
<<<<<<< HEAD
auto shapePerCTA = triton::gpu::getShapePerCTA(layout, shape);
Value warpSize = i32_val(triton::gpu::getWarpSize(layout));
=======
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape);
Value warpSize = i32_val(32);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
Value warpSize = i32_val(triton::gpu::getWarpSize(layout));
Value laneId = urem(tid, warpSize);
Value warpId = udiv(tid, warpSize);
SmallVector<Value> multiDimWarpId =
@@ -694,19 +685,13 @@ public:
blockedLayout, type);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
<<<<<<< HEAD
result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, type);
if (mmaLayout.isAmpere())
result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, type);
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
result = emitBaseIndexForMfmaLayout(loc, rewriter, mfmaLayout, type);
=======
result = emitBaseIndexWithinCTAForMmaLayoutV1(loc, rewriter,
mmaLayout, type);
if (mmaLayout.isAmpere() || mmaLayout.isHopper())
result = emitBaseIndexWithinCTAForMmaLayoutV2V3(loc, rewriter,
mmaLayout, type);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
result = emitBaseIndexForMfmaLayout(loc, rewriter, mfmaLayout, type);
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(type.getShape());
@@ -764,7 +749,7 @@ public:
const unsigned elemsPerThreadPerGroup = 4;
auto warpSize = getWarpSize(mfmaLayout);
assert(warpSize == 64);
auto shapePerCta = getShapePerCTA(mfmaLayout);
auto shapePerCta = getShapePerCTATile(mfmaLayout);
for (unsigned block = 0; block < numGroups; block++) {
unsigned rowOrColOffset = block * elemsPerThreadPerGroup * warpSize / 32;
for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) {
@@ -802,14 +787,11 @@ public:
result = emitIndicesForDistributedLayout(loc, b, blocked, type,
withCTAOffset);
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
<<<<<<< HEAD
result = emitIndicesForDistributedLayout(loc, b, mma, type);
} else if (auto mfma = layout.dyn_cast<MfmaEncodingAttr>()) {
result = emitIndicesForDistributedLayout(loc, b, mfma, type);
=======
result =
emitIndicesForDistributedLayout(loc, b, mma, type, withCTAOffset);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
} else if (auto mfma = layout.dyn_cast<MfmaEncodingAttr>()) {
result =
emitIndicesForDistributedLayout(loc, b, mfma, type, withCTAOffset);
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
result =
emitIndicesForDistributedLayout(loc, b, slice, type, withCTAOffset);
@@ -848,7 +830,7 @@ private:
const BlockedEncodingAttr &blockedLayout, RankedTensorType type) const {
auto shape = type.getShape();
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(triton::gpu::getWarpSize(blocked_layout));
Value warpSize = i32_val(triton::gpu::getWarpSize(blockedLayout));
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
auto sizePerThread = blockedLayout.getSizePerThread();
@@ -1208,7 +1190,7 @@ private:
auto tensorShape = type.getShape();
SmallVector<SmallVector<unsigned>> offsets;
auto shapePerCta = getShapePerCTA(mfmaLayout);
auto shapePerCta = getShapePerCTA(mfmaLayout, tensorShape);
SmallVector<unsigned> numCTAPerDim(2);
for (unsigned d = 0; d < 2; ++d) {

View File

@@ -20,11 +20,11 @@
#include "triton/Analysis/Membar.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
<<<<<<< HEAD
=======
#ifndef USE_ROCM
#else
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#endif
#include "triton/Tools/Sys/GetPlatform.hpp"
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
#include "BarrierOpToLLVM.h"
#include "ClusterOpsToLLVM.h"
@@ -70,19 +70,13 @@ public:
: ConversionTarget(ctx) {
addLegalDialect<index::IndexDialect>();
addLegalDialect<LLVM::LLVMDialect>();
<<<<<<< HEAD
if (isROCM) {
addLegalDialect<ROCDL::ROCDLDialect>();
addLegalDialect<mlir::scf::SCFDialect>();
} else {
=======
switch (target) {
case Target::NVVM:
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
addLegalDialect<NVVM::NVVMDialect>();
break;
case Target::ROCDL:
addLegalDialect<ROCDL::ROCDLDialect>();
addLegalDialect<mlir::scf::SCFDialect>();
break;
}
addLegalOp<mlir::UnrealizedConversionCastOp>();
@@ -377,19 +371,13 @@ public:
explicit TritonLLVMConversionTarget(MLIRContext &ctx, Target target)
: ConversionTarget(ctx) {
addLegalDialect<LLVM::LLVMDialect>();
<<<<<<< HEAD
if (isROCM) {
addLegalDialect<ROCDL::ROCDLDialect>();
addLegalDialect<mlir::scf::SCFDialect>();
} else {
=======
switch (target) {
case Target::NVVM:
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
addLegalDialect<NVVM::NVVMDialect>();
break;
case Target::ROCDL:
addLegalDialect<ROCDL::ROCDLDialect>();
addLegalDialect<mlir::scf::SCFDialect>();
break;
}
addLegalDialect<mlir::triton::nvgpu::NVGPUDialect>();
@@ -419,14 +407,10 @@ struct ConvertTritonGPUToLLVM
// Preprocess
decomposeFp8e4b15Convert(mod);
<<<<<<< HEAD
decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp);
#ifdef USE_ROCM
decomposeMfmaToDotOperand(mod, numWarps, threadsPerWarp);
#endif
=======
decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp, numCTAs);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
#ifdef USE_ROCM
decomposeMfmaToDotOperand(mod, numWarps, threadsPerWarp, numCTAs);
#endif
decomposeBlockedToDotOperand(mod);
decomposeInsertSliceAsyncOp(mod);
decomposeMixedModeDotOp(mod);
@@ -699,8 +683,8 @@ private:
}
#ifdef USE_ROCM
void decomposeMfmaToDotOperand(ModuleOp mod, int numWarps,
int threadsPerWarp) const {
void decomposeMfmaToDotOperand(ModuleOp mod, int numWarps, int threadsPerWarp,
int numCTAs) const {
// Replace `mfma -> dot_op` with `mfma -> blocked -> dot_op`
// unless certain conditions are met
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
@@ -716,7 +700,7 @@ private:
dstType.getShape(), dstType.getElementType(),
triton::gpu::BlockedEncodingAttr::get(
mod.getContext(), srcType.getShape(), getSizePerThread(srcMfma),
getOrder(srcMfma), numWarps, threadsPerWarp));
getOrder(srcMfma), numWarps, threadsPerWarp, numCTAs));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(

View File

@@ -331,13 +331,9 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
int i);
Value shflUpSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
<<<<<<< HEAD
int i, Value laneId);
=======
int i);
Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
StringRef key, StringRef content);

View File

@@ -475,11 +475,7 @@ public:
void runOnOperation() override {
// Only rewrite if the hardware does not support
<<<<<<< HEAD
if (!isROCM && computeCapability >= 90)
=======
if (computeCapability >= 90)
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
return;
// NOTES(Chenggang): we don't use `ConversionPatternRewriter`, because

View File

@@ -249,7 +249,6 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
return {};
}
} else if (parentLayout.isa<MfmaEncodingAttr>()) {
auto parentShapePerCTA = getShapePerCTA(parentLayout);
auto opIdx = dotLayout.getOpIdx();
if (opIdx == 0) {
return {4, 1};
@@ -381,7 +380,7 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout,
}
} else if (auto parentMfmaLayout =
parentLayout.dyn_cast<MfmaEncodingAttr>()) {
auto parentShapePerCTA = getShapePerCTA(parentLayout, tensorShape);
auto parentShapePerCTA = getShapePerCTATile(parentLayout, tensorShape);
auto opIdx = dotLayout.getOpIdx();
if (opIdx == 0) {
@@ -691,11 +690,7 @@ static LogicalResult parseBoolAttrValue(AsmParser &parser, Attribute attr,
bool &value, StringRef desc) {
auto boolAttr = attr.dyn_cast<BoolAttr>();
if (!boolAttr) {
<<<<<<< HEAD
parser.emitError(parser.getNameLoc(), "expected bool type in ") << desc;
=======
parser.emitError(parser.getNameLoc(), "expected an bool type in ") << desc;
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
return failure();
}
value = boolAttr.getValue();
@@ -862,11 +857,11 @@ MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
return elemsPerThread;
}
<<<<<<< HEAD
unsigned MfmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
return product<unsigned>(getElemsPerThread(shape, eltTy));
=======
}
unsigned
MmaEncodingAttr::getElemsPerThreadOfOperand(int opIdx,
ArrayRef<int64_t> shape) const {
@@ -896,7 +891,6 @@ MmaEncodingAttr::getElemsPerThreadOfOperand(int opIdx,
}
}
return res;
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
}
unsigned MmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
@@ -971,7 +965,6 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
<<<<<<< HEAD
if (auto mfmaParent = getParent().dyn_cast<MfmaEncodingAttr>()) {
int warpsPerCTAM = mfmaParent.getWarpsPerCTA()[0];
int warpsPerCTAN = mfmaParent.getWarpsPerCTA()[1];
@@ -980,9 +973,7 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
auto rep = getMFMARep(shape, eltTy);
return rep[0] * rep[1];
}
=======
auto shapePerCTA = getShapePerCTA(*this, shape);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
if (auto mmaParent = getParent().dyn_cast<MmaEncodingAttr>()) {
int warpsPerCTAM = mmaParent.getWarpsPerCTA()[0];
int warpsPerCTAN = mmaParent.getWarpsPerCTA()[1];
@@ -1247,7 +1238,7 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) {
return {};
unsigned nonKDim = 0;
SmallVector<unsigned, 2> warpsPerCTA;
SmallVector<unsigned> warpsPerCTA;
bool isTransposed;
for (const NamedAttribute &attr : dict) {
@@ -1438,11 +1429,7 @@ void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
auto mmaParent = getParent().dyn_cast<MmaEncodingAttr>();
printer << "<{"
<< "opIdx = " << getOpIdx() << ", parent = " << getParent();
<<<<<<< HEAD
if ((mmaParent && mmaParent.isAmpere()) || getParent().isa<MfmaEncodingAttr>())
=======
if (mmaParent && mmaParent.isAmpere())
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
printer << ", kWidth = " << getKWidth();
printer << "}>";
}

View File

@@ -75,9 +75,8 @@ warpsPerTileV2(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
return ret;
}
<<<<<<< HEAD
#ifdef USE_ROCM
SmallVector<unsigned, 2> warpsPerTileMI200(triton::DotOp dotOp,
SmallVector<unsigned, 2> warpsPerTileMI200(tt::DotOp dotOp,
const ArrayRef<int64_t> shape,
int numWarps) {
// TODO: needs to be updated with appropriate shapePerWarp etc.
@@ -86,7 +85,7 @@ SmallVector<unsigned, 2> warpsPerTileMI200(triton::DotOp dotOp,
};
auto slices = mlir::getSlice(dotOp, filter);
for (Operation *op : slices)
if (isa<triton::DotOp>(op) && (op != dotOp))
if (isa<tt::DotOp>(op) && (op != dotOp))
return {(unsigned)numWarps, 1};
SmallVector<int64_t, 2> tensorShape = {shape[0], shape[1]};
@@ -104,7 +103,18 @@ SmallVector<unsigned, 2> warpsPerTileMI200(triton::DotOp dotOp,
ret[0] *= 2;
} else
ret[1] *= 2;
=======
} else {
ret[1] *= 2;
}
} while (true);
if (ret[1] * shapePerWarp[1] > tensorShape[1]) {
return {ret[1], ret[0]};
}
return ret;
}
SmallVector<unsigned, 2>
warpsPerTileV3(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
const SmallVector<unsigned, 3> &instrShape) {
@@ -122,17 +132,10 @@ warpsPerTileV3(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
break;
if (shape[0] > shapePerWarp[0] * ret[0]) {
ret[0] *= 2;
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
} else {
ret[1] *= 2;
}
} while (true);
<<<<<<< HEAD
if (ret[1] * shapePerWarp[1] > tensorShape[1]) {
return {ret[1], ret[0]};
}
return ret;
}
@@ -140,15 +143,15 @@ class BlockedToMFMA : public mlir::RewritePattern {
int mfmaVersion;
public:
BlockedToMFMA(mlir::MLIRContext *context, int mfmaVersion)
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context), mfmaVersion(mfmaVersion) {}
: mlir::RewritePattern(tt::DotOp::getOperationName(), 2, context), mfmaVersion(mfmaVersion) {}
bool isChainDot(triton::DotOp &dotOp) const {
bool isChainDot(tt::DotOp &dotOp) const {
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
auto slices = mlir::getSlice(dotOp, filter);
for (Operation *op : slices) {
if (isa<triton::DotOp>(op) && (op != dotOp))
if (isa<tt::DotOp>(op) && (op != dotOp))
return true;
}
return false;
@@ -158,7 +161,7 @@ public:
/// @param dot target dot operation
/// @param mfmaVersion
/// @return pair {nonKDim, kDim} sizes of one MFMA instruction arguments
std::pair<int64_t, int64_t> chooseMfmaDimensions(triton::DotOp dot, int mfmaVersion) const {
std::pair<int64_t, int64_t> chooseMfmaDimensions(tt::DotOp dot, int mfmaVersion) const {
int64_t nonKDim = 32;
// number of matrix elements along k dim per one MFMA intruction
int64_t kDim = -1;
@@ -183,11 +186,11 @@ public:
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto dotOp = cast<triton::DotOp>(op);
auto dotOp = cast<tt::DotOp>(op);
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
if (!oldRetType.getEncoding() ||
!oldRetType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>())
!oldRetType.getEncoding().isa<ttg::BlockedEncodingAttr>())
return failure();
if (!supportMFMA(dotOp))
@@ -196,7 +199,7 @@ public:
// get MFMA encoding for the given number of warps
auto retShape = oldRetType.getShape();
auto mod = op->getParentOfType<mlir::ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
int numWarps = ttg::TritonGPUDialect::getNumWarps(mod);
// operands
Value a = dotOp.getA();
@@ -205,14 +208,14 @@ public:
auto oldBType = b.getType().cast<RankedTensorType>();
auto ctx = oldAType.getContext();
triton::gpu::MfmaEncodingAttr mfmaEnc;
ttg::MfmaEncodingAttr mfmaEnc;
auto [nonKDim, kDim] = chooseMfmaDimensions(dotOp, mfmaVersion);
auto warpsPerTile = warpsPerTileMI200(dotOp, retShape, numWarps);
bool isTransposed = isChainDot(dotOp);
mfmaEnc = triton::gpu::MfmaEncodingAttr::get(
mfmaEnc = ttg::MfmaEncodingAttr::get(
oldRetType.getContext(), nonKDim, warpsPerTile, isTransposed);
auto newRetType =
@@ -220,44 +223,39 @@ public:
// convert accumulator
auto oldAcc = dotOp.getOperand(2);
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
auto newAcc = rewriter.create<ttg::ConvertLayoutOp>(
oldAcc.getLoc(), newRetType, oldAcc);
auto oldAOrder = oldAType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.cast<ttg::DotOperandEncodingAttr>()
.getParent()
.cast<triton::gpu::BlockedEncodingAttr>()
.cast<ttg::BlockedEncodingAttr>()
.getOrder();
auto oldBOrder = oldBType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.cast<ttg::DotOperandEncodingAttr>()
.getParent()
.cast<triton::gpu::BlockedEncodingAttr>()
.cast<ttg::BlockedEncodingAttr>()
.getOrder();
// kWidth is a number of consecutive elements per one instruction per one thread
auto kWidth = kDim / 2;
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth));
ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth));
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth));
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<triton::DotOp>(
ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth));
a = rewriter.create<ttg::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<ttg::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<tt::DotOp>(
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getAllowTF32());
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(
op, oldRetType, newDot.getResult());
return success();
}
};
#endif
=======
return ret;
}
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding

View File

@@ -283,12 +283,6 @@ static bool isLayoutAnchor(Operation *op) {
return false;
}
<<<<<<< HEAD
// this may generate unsupported conversions in the LLVM codegen
if (newEncoding.isa<triton::gpu::MmaEncodingAttr>() ||
newEncoding.isa<triton::gpu::MfmaEncodingAttr>()) {
return failure();
=======
void LayoutPropagation::initAnchorLayout() {
funcOp.walk([&](Operation *op) {
if (isLayoutAnchor(op)) {
@@ -304,7 +298,6 @@ void LayoutPropagation::initAnchorLayout() {
layouts.insert({result, tensorType.getEncoding()});
}
}
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
}
});
}
@@ -498,32 +491,6 @@ void LayoutPropagation::map(Value old, Value newV) {
newV;
}
<<<<<<< HEAD
// op(cvt(arg_0), arg_1, ..., arg_n)
// -> cvt(op(arg_0, cvt(arg_1), ..., cvt(arg_n)))
void pushConversionForward(triton::gpu::ConvertLayoutOp cvt,
SetVector<Operation *> &cvtSlices,
PatternSharedInfo &sharedInfo,
mlir::PatternRewriter &rewriter) {
auto srcEncoding =
cvt.getOperand().getType().cast<RankedTensorType>().getEncoding();
auto dstEncoding =
cvt.getResult().getType().cast<RankedTensorType>().getEncoding();
IRMapping mapping;
auto op = cvtSlices.front();
for (Value arg : op->getOperands()) {
if (arg.getDefiningOp() == cvt)
mapping.map(arg, cvt.getOperand());
else {
auto oldType = arg.getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
oldType.getShape(), oldType.getElementType(), srcEncoding);
auto cvtI = rewriter.create<triton::gpu::ConvertLayoutOp>(arg.getLoc(),
newType, arg);
if (Operation *argOp = arg.getDefiningOp())
cvtI->moveAfter(argOp);
mapping.map(arg, cvtI);
=======
Value LayoutPropagation::getValueAs(Value value, Attribute encoding) {
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
Value rewrittenValue;
@@ -538,7 +505,6 @@ Value LayoutPropagation::getValueAs(Value value, Attribute encoding) {
rewrittenValue = value;
else
rewrittenValue = rewriteMapping[{value, encodingPicked}];
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
}
assert(rewrittenValue);
if (rewrittenValue.getType().cast<RankedTensorType>().getEncoding() ==
@@ -572,18 +538,7 @@ Operation *LayoutPropagation::cloneElementwise(OpBuilder &rewriter,
origType.getElementType(), encoding);
newOp->getResult(i).setType(newType);
}
<<<<<<< HEAD
auto *newOp = cloneWithInferType(rewriter, op, mapping);
auto newType = newOp->getResult(0).getType().cast<RankedTensorType>();
auto newCvtType = RankedTensorType::get(
newType.getShape(), newType.getElementType(), dstEncoding);
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
newOp->getLoc(), newCvtType, newOp->getResult(0));
sharedInfo.cvtsPushedForwardMap[newCvt] = newCvt->getOperand(0).getDefiningOp();
rewriter.replaceOp(op, newCvt->getResults());
=======
return newOp;
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
}
Operation *LayoutPropagation::rewriteForOp(scf::ForOp forOp) {
@@ -649,16 +604,6 @@ Operation *LayoutPropagation::rewriteWhileOp(scf::WhileOp whileOp) {
returnTypes.push_back(newType);
}
<<<<<<< HEAD
//
class RematerializeForward : public mlir::RewritePattern {
PatternSharedInfo &sharedInfo;
public:
explicit RematerializeForward(mlir::MLIRContext *context, PatternSharedInfo &sharedInfo)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
1, context), sharedInfo(sharedInfo) {}
=======
auto newWhileOp =
rewriter.create<scf::WhileOp>(whileOp.getLoc(), returnTypes, operands);
SmallVector<Type> argsTypesBefore;
@@ -670,7 +615,6 @@ public:
rewriter.createBlock(&newWhileOp.getBefore(), {}, argsTypesBefore,
bbArgLocsBefore);
rewriter.createBlock(&newWhileOp.getAfter(), {}, returnTypes, bbArgLocsAfter);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
for (int i = 0; i < whileOp.getNumRegions(); ++i) {
newWhileOp->getRegion(i).front().getOperations().splice(
@@ -741,10 +685,6 @@ void LayoutPropagation::rewriteYieldOp(scf::YieldOp yieldOp) {
}
}
<<<<<<< HEAD
pushConversionForward(cvt, cvtSlices, sharedInfo, rewriter);
return success();
=======
void LayoutPropagation::rewriteConditionOp(scf::ConditionOp conditionOp) {
scf::WhileOp whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
for (unsigned i = 1; i < conditionOp->getNumOperands(); ++i) {
@@ -755,58 +695,9 @@ void LayoutPropagation::rewriteConditionOp(scf::ConditionOp conditionOp) {
continue;
Value newOperand = getValueAs(operand.get(), tensorType.getEncoding());
conditionOp->setOperand(operand.getOperandNumber(), newOperand);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
}
}
<<<<<<< HEAD
// Layout conversions are expensive. They require going through
// shared memory, which is orders of magnitude slower than
// other non-i/o operations in the dialect.
// It therefore makes sense to remove them whenever possible,
// even if it means rematerializing all values whose definitions
// are reachable from it without passing through any memory operation.
class RematerializeBackward : public mlir::RewritePattern {
PatternSharedInfo &sharedInfo;
public:
explicit RematerializeBackward(mlir::MLIRContext *context, PatternSharedInfo &sharedInfo)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
3, context), sharedInfo(sharedInfo) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *cvt,
mlir::PatternRewriter &rewriter) const override {
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(cvt))
return mlir::failure();
auto it = sharedInfo.cvtsPushedForwardMap.find(cvt);
if (it != sharedInfo.cvtsPushedForwardMap.end() &&
it->second == cvt->getOperand(0).getDefiningOp())
return mlir::failure();
// we don't touch block arguments
Operation *op = cvt->getOperand(0).getDefiningOp();
if (!op)
return mlir::failure();
// we don't want to rematerialize any conversion to/from shared
if (triton::gpu::isSharedEncoding(cvt->getResults()[0]) ||
triton::gpu::isSharedEncoding(cvt->getOperand(0)))
return mlir::failure();
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accommodate fused attention
auto targetType = cvt->getResultTypes()[0].cast<RankedTensorType>();
if (targetType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
return mlir::failure();
// DFS
SetVector<Operation *> processed;
SetVector<Attribute> layout;
llvm::MapVector<Value, Attribute> toConvert;
if (simulateBackwardRematerialization(cvt, processed, layout, toConvert,
targetType.getEncoding()) > 0)
return mlir::failure();
=======
void LayoutPropagation::rewriteReduceToScalar(Operation *reduceOp) {
OpBuilder rewriter(reduceOp);
Attribute srcEncoding;
@@ -875,7 +766,6 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) {
assert(0 && "unexpected op in rewrite");
return nullptr;
}
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
static bool canBeRemat(Operation *op) {
if (isa<triton::LoadOp, triton::StoreOp>(op))
@@ -898,42 +788,6 @@ static scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(loop);
<<<<<<< HEAD
class MoveConvertOutOfLoop : public mlir::RewritePattern {
PatternSharedInfo &sharedInfo;
public:
explicit MoveConvertOutOfLoop(mlir::MLIRContext *context,
PatternSharedInfo &sharedInfo)
: mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context),
sharedInfo(sharedInfo) {}
SmallVector<Value, 4>
rematerializeForLoop(mlir::PatternRewriter &rewriter, scf::ForOp &forOp,
size_t i, RankedTensorType newType,
triton::gpu::ConvertLayoutOp origConversion) const {
// Rewrite init argument
Type origType = forOp.getInitArgs()[i].getType();
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
newInitArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
newInitArgs[i].getLoc(), newType, newInitArgs[i]);
// Clone for loop
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newInitArgs);
newForOp->moveBefore(forOp);
rewriter.setInsertionPointToStart(newForOp.getBody());
IRMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(origConversion.getResult(), newForOp.getRegionIterArgs()[i]);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
for (Operation &op : forOp.getBody()->without_terminator()) {
if (&op == (Operation *)(&origConversion))
continue;
Operation *newOp = rewriter.clone(op, mapping);
=======
// Create a new loop before the existing one, with the extra operands.
rewriter.setInsertionPoint(loop);
auto operands = llvm::to_vector<4>(loop.getIterOperands());
@@ -966,36 +820,10 @@ static void rewriteSlice(SetVector<Value> &slice,
opsToRewrite.insert(v.cast<BlockArgument>().getOwner()->getParentOp());
// We also need to rewrite the yield op.
opsToRewrite.insert(v.cast<BlockArgument>().getOwner()->getTerminator());
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
}
}
opsToRewrite = multiRootTopologicalSort(opsToRewrite);
<<<<<<< HEAD
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto forOp = cast<scf::ForOp>(op);
auto iterArgs = forOp.getRegionIterArgs();
for (const auto &iterArg : llvm::enumerate(iterArgs)) {
// skip non-tensor types
if (!iterArg.value().getType().isa<RankedTensorType>())
continue;
SmallVector<Operation *> cvts;
if (canMoveOutOfLoop(iterArg.value(), cvts).failed())
continue;
// check
for (auto *op : cvts) {
auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(op);
auto it = sharedInfo.cvtsPushedForwardMap.find(cvt);
if (it != sharedInfo.cvtsPushedForwardMap.end())
return mlir::failure();
auto targetType = op->getResultTypes()[0].cast<RankedTensorType>();
auto newFor = rematerializeForLoop(rewriter, forOp, iterArg.index(),
targetType, cvt);
rewriter.replaceOp(forOp, newFor);
return success();
=======
SmallVector<Operation *> deadLoops;
OpBuilder builder(slice.begin()->getContext());
for (Operation *op : opsToRewrite) {
@@ -1032,7 +860,6 @@ static void rewriteSlice(SetVector<Value> &slice,
if (slice.count(operand) == 0)
continue;
yieldOperands.push_back(mapping.lookup(operand));
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
}
builder.create<scf::YieldOp>(op->getLoc(), yieldOperands);
op->erase();
@@ -1214,19 +1041,6 @@ public:
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
<<<<<<< HEAD
mlir::RewritePatternSet patterns(context);
PatternSharedInfo sharedInfo;
patterns.add<SimplifyConversion>(context);
patterns.add<SimplifyReduceCvt>(context);
patterns.add<RematerializeBackward>(context, sharedInfo);
patterns.add<RematerializeForward>(context, sharedInfo);
patterns.add<MoveConvertOutOfLoop>(context, sharedInfo);
patterns.add<MoveConvertOutOfIf>(context);
patterns.add<DecomposeDotOperand>(context);
patterns.add<ConvertDotConvert>(context);
=======
// 1. Propagate layout forward starting from "anchor" ops.
m.walk([](triton::FuncOp funcOp) {
LayoutPropagation layoutPropagation(funcOp);
@@ -1249,7 +1063,6 @@ public:
// 3. For converts left try to hoist them above cast generating larger size
// types in order to reduce the cost of the convert op.
hoistConvert(m);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
mlir::RewritePatternSet decomposePatterns(context);
decomposePatterns.add<DecomposeDotOperand>(context);

View File

@@ -401,12 +401,13 @@ void LoopPipeliner::createBufferTypes() {
ty.getShape().end());
Type eType = ty.getElementType();
auto blockedEnc = ty.getEncoding().cast<ttg::BlockedEncodingAttr>();
auto CTALayout = ttg::getCTALayout(ty.getEncoding());
// unsigned bitWidth = dotOpEnc.getMMAv2kWidth()
// ? 32 / dotOpEnc.getMMAv2kWidth()
// : ty.getElementType().getIntOrFloatBitWidth();
auto sharedEnc =
ttg::SharedEncodingAttr::get(ty.getContext(), dotOpEnc, ty.getShape(),
ttg::getOrder(ty.getEncoding()), eType);
ttg::getOrder(ty.getEncoding()), CTALayout, eType);
loadsBufferType[loadOp] = RankedTensorType::get(bufferShape, eType, sharedEnc);
}
}

View File

@@ -1,5 +1,6 @@
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
@@ -58,11 +59,7 @@ struct NVVMMetadata {
// Add the nvvm related metadata to LLVM IR.
static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
<<<<<<< HEAD
bool isROCM, const int threadsPerCTA) {
=======
Target target) {
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
Target target, const int threadsPerCTA) {
auto *module = func->getParent();
auto &ctx = func->getContext();
@@ -92,18 +89,8 @@ static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
}
if (metadata.isKernel) {
<<<<<<< HEAD
if (isROCM) {
func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
func->addFnAttr("amdgpu-flat-work-group-size",
"1, " + std::to_string(threadsPerCTA));
func->addFnAttr("denormal-fp-math-f32", "preserve-sign");
func->addFnAttr("amdgpu-unsafe-fp-atomics", "true");
} else {
=======
switch (target) {
case Target::NVVM: {
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
llvm::Metadata *mdArgs[] = {
llvm::ValueAsMetadata::get(func), llvm::MDString::get(ctx, "kernel"),
llvm::ValueAsMetadata::get(
@@ -113,7 +100,10 @@ static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
} break;
case Target::ROCDL: {
func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
func->addFnAttr("amdgpu-flat-work-group-size", "1, 1024");
func->addFnAttr("amdgpu-flat-work-group-size",
"1, " + std::to_string(threadsPerCTA));
func->addFnAttr("denormal-fp-math-f32", "preserve-sign");
func->addFnAttr("amdgpu-unsafe-fp-atomics", "true");
} break;
}
}
@@ -341,11 +331,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
for (auto &func : llvmModule->functions()) {
auto it = nvvmMetadata.find(func.getName());
if (it != nvvmMetadata.end())
<<<<<<< HEAD
amendLLVMFunc(&func, it->second, isROCM, threadsPerCTA);
=======
amendLLVMFunc(&func, it->second, target);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
amendLLVMFunc(&func, it->second, target, threadsPerCTA);
}
return llvmModule;
@@ -379,7 +365,9 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
pm.addPass(mlir::createConvertIndexToLLVMPass());
pm.addPass(
createConvertTritonGPUToLLVMPass({computeCapability, &tmaInfos, target}));
#ifndef USE_ROCM
pm.addPass(createConvertNVGPUToLLVMPass());
#endif
pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createCanonicalizerPass());
// Simplify the IR

View File

@@ -36,11 +36,8 @@
#include "triton/Target/HSACO/HSACOTranslation.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/Target/PTX/PTXTranslation.h"
<<<<<<< HEAD
#include "triton/Target/HSACO/HSACOTranslation.h"
=======
#include "triton/Target/PTX/TmaMetadata.h"
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
#include "triton/Tools/Sys/GetEnv.hpp"
#include "triton/Tools/Sys/GetPlatform.hpp"

View File

@@ -1005,7 +1005,7 @@ def deserialize_fp8(np_data, in_dtype):
# return np_data
@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4, tl.float8e5])
@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4b15x4, tl.float8e4nv, tl.float8e5])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32])
def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
"""
@@ -1056,9 +1056,9 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
[32, 32, 128],
[128, 128, 64],
[64, 128, 128]]
for ab_type in [[tl.float8e4, tl.float16],
for ab_type in [[tl.float8e4nv, tl.float16],
[tl.float8e5, tl.float16],
[tl.float16, tl.float8e4],
[tl.float16, tl.float8e4nv],
[tl.float16, tl.float8e5]]
for out_dtype in [torch.float16, torch.float32]
])

View File

@@ -53,14 +53,8 @@ def compile_fn(config, device_type, cc):
def test_compile_in_subproc() -> None:
<<<<<<< HEAD
cc, device_type = get_device_type()
config = instance_descriptor(tuple(range(4)), ())
=======
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = instance_descriptor(tuple(range(4)), (), (), ())
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
multiprocessing.set_start_method('fork')
proc = multiprocessing.Process(
@@ -92,14 +86,8 @@ def compile_fn_dot(config, device_type, cc):
def test_compile_in_forked_subproc() -> None:
reset_tmp_dir()
<<<<<<< HEAD
cc, device_type = get_device_type()
config = instance_descriptor(tuple(range(1)), ())
=======
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = instance_descriptor(tuple(range(1)), (), (), ())
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
assert multiprocessing.get_start_method() == 'fork'
proc = multiprocessing.Process(

View File

@@ -73,17 +73,13 @@ def optimize_ttir(mod, arch):
return mod
<<<<<<< HEAD
def ttir_to_ttgir(mod, num_warps, warpsize):
def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, arch):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize)
=======
def ttir_to_ttgir(mod, num_warps, num_ctas, arch):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_convert_triton_to_tritongpu_pass(num_warps, 32, num_ctas, arch)
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
if is_hip():
pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize, num_ctas, 0)
else:
pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize, num_ctas, arch)
pm.run(mod)
return mod
@@ -109,26 +105,16 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch,
if optimize_epilogue:
pm.add_tritongpu_optimize_epilogue_pass()
pm.add_tritongpu_optimize_dot_operands_pass()
<<<<<<< HEAD
if num_stages == 0 and is_hip() and gpu_matrix_core_version() != 0:
pm.add_tritongpu_stream_pipeline_pass()
pm.add_canonicalizer_pass()
else:
pm.add_tritongpu_pipeline_pass(num_stages)
pm.add_tritongpu_prefetch_pass()
pm.add_tritongpu_optimize_dot_operands_pass()
pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_decompose_conversions_pass()
if num_stages != 0:
pm.add_tritongpu_reorder_instructions_pass()
=======
ws_enabled = False
# `num_warps` does not mean the total number of warps of a CTA when
# warp specialization is enabled.
# it's the responsibility of the compiler to figure out the exact
# `num_warps` to use.
# TODO: support the case where `num_warps` from user is not 4.
if arch // 10 >= 9 and enable_warp_specialization and num_warps == 4:
if _is_cuda(arch) and arch // 10 >= 9 and enable_warp_specialization and num_warps == 4:
pm.add_tritongpu_ws_feasibility_checking_pass(arch)
pm.run(mod)
ws_enabled = ir.is_ws_supported(mod)
@@ -142,20 +128,27 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch,
pm.add_tritongpu_wsmaterialization_pass(arch)
pm.add_cse_pass()
else:
pm.add_tritongpu_pipeline_pass(
num_stages, num_warps, num_ctas, arch)
pm.add_tritongpu_materialize_load_store_pass(num_warps, arch)
if arch // 10 <= 8:
if is_hip():
pm.add_tritongpu_pipeline_pass(
num_stages, num_warps, num_ctas, 0)
else:
pm.add_tritongpu_pipeline_pass(
num_stages, num_warps, num_ctas, arch)
if is_hip():
pm.add_tritongpu_materialize_load_store_pass(num_warps, 0)
else:
pm.add_tritongpu_materialize_load_store_pass(num_warps, arch)
if _is_cuda(arch) and arch // 10 <= 8:
pm.add_tritongpu_prefetch_pass()
pm.add_tritongpu_optimize_dot_operands_pass()
pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_decompose_conversions_pass()
pm.add_tritongpu_ws_fixup_missing_attrs_pass()
pm.add_tritongpu_reorder_instructions_pass()
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
if num_stages != 0:
pm.add_tritongpu_reorder_instructions_pass()
pm.add_cse_pass()
pm.add_symbol_dce_pass()
if arch // 10 >= 9:
if _is_cuda(arch) and arch // 10 >= 9:
pm.add_tritongpu_fence_insertion_pass()
pm.add_tritongpu_ws_fixup_missing_attrs_pass()
pm.run(mod)
@@ -475,10 +468,6 @@ def compile(fn, **kwargs):
warp_size = CUDA_DEFAULT_WARP_SIZE if _is_cuda(arch) else arch[3]
context = ir.context()
constants = kwargs.get("constants", dict())
<<<<<<< HEAD
num_warps = kwargs.get("num_warps", 4)
num_stages = kwargs.get("num_stages", 3 if is_cuda and arch >= 75 else (1 if is_hip else 2))
=======
num_warps = kwargs.get("num_warps", get_arch_default_num_warps(device_type))
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
num_ctas = kwargs.get("num_ctas", 1)
@@ -487,7 +476,6 @@ def compile(fn, **kwargs):
enable_warp_specialization = kwargs.get("enable_warp_specialization", False)
# TODO[shuhaoj]: persistent can be decoupled with warp specialization
enable_persistent = kwargs.get("enable_persistent", enable_warp_specialization)
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
extern_libs = kwargs.get("extern_libs", dict())
if extern_libs is None:
extern_libs = dict()
@@ -509,11 +497,7 @@ def compile(fn, **kwargs):
stages["ttir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
<<<<<<< HEAD
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size), num_stages, arch))
=======
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
stages["llir"] = (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos))
if is_cuda:
@@ -584,11 +568,8 @@ def compile(fn, **kwargs):
InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']]
else:
metadata = {"num_warps": num_warps,
<<<<<<< HEAD
"warp_size": warp_size,
=======
"num_ctas": num_ctas,
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
"num_stages": num_stages,
"enable_warp_specialization": enable_warp_specialization,
"enable_persistent": enable_persistent,
@@ -705,11 +686,8 @@ class CompiledKernel:
# initialize metadata
self.shared = metadata["shared"]
self.num_warps = metadata["num_warps"]
<<<<<<< HEAD
self.warp_size = metadata["warp_size"]
=======
self.num_ctas = metadata["num_ctas"]
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
self.num_stages = metadata["num_stages"]
self.clusterDims = metadata["clusterDims"]
if "tensormaps_info" in metadata:

View File

@@ -104,156 +104,145 @@ def generate_launcher(constants, signature, ids):
# generate glue code
if is_hip():
src = f"""
#define __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#include <Python.h>
#include <stdio.h>
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)]
src = f"""
#define __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#include <Python.h>
#include <stdbool.h>
#include <dlfcn.h>
static inline void gpuAssert(hipError_t code, const char *file, int line)
{{
if (code != HIP_SUCCESS)
{{
const char* prefix = "Triton Error [HIP]: ";
const char* str = hipGetErrorString(code);
char err[1024] = {{0}};
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
PyErr_SetString(PyExc_RuntimeError, err);
}}
static inline void gpuAssert(hipError_t code, const char *file, int line)
{{
if (code != HIP_SUCCESS)
{{
const char* prefix = "Triton Error [HIP]: ";
const char* str = hipGetErrorString(code);
char err[1024] = {{0}};
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
PyErr_SetString(PyExc_RuntimeError, err);
}}
}}
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
// printf("_launch hip kernel\\n");
void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }};
if (gridX*gridY*gridZ > 0) {{
HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, 64*num_warps, 1, 1, shared_memory, stream, params, 0));
}}
}}
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
typedef struct _DevicePtrInfo {{
hipDeviceptr_t dev_ptr;
bool valid;
}} DevicePtrInfo;
static int getWarpSize(hipStream_t stream)
{{
int device_id = hipGetStreamDeviceId(stream);
gpuAssert(device_id >= 0 ? hipSuccess : hipErrorInvalidDevice, __FILE__, __LINE__);
hipDeviceProp_t prop;
HIP_CHECK(hipGetDeviceProperties(&prop, device_id));
return prop.warpSize;
}}
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, hipStream_t stream, hipFunction_t function, {arg_decls}) {{
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
if (gridX*gridY*gridZ > 0) {{
int warp_size = getWarpSize(stream);
HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, num_warps * warp_size, 1, 1, shared_memory, stream, params, 0));
}}
}}
typedef struct _DevicePtrInfo {{
hipDeviceptr_t dev_ptr;
bool valid;
}} DevicePtrInfo;
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
DevicePtrInfo ptr_info;
ptr_info.dev_ptr = 0;
ptr_info.valid = true;
if (PyLong_Check(obj)) {{
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
return ptr_info;
}}
if (obj == Py_None) {{
// valid nullptr
return ptr_info;
}}
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
if (ptr) {{
PyObject *empty_tuple = PyTuple_New(0);
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
Py_DECREF(empty_tuple);
Py_DECREF(ptr);
if (!PyLong_Check(ret)) {{
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
ptr_info.valid = false;
return ptr_info;
}}
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
if (!ptr_info.dev_ptr)
return ptr_info;
uint64_t dev_ptr;
hipError_t status = hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
if (status == hipErrorInvalidValue) {{
PyErr_Format(PyExc_ValueError,
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
ptr_info.valid = false;
}}
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
return ptr_info;
}}
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
DevicePtrInfo ptr_info;
ptr_info.dev_ptr = 0;
ptr_info.valid = true;
if (PyLong_Check(obj)) {{
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
return ptr_info;
}}
if (obj == Py_None) {{
// valid nullptr
return ptr_info;
}}
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
if(ptr){{
PyObject *empty_tuple = PyTuple_New(0);
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
Py_DECREF(empty_tuple);
Py_DECREF(ptr);
if (!PyLong_Check(ret)) {{
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
ptr_info.valid = false;
return ptr_info;
}}
static PyObject* launch(PyObject* self, PyObject* args) {{
int gridX, gridY, gridZ;
uint64_t _stream;
uint64_t _function;
int num_warps;
int shared_memory;
PyObject *launch_enter_hook = NULL;
PyObject *launch_exit_hook = NULL;
PyObject *compiled_kernel = NULL;
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
if (!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel{', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{
return NULL;
}}
if (launch_enter_hook != Py_None && !PyObject_CallObject(launch_enter_hook, args)) {{
return NULL;
}}
// raise exception asap
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
Py_BEGIN_ALLOW_THREADS;
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''});
Py_END_ALLOW_THREADS;
if (launch_exit_hook != Py_None && !PyObject_CallObject(launch_exit_hook, args)) {{
return NULL;
}}
// return None
Py_INCREF(Py_None);
return Py_None;
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
if(!ptr_info.dev_ptr)
return ptr_info;
uint64_t dev_ptr;
hipError_t status = hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
if (status == hipErrorInvalidValue) {{
PyErr_Format(PyExc_ValueError,
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
ptr_info.valid = false;
}}
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
Py_DECREF(ret);
return ptr_info;
}}
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
return ptr_info;
}}
static PyMethodDef ModuleMethods[] = {{
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
{{NULL, NULL, 0, NULL}} // sentinel
}};
static PyObject* launch(PyObject* self, PyObject* args) {{
// printf("launch\\n");
int gridX, gridY, gridZ;
uint64_t _stream;
uint64_t _function;
int num_warps;
int num_ctas;
int clusterDimX;
int clusterDimY;
int clusterDimZ;
int shared_memory;
PyObject *launch_enter_hook = NULL;
PyObject *launch_exit_hook = NULL;
PyObject *compiled_kernel = NULL;
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &num_ctas, &clusterDimX, &clusterDimY, &clusterDimZ, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel{', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{
return NULL;
}}
static struct PyModuleDef ModuleDef = {{
PyModuleDef_HEAD_INIT,
\"__triton_launcher\",
NULL, //documentation
-1, //size
ModuleMethods
}};
if (launch_enter_hook != Py_None) {{
PyObject_CallObject(launch_enter_hook, args);
}}
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {{
return NULL;
}}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}}
"""
// raise exception asap
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
_launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''});
if (launch_exit_hook != Py_None) {{
PyObject_CallObject(launch_exit_hook, args);
}}
if(PyErr_Occurred()) {{
return NULL;
}}
// return None
Py_INCREF(Py_None);
return Py_None;
}}
static PyMethodDef ModuleMethods[] = {{
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
{{NULL, NULL, 0, NULL}} // sentinel
}};
static struct PyModuleDef ModuleDef = {{
PyModuleDef_HEAD_INIT,
\"__triton_launcher\",
NULL, //documentation
-1, //size
ModuleMethods
}};
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {{
return NULL;
}}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}}
"""
else:
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)]
@@ -279,12 +268,6 @@ static inline void gpuAssert(CUresult code, const char *file, int line)
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
<<<<<<< HEAD
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
if(gridX*gridY*gridZ > 0){{
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, num_warps * 32, 1, 1, shared_memory, stream, params, 0));
=======
typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra);
static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
@@ -336,7 +319,6 @@ static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas
}}
CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
}}
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
}}
}}

View File

@@ -1324,13 +1324,8 @@ def dot(lhs: tl.tensor,
assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
assert lhs.type.is_block() and rhs.type.is_block()
<<<<<<< HEAD
assert lhs.dtype == rhs.dtype or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp16()) or (lhs.type.scalar.is_fp16() and rhs.type.scalar.is_fp8()), f"First input ({lhs.dtype}) and second input ({rhs.dtype}) must have the same dtype!"
=======
assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.arch)
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!"
assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!"
assert lhs.shape[1].value == rhs.shape[0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})"

View File

@@ -363,17 +363,10 @@ class JITFunction(KernelInterface[T]):
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
src = f"""
<<<<<<< HEAD
def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
from ..compiler import compile, CompiledKernel
sig_key = {sig_keys},
=======
import triton
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages
sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()}
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
assert num_ctas > 0

View File

@@ -90,11 +90,7 @@ def _fwd_kernel(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)
if IS_CAUSAL:
qk = tl.where(P_SEQ + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
<<<<<<< HEAD
qk += tl.dot(q, k)
=======
qk += tl.dot(q, k, out_dtype=tl.float16)
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
alpha = tl.math.exp2(m_i - m_i_new)
@@ -168,17 +164,11 @@ def _bwd_kernel(
V += off_z * stride_vz + off_h * stride_vh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
<<<<<<< HEAD
DK += off_z * stride_qz + off_h * stride_qh
DV += off_z * stride_qz + off_h * stride_qh
# See fwd pass above for explanation.
qk_scale = sm_scale * 1.44269504
for start_n in range(0, num_block):
=======
DK += off_z * stride_kz + off_h * stride_kh
DV += off_z * stride_vz + off_h * stride_vh
# See fwd pass above for explanation.
qk_scale = sm_scale * 1.44269504
for start_n in range(0, num_block_kv):
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
if CAUSAL:
lo = tl.math.max(start_n * BLOCK_M - P_SEQ, 0)
else:
@@ -466,7 +456,6 @@ class _attention(torch.autograd.Function):
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
BLOCK_M = 128
<<<<<<< HEAD
if torch.version.hip is None:
BLOCK_N = 64 if Lk <= 64 else 32
num_stages = 4 if Lk <= 64 else 3
@@ -480,14 +469,6 @@ class _attention(torch.autograd.Function):
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2]
=======
BLOCK_N = 64 if Lk <= 64 else 32
num_stages = 4 if Lk <= 64 else 3
num_warps = 4
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2]
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
_fwd_kernel[grid](
q, k, v, sm_scale,
L,
@@ -507,10 +488,7 @@ class _attention(torch.autograd.Function):
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
ctx.causal = causal
<<<<<<< HEAD
ctx.split_kernel = split_kernel
=======
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
ctx.P_SEQ = P_SEQ
return o
@@ -538,28 +516,8 @@ class _attention(torch.autograd.Function):
block_scale = (q.shape[2] // ctx.grid[0]) // BLOCK
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do,
<<<<<<< HEAD
do_scaled, delta,
BLOCK_M=block_scale * BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
=======
delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do,
dq, dk, dv,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2], ctx.P_SEQ,
ctx.grid[0], triton.cdiv(k.shape[2], BLOCK),
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
CAUSAL=ctx.causal,
num_stages=1,
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
)
if not ctx.split_kernel:
_bwd_kernel[(ctx.grid[1],)](
@@ -611,7 +569,6 @@ class _attention(torch.autograd.Function):
attention = _attention.apply
<<<<<<< HEAD
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, P_SEQ',
[(4, 48, 1024, 64, 128),
(4, 48, 2048, 64, 128),
@@ -621,16 +578,10 @@ attention = _attention.apply
])
@pytest.mark.parametrize('causal', [False, True])
def test_op_fwd(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16):
=======
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, P_SEQ', [(6, 9, 1024, 64, 128)])
@pytest.mark.parametrize('causal', [False, True])
def test_op(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16):
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
<<<<<<< HEAD
sm_scale = q.shape[-1] ** (-0.5)
dout = torch.randn_like(q)
# reference implementation
@@ -663,12 +614,6 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, P_SEQ, dtype=torch.float16):
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ)
=======
sm_scale = 0.5
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ)
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
if causal:
p[:, :, M == 0] = float("-inf")

View File

@@ -1,6 +1,6 @@
set -o xtrace
alias drun='sudo docker run -it --rm --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined'
DRUN='sudo docker run -it --rm --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined'
# DEVICES="--gpus all"
DEVICES="--device=/dev/kfd --device=/dev/dri"
@@ -21,7 +21,7 @@ CONTAINER_NAME=triton
# start new container
docker stop $CONTAINER_NAME
docker rm $CONTAINER_NAME
CONTAINER_ID=$(drun -d -w $WORK_DIR --name $CONTAINER_NAME $MEMORY $VOLUMES $DEVICES $IMAGE_NAME)
CONTAINER_ID=$($DRUN -d -w $WORK_DIR --name $CONTAINER_NAME $MEMORY $VOLUMES $DEVICES $IMAGE_NAME)
echo "CONTAINER_ID: $CONTAINER_ID"
# docker cp . $CONTAINER_ID:$WORK_DIR
# docker exec $CONTAINER_ID bash -c "bash scripts/amd/run.sh"

View File

@@ -1,11 +1,7 @@
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s
tt.func @ops() {
<<<<<<< HEAD
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {{.*}}
=======
// CHECK: module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {{.*}}
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
%c = arith.constant dense<3.00e+00> : tensor<128x128xf32>
@@ -37,17 +33,10 @@ tt.func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
tt.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
// Test if the total number of threadsPerWarp is 64
// Test if the total number of warps is 2
<<<<<<< HEAD
// CHECK: #[[blocked0:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 16], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: #[[blocked1:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: #[[blocked2:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {{.*}}
=======
// CHECK: #[[blocked0:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
// CHECK: #[[blocked1:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
// CHECK: #[[blocked2:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
// CHECK: module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {{.*}}
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
%c0 = arith.constant dense<1.00e+00> : tensor<4x4xf32>
%c1 = arith.constant dense<2.00e+00> : tensor<8x2xf32>
%c2 = arith.constant dense<3.00e+00> : tensor<16x16xf32>

View File

@@ -1,17 +1,9 @@
<<<<<<< HEAD
// RUN: not triton-opt %s -split-input-file --convert-triton-gpu-to-llvm --mlir-pass-pipeline-crash-reproducer=%t 2>/dev/null | FileCheck --check-prefixes=CHECK,GCN %s
=======
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=nvvm" | FileCheck %s
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
// RUN: not triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=rocdl" --mlir-pass-pipeline-crash-reproducer=%t 2>/dev/null | FileCheck --check-prefixes=CHECK,GCN %s
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<f16, 1>)
// Here the 128 comes from the 4 in module attribute multiples 32
<<<<<<< HEAD
// PTX: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = [128 : i32]} {{.*}}
=======
// CHECK: nvvm.kernel = 1 : ui1, nvvm.maxntid = [128 : i32]
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
// PTX: nvvm.kernel = 1 : ui1, nvvm.maxntid = [128 : i32]
tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
// CHECK: llvm.return
tt.return
@@ -711,11 +703,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_program_id
tt.func @basic_program_id() {
<<<<<<< HEAD
// PTX: nvvm.read.ptx.sreg.ctaid.x : i32
=======
// CHECK: llvm.inline_asm asm_dialect = att operand_attrs = [] "mov.u32 $0, %ctaid.x;", "=r" : () -> i32
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
%0 = tt.get_program_id x : i32
tt.return
}
@@ -788,15 +776,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// -----
<<<<<<< HEAD
module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// PTX-LABEL: basic_async_wait
// This test is disabled for GCN target, because it is PTX specific
// GCN-NOT: basic_async_wait
=======
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_async_wait
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
tt.func @basic_async_wait() {
// PTX: cp.async.wait_group 0x4
triton_gpu.async_wait {num = 4: i32}
@@ -947,14 +930,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%tensor = triton_gpu.alloc_tensor : tensor<2x16x32xf32, #A>
%index = arith.constant 1 : i32
<<<<<<< HEAD
// This test is PTX specific, GCN targets decompose async operations into oridinary load/stores.
// TODO: Fix AMD compilation.
// last operation (commit_group) is still emitted by AMD pipeline,
// It is left to catch changes in AMD compilation pipeline.
// PTX: llvm.inline_asm
// PTX-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// PTX: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// PTX: llvm.inline_asm
// PTX-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// PTX: llvm.inline_asm
@@ -978,18 +960,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32>
// GCN-COUNT-4: llvm.store {{.*}} : !llvm.ptr<vector<1xf32>, 3>
// GCN: llvm.inline_asm {{.*}}cp.async.commit_group
=======
// CHECK: llvm.inline_asm
// CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.commit_group
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x32x!tt.ptr<f32>, #AL> -> tensor<2x16x32xf32, #A>
triton_gpu.async_commit_group
tt.return
@@ -1224,7 +1194,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
#mma0 = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=2}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}>
<<<<<<< HEAD
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// PTX-LABEL: convert_dot
// This test is not relevant to GCN target, because it is PTX specific
@@ -1232,20 +1201,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
%AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
%BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
// PTX: llvm.inline_asm
// PTX-SAME: ldmatrix.sync.aligned.m8n8.x4
// PTX: ldmatrix.sync.aligned.m8n8.x4
// PTX: llvm.inline_asm
// PTX-SAME: ldmatrix.sync.aligned.m8n8.x4
=======
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot
tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
%AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
%BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
// CHECK: llvm.inline_asm
// CHECK: ldmatrix.sync.aligned.m8n8.x4
// CHECK: llvm.inline_asm
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
%AA_DOT = triton_gpu.convert_layout %AA : (tensor<16x16xf16, #shared0>) -> tensor<16x16xf16, #dot_operand_a>
%BB_DOT = triton_gpu.convert_layout %BB : (tensor<16x16xf16, #shared0>) -> tensor<16x16xf16, #dot_operand_b>
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
@@ -1271,7 +1229,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// -----
<<<<<<< HEAD
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0]}>
#mfma0 = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA=[1,1], isTranspose=false}>
@@ -1312,19 +1269,12 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
#mma = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [2, 2]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// PTX: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// PTX-LABEL: convert_layout_mmav2_block
// This test is not relevant to GCN target, because it is PTX specific
=======
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mma = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_mmav2_block
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
// PTX: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// PTX-LABEL: convert_layout_mmav2_block
// This test is not relevant to GCN target, because it is PTX specific
tt.func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) {
// PTX: llvm.store
// PTX-SAME: !llvm.ptr<vector<2xf32>, 3>
@@ -1340,20 +1290,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// -----
<<<<<<< HEAD
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#mma = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// PTX: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// PTX-LABEL: convert_layout_mmav1_block
// This test is not relevant to GCN target, because it is PTX specific
=======
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mma = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 16]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_mmav1_block
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
// PTX: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// PTX-LABEL: convert_layout_mmav1_block
// This test is not relevant to GCN target, because it is PTX specific
tt.func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) {
// PTX: llvm.store
// PTX-SAME: !llvm.ptr<vector<2xf32>, 3>
@@ -1438,14 +1380,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
#mma = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=2}>
<<<<<<< HEAD
module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// PTX-LABEL: matmul_kernel_dot_operand_layout
// This test is disabled for GCN target, because it is PTX specific
// This test is not relevant to GCN target, because it is PTX specific
=======
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
tt.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
@@ -1465,7 +1403,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// -----
<<<<<<< HEAD
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false}>
@@ -1494,16 +1431,6 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0]}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
#mma = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// PTX-LABEL: matmul884_kernel_dot_operand_layout
// This test is not relevant to GCN target, because it is PTX specific
=======
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
@@ -1511,7 +1438,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
// PTX-LABEL: matmul884_kernel_dot_operand_layout
// This test is not relevant to GCN target, because it is PTX specific
tt.func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<32x64xf16, #shared0>, %b:tensor<64x64xf16, #shared1>) {
%cst = arith.constant dense<0.000000e+00> : tensor<32x64xf32, #mma>
@@ -1557,14 +1485,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=1}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=1}>
<<<<<<< HEAD
module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// PTX-LABEL: matmul_tf32dot
// This test is not relevant to GCN target, because it is PTX specific
=======
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: matmul_tf32dot
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
@@ -1617,20 +1540,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_f32_scalar
tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
<<<<<<< HEAD
// GCN-NOT: llvm.inline_asm
// GCN: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f32, 1>, f32
// PTX: llvm.icmp "eq"
// PTX: llvm.inline_asm
// PTX: llvm.inline_asm
// PTX-SAME: @$3 atom.global.gpu.add.f32
// PTX-SAME: @$3 atom.global.gpu.relaxed.add.f32
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32, sem = 1: i32} : (!tt.ptr<f32>, f32, i1) -> f32
=======
// CHECK: llvm.icmp "eq"
// CHECK: llvm.inline_asm
// CHECK-SAME: @$3 atom.global.gpu.relaxed.add.f32
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32, sem = 1 : i32} : (!tt.ptr<f32>, f32, i1) -> f32
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
tt.return
}
}
@@ -1699,18 +1614,9 @@ tt.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
%blockidx = tt.get_program_id x: i32
%blockidy = tt.get_program_id y: i32
%blockidz = tt.get_program_id z : i32
<<<<<<< HEAD
// PTX: nvvm.read.ptx.sreg.ctaid.x
// PTX: nvvm.read.ptx.sreg.ctaid.y
// PTX: nvvm.read.ptx.sreg.ctaid.z
// GCN: rocdl.workgroup.id.x
// GCN: rocdl.workgroup.id.y
// GCN: rocdl.workgroup.id.z
=======
// CHECK: clusterid.x
// CHECK: clusterid.y
// CHECK: clusterid.z
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
%v0 = arith.addi %blockidx, %blockidy : i32
%v1 = arith.addi %v0, %blockidz : i32
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>
@@ -1747,15 +1653,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 : i32} {
tt.func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
<<<<<<< HEAD
// PTX: nvvm.read.ptx.sreg.nctaid.x
// PTX: nvvm.read.ptx.sreg.nctaid.y
// PTX: nvvm.read.ptx.sreg.nctaid.z
// GCN: rocdl.grid.dim.x
// GCN: rocdl.grid.dim.y
// GCN: rocdl.grid.dim.z
=======
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
%blockdimx = tt.get_num_programs {axis=0:i32} : i32
%blockdimy = tt.get_num_programs {axis=1:i32} : i32
%blockdimz = tt.get_num_programs {axis=2:i32} : i32
@@ -1887,14 +1784,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
}
// -----
<<<<<<< HEAD
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
=======
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
// CHECK-LABEL: test_s8_to_bf16_conversion
tt.func @test_s8_to_bf16_conversion(%in: tensor<32xi8, #blocked>) {
// We can't vectorize if we only process
@@ -1907,12 +1799,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
}
// -----
<<<<<<< HEAD
#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1]}>
=======
#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
#dot = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: test_s8_to_bf16_vectorized_conversion
@@ -1935,25 +1823,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// -----
<<<<<<< HEAD
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [2, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: atomic_add_f16
tt.func @atomic_add_f16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: f16 {tt.difisibility = 16 : i32}) {
%c1_i1 = arith.constant 1 : i1
%1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%2 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x32xi32, #blocked>
%3 = tt.broadcast %2 : (tensor<1x32xi32, #blocked>) -> tensor<32x32xi32, #blocked>
%4 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<32x32x!tt.ptr<f16>, #blocked>
%5 = tt.addptr %4, %3 : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
%6 = tt.splat %arg1 : (f16) -> tensor<32x32xf16, #blocked>
%7 = tt.splat %c1_i1 : (i1) -> tensor<32x32xi1, #blocked>
// PTX: llvm.inline_asm
// PTX-SAME: @$3 atom.global.gpu.add.noftz.f16x2
// GCN-COUNT-8: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f16, 1>, f16
%8 = "tt.atomic_rmw"(%5, %6, %7) {atomic_rmw_op = 5 : i32, sem = 1: i32} : (tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xf16, #blocked>, tensor<32x32xi1, #blocked>) -> tensor<32x32xf16, #blocked>
=======
// CHECK-LABEL: sum_reduction
// CHECK: %[[M:.+]] = llvm.mlir.constant(-1 : i32) : i32
// CHECK: nvvm.redux.sync add %{{.*}}, %[[M]]
@@ -1986,14 +1855,12 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c
%13 = tt.splat %arg1 : (!tt.ptr<i32, 1>) -> tensor<1x!tt.ptr<i32, 1>, #blocked1>
%14 = tt.addptr %13, %0 : tensor<1x!tt.ptr<i32, 1>, #blocked1>, tensor<1xi32, #blocked1>
tt.store %14, %12 {cache = 1 : i32, evict = 1 : i32} : tensor<1xi32, #blocked1>
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
tt.return
}
}
// -----
<<<<<<< HEAD
=======
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#slice = #triton_gpu.slice<{dim = 1, parent = #blocked}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} {
@@ -2047,4 +1914,26 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c
tt.return
}
}
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [2, 1], order = [1, 0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: atomic_add_f16
tt.func @atomic_add_f16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: f16 {tt.difisibility = 16 : i32}) {
%c1_i1 = arith.constant 1 : i1
%1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%2 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x32xi32, #blocked>
%3 = tt.broadcast %2 : (tensor<1x32xi32, #blocked>) -> tensor<32x32xi32, #blocked>
%4 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<32x32x!tt.ptr<f16>, #blocked>
%5 = tt.addptr %4, %3 : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
%6 = tt.splat %arg1 : (f16) -> tensor<32x32xf16, #blocked>
%7 = tt.splat %c1_i1 : (i1) -> tensor<32x32xi1, #blocked>
// PTX: llvm.inline_asm
// PTX-SAME: @$3 atom.global.gpu.add.noftz.f16x2
// GCN-COUNT-8: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f16, 1>, f16
%8 = "tt.atomic_rmw"(%5, %6, %7) {atomic_rmw_op = 5 : i32, sem = 1: i32} : (tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xf16, #blocked>, tensor<32x32xi1, #blocked>) -> tensor<32x32xf16, #blocked>
tt.return
}
}

View File

@@ -8,16 +8,9 @@
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
<<<<<<< HEAD
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
=======
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr<f32, 1>, [[row_layout]]>
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
// CHECK: [[load_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[row_layout]]>
// CHECK: [[load_other:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[row_layout]]>
// CHECK: [[load_val:%.*]] = tt.load [[load_ptr]], [[load_mask]], [[load_other]] {{.*}} : tensor<64x64xf32, [[row_layout]]>

View File

@@ -3,16 +3,15 @@ add_triton_ut(
SRCS PTXAsmFormatTest.cpp
LIBS TritonGPUToLLVM
)
<<<<<<< HEAD
add_triton_ut(
NAME TestGcnAsmFormat
SRCS GcnAsmFormatTest.cpp
LIBS TritonGPUToLLVM
=======
)
add_triton_ut(
NAME TestEmitIndices
SRCS EmitIndicesTest.cpp DumpLayout.cpp
LIBS TritonGPUIR TritonNvidiaGPUIR ${dialect_libs} ${conversion_libs}
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
)