Fix merge conflicts

This commit is contained in:
Jason Furmanek
2023-09-01 04:01:32 +00:00
parent 3eaeb89d18
commit df5c263a19
28 changed files with 127 additions and 1235 deletions

View File

@@ -8,15 +8,11 @@ namespace triton {
std::unique_ptr<Pass> createCombineOpsPass();
<<<<<<< HEAD
std::unique_ptr<Pass> createRewriteTensorPointerPass(int computeCapability = 80,
bool isROCM = false);
=======
std::unique_ptr<Pass> createReorderBroadcastPass();
std::unique_ptr<Pass>
createRewriteTensorPointerPass(int computeCapability = 80);
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
createRewriteTensorPointerPass(int computeCapability = 80,
bool isROCM = false);
} // namespace triton

View File

@@ -31,13 +31,10 @@ SmallVector<unsigned> getElemsPerThread(Type type);
// getThreadsPerWarpWithUniqueData.
SmallVector<unsigned> getThreadsPerWarp(Attribute layout);
<<<<<<< HEAD
unsigned getWarpSize(Attribute layout);
=======
// Returns the number of warps per CTA that may have access to replicated
// elements. If you want non-replicated warps, use getWarpsPerCTAWithUniqueData.
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
SmallVector<unsigned> getWarpsPerCTA(Attribute layout);
SmallVector<unsigned> getSizePerThread(Attribute layout);

View File

@@ -77,8 +77,7 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
"ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
<<<<<<< HEAD
"Type":$eltTy), [{
"unsigned":$typeWidthInBit), [{
#ifdef USE_ROCM
// ---- begin GFX908/GFX90A ----
@@ -93,9 +92,8 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
const int SIMDWidth = 16;
// number of inner dimension rows per one pattern repeat
int typeBitWidth = eltTy.getIntOrFloatBitWidth();
int innerDimLength = shape[order[0]];
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeBitWidth;
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit;
int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
// Note: the following settings is customized to avoid
@@ -111,7 +109,7 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
// 4. TODO: what about f64?
//
// maxPhase is set to SIMDWidth / perPhase
int vecSize = (eltTy.isF16() ? 64 : 32 ) / typeBitWidth;
int vecSize = ((typeWidthInBit == 16) ? 64 : 32 ) / typeWidthInBit;
int maxPhase = SIMDWidth / perPhase;
return $_get(context, vecSize, perPhase, maxPhase, order);
@@ -122,9 +120,6 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
}
}
#endif
=======
"unsigned":$typeWidthInBit), [{
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
if(!mmaEnc)
@@ -154,11 +149,11 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
// ---- begin Ampere ----
if (mmaEnc.isAmpere()) {
int perPhase = 128 / (shape[order[0]] * 4 / dotOpEnc.getMMAv2kWidth());
int perPhase = 128 / (shape[order[0]] * 4 / dotOpEnc.getKWidth());
perPhase = std::max<int>(perPhase, 1);
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getMMAv2kWidth()};
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
// for now, disable swizzle when using transposed int8 tensor cores
if ((32 / typeWidthInBit != dotOpEnc.getMMAv2kWidth()) && order[0] == inner)
if ((32 / typeWidthInBit != dotOpEnc.getKWidth()) && order[0] == inner)
return $_get(context, 1, 1, 1, order);
// --- handle A operand ---

View File

@@ -183,13 +183,10 @@ private:
ReduceOpHelper helper(reduceOp);
unsigned bytes = helper.getScratchSizeInBytes();
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
<<<<<<< HEAD
=======
} else if (auto scanOp = dyn_cast<triton::ScanOp>(op)) {
ScanLoweringHelper helper(scanOp);
unsigned bytes = helper.getScratchSizeInBytes();
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.getSrc().getType().cast<RankedTensorType>();
auto dstTy = cvtLayout.getResult().getType().cast<RankedTensorType>();

View File

@@ -94,11 +94,7 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
// that case doesn't need inter-warp communication
<<<<<<< HEAD
if (isFastReduction() && triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
=======
if (isWarpSynchronous())
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
return {{0, 0}, {0, 0}};
/// shared memory block0

View File

@@ -401,153 +401,6 @@ inline SmallVector<Value> packI32(const SmallVector<Value> &inValues,
return outValues;
}
<<<<<<< HEAD
struct FpToFpOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::FpToFpOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern;
typedef std::function<SmallVector<Value>(
Location, ConversionPatternRewriter &, const Value &, const Value &,
const Value &, const Value &)>
ConvertorT;
/* ------------------ */
// FP8 -> FP16
/* ------------------ */
static SmallVector<Value>
convertFp8x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter,
const char *ptxAsm, const Value &v0, const Value &v1,
const Value &v2, const Value &v3) {
auto ctx = rewriter.getContext();
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value fp8x4Vec = undef(fp8x4VecTy);
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v2, i32_val(2));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v3, i32_val(3));
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
PTXBuilder builder;
auto &ptxOp = *builder.create(ptxAsm);
auto *o0 = builder.newOperand("=r");
auto *o1 = builder.newOperand("=r");
auto *i = builder.newOperand(fp8x4Vec, "r");
ptxOp({o0, o1, i}, /*onlyAttachMLIRArgs=*/true);
auto fp16x2VecTy = vec_ty(f16_ty, 2);
auto fp16x2x2StructTy =
struct_ty(SmallVector<Type>{fp16x2VecTy, fp16x2VecTy});
auto fp16x2x2Struct =
builder.launch(rewriter, loc, fp16x2x2StructTy, false);
auto fp16x2Vec0 = extract_val(fp16x2VecTy, fp16x2x2Struct, 0);
auto fp16x2Vec1 = extract_val(fp16x2VecTy, fp16x2x2Struct, 1);
return {extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
extract_element(f16_ty, fp16x2Vec0, i32_val(1)),
extract_element(f16_ty, fp16x2Vec1, i32_val(0)),
extract_element(f16_ty, fp16x2Vec1, i32_val(1))};
}
static SmallVector<Value>
convertFp8E4M3x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
#ifdef USE_ROCM
auto ctx = rewriter.getContext();
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value a0 = undef(fp8x4VecTy);
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0));
a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1));
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2));
a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3));
a0 = bitcast(a0, i32_ty);
Value a1 = undef(fp8x4VecTy);
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0));
a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1));
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2));
a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3));
a1 = bitcast(a1, i32_ty);
Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
Value b1 = and_(i32_ty, a1, i32_val(0x7fff7fff));
b0 = lshr(i32_ty, b0, i32_val(1));
b1 = lshr(i32_ty, b1, i32_val(1));
b0 = or_( i32_ty, b0, and_(i32_ty, a0, i32_val(0x80008000)) );
b1 = or_( i32_ty, b1, and_(i32_ty, a1, i32_val(0x80008000)) );
auto fp16x2VecTy = vec_ty(f16_ty, 2);
auto fp16x2Vec0 = bitcast(b0, fp16x2VecTy);
auto fp16x2Vec1 = bitcast(b1, fp16x2VecTy);
return { extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
extract_element(f16_ty, fp16x2Vec0, i32_val(1)),
extract_element(f16_ty, fp16x2Vec1, i32_val(0)),
extract_element(f16_ty, fp16x2Vec1, i32_val(1))
};
#else
auto *ptxAsm = // WARN: subnormal (0bs0000xxx) are not handled
"{ \n"
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
"prmt.b32 a0, 0, $2, 0x5040; \n" // a0 = 0xf300f400
"prmt.b32 a1, 0, $2, 0x7060; \n" // a1 = 0xf100f200
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
"shr.b32 b0, b0, 1; \n" // b0 >>= 1
"shr.b32 b1, b1, 1; \n" // shift into fp16 position
"add.u32 b0, b0, 0x20002000; \n" // b0.exp += 2**4-2**3
// exponent compensate = 8
"add.u32 b1, b1, 0x20002000; \n" // b1 += 8<<10 | 8<<10<<16
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
"}";
return convertFp8x4ToFp16x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
#endif
}
static SmallVector<Value>
convertFp8E5M2x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
// exponent bias of Fp8E5M2 and Fp16 are the same
#ifdef USE_ROCM
auto ctx = rewriter.getContext();
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value a0 = undef(fp8x4VecTy);
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0));
a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1));
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2));
a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3));
a0 = bitcast(a0, i32_ty);
Value a1 = undef(fp8x4VecTy);
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0));
a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1));
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2));
a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3));
a1 = bitcast(a1, i32_ty);
auto fp16x2VecTy = vec_ty(f16_ty, 2);
auto fp16x2Vec0 = bitcast(a0, fp16x2VecTy);
auto fp16x2Vec1 = bitcast(a1, fp16x2VecTy);
return { extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
extract_element(f16_ty, fp16x2Vec0, i32_val(1)),
extract_element(f16_ty, fp16x2Vec1, i32_val(0)),
extract_element(f16_ty, fp16x2Vec1, i32_val(1))
};
#else
auto *ptxAsm = "{ \n"
"prmt.b32 $0, 0, $2, 0x5140; \n\t"
"prmt.b32 $1, 0, $2, 0x7362; \n\t"
"}";
return convertFp8x4ToFp16x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
#endif
}
=======
typedef std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
const Value &, const Value &,
const Value &, const Value &)>
@@ -555,7 +408,6 @@ typedef std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType,
Type outType) {
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
ConverterT converter = [ptxAsm, inType, outType](
Location loc, ConversionPatternRewriter &rewriter,
@@ -585,109 +437,6 @@ static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType,
for (Value inVal : inPacked)
operands.push_back(builder.newOperand(inVal, "r"));
auto &ptxOp = *builder.create(ptxAsm);
<<<<<<< HEAD
auto *o0 = builder.newOperand("=r");
auto *o1 = builder.newOperand("=r");
auto *i = builder.newOperand(fp8x4Vec, "r");
ptxOp({o0, o1, i}, /* onlyAttachMLIRArgs */ true);
auto bf16x2VecTy = vec_ty(i16_ty, 2);
auto bf16x2x2StructTy =
struct_ty(SmallVector<Type>{bf16x2VecTy, bf16x2VecTy});
auto bf16x2x2Struct =
builder.launch(rewriter, loc, bf16x2x2StructTy, false);
auto bf16x2Vec0 = extract_val(bf16x2VecTy, bf16x2x2Struct, 0);
auto bf16x2Vec1 = extract_val(bf16x2VecTy, bf16x2x2Struct, 1);
return {extract_element(i16_ty, bf16x2Vec0, i32_val(0)),
extract_element(i16_ty, bf16x2Vec0, i32_val(1)),
extract_element(i16_ty, bf16x2Vec1, i32_val(0)),
extract_element(i16_ty, bf16x2Vec1, i32_val(1))};
}
static SmallVector<Value>
convertFp8E4M3x4ToBf16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
#ifdef USE_ROCM
auto ctx = rewriter.getContext();
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value a0 = undef(fp8x4VecTy);
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0));
a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1));
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2));
a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3));
a0 = bitcast(a0, i32_ty);
Value a1 = undef(fp8x4VecTy);
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0));
a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1));
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2));
a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3));
a1 = bitcast(a1, i32_ty);
Value sign0 = and_(i32_ty, a0, i32_val(0x80008000));
Value sign1 = and_(i32_ty, a1, i32_val(0x80008000));
Value nosign0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
Value nosign1 = and_(i32_ty, a1, i32_val(0x7fff7fff));
nosign0 = lshr(i32_ty, nosign0, i32_val(4));
nosign1 = lshr(i32_ty, nosign1, i32_val(4));
nosign0 = add(i32_ty, nosign0, i32_val(0x38003800));
nosign1 = add(i32_ty, nosign1, i32_val(0x38003800));
auto bf16x2VecTy = vec_ty(i16_ty, 2);
Value bf16x2Vec0 = or_(i32_ty, sign0, nosign0);
Value bf16x2Vec1 = or_(i32_ty, sign1, nosign1);
bf16x2Vec0 = bitcast(bf16x2Vec0, bf16x2VecTy);
bf16x2Vec1 = bitcast(bf16x2Vec1, bf16x2VecTy);
return { extract_element(i16_ty, bf16x2Vec0, i32_val(0)),
extract_element(i16_ty, bf16x2Vec0, i32_val(1)),
extract_element(i16_ty, bf16x2Vec1, i32_val(0)),
extract_element(i16_ty, bf16x2Vec1, i32_val(1))
};
#else
auto *ptxAsm = // WARN: subnormal (0bs0000xxx) are not handled
"{ \n"
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
"prmt.b32 a0, 0, $2, 0x5040; \n" // a0 = 0xf300f400
"prmt.b32 a1, 0, $2, 0x7060; \n" // a1 = 0xf100f200
"and.b32 b0, a0, 0x7fff7fff; \n" // b0 = a0 & 0x7fff7fff
"and.b32 b1, a1, 0x7fff7fff; \n" // (strip sign)
"shr.b32 b0, b0, 4; \n" // b0 >>= 4
"shr.b32 b1, b1, 4; \n" // shift into fp16 position
"add.u32 b0, b0, 0x3c003c00; \n" // b0.exp += 2**7-2**3
// exponent compensate = 120
"add.u32 b1, b1, 0x3c003c00; \n" // b1 += 120<<7 | 120<<7<<16
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
"}";
return convertFp8x4ToBf16x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
#endif
};
static SmallVector<Value>
convertFp8E5M2x4ToBf16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto *ptxAsm = // WARN: subnormal (0bs00000xx) are not handled
"{ \n"
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
"shr.b32 b1, b1, 3; \n" // shift into bf16 position
"add.u32 b0, b0, 0x38003800; \n" // b0.exp += 2**7-2**4
// exponent compensate = 112
"add.u32 b1, b1, 0x38003800; \n" // b1 += 112<<7 | 112<<7<<16
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
"}";
return convertFp8x4ToBf16x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
=======
ptxOp(operands, /*onlyAttachMLIRArgs=*/true);
auto outVecTy = vec_ty(outType, outVecWidth);
SmallVector<Value> outPacked;
@@ -705,146 +454,13 @@ static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType,
ret.push_back(extract_element(outType, outPacked[i / outVecWidth],
i32_val(i % outVecWidth)));
return ret;
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
};
return converter;
}
<<<<<<< HEAD
/* ------------------ */
// FP16 -> FP8
/* ------------------ */
static SmallVector<Value>
convertFp16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const char *ptxAsm, const Value &v0, const Value &v1,
const Value &v2, const Value &v3) {
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);
Value fp16x2Vec1 = undef(fp16x2VecTy);
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
PTXBuilder builder;
auto &ptxOp = *builder.create(ptxAsm);
auto *o = builder.newOperand("=r");
auto *i0 = builder.newOperand(fp16x2Vec0, "r");
auto *i1 = builder.newOperand(fp16x2Vec1, "r");
ptxOp({o, i0, i1}, /*onlyAttachMLIRArgs=*/true);
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
}
static SmallVector<Value>
convertFp16x4ToFp8E4M3x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
#ifdef USE_ROCM
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);
Value fp16x2Vec1 = undef(fp16x2VecTy);
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
Value a0 = shl(i32_ty, fp16x2Vec0, i32_val(1));
Value a1 = shl(i32_ty, fp16x2Vec1, i32_val(1));
a0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
a1 = and_(i32_ty, a1, i32_val(0x7fff7fff));
a0 = add(i32_ty, a0, i32_val(0x00800080));
a1 = add(i32_ty, a1, i32_val(0x00800080));
Value b0 = or_( i32_ty, and_(i32_ty, fp16x2Vec0, i32_val(0x80008000)), a0 );
Value b1 = or_( i32_ty, and_(i32_ty, fp16x2Vec1, i32_val(0x80008000)), a1 );
auto fp8x4VecTy = vec_ty(i8_ty, 4);
b0 = bitcast(b0, fp8x4VecTy);
b1 = bitcast(b1, fp8x4VecTy);
return {extract_element(i8_ty, b0, i32_val(1)),
extract_element(i8_ty, b0, i32_val(3)),
extract_element(i8_ty, b1, i32_val(1)),
extract_element(i8_ty, b1, i32_val(3))
};
#else
auto *ptxAsm = // WARN: subnormal Fp8s are not handled
"{ \n"
".reg .b32 a<2>, b<2>; \n" // see Fp8E4M3x4ToFp16x4
"sub.u32 a0, $1, 0x20002000; \n" // a0 = input0 - 0x20002000
// (compensate offset)
"sub.u32 a1, $2, 0x20002000; \n" // a1 = input1 - 0x20002000
// (8 << 10 | 8 << 10 << 16)
"shl.b32 a0, a0, 1; \n" // a0 <<= 1
"shl.b32 a1, a1, 1; \n" // shift into fp8e4 position
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n" // a0 &= 0x7fff7fff
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n" // b0 = a0|(0x80008000&in0)
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n" // (restore sign)
"prmt.b32 $0, b0, b1, 0x7531; \n" // output = b1b0
"}";
return convertFp16x4ToFp8x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
#endif
}
static SmallVector<Value>
convertFp16x4ToFp8E5M2x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
#ifdef USE_ROCM
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);
Value fp16x2Vec1 = undef(fp16x2VecTy);
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
Value b0 = bitcast(fp16x2Vec0, i32_ty);
Value b1 = bitcast(fp16x2Vec1, i32_ty);
auto fp8x4VecTy = vec_ty(i8_ty, 4);
b0 = bitcast(b0, fp8x4VecTy);
b1 = bitcast(b1, fp8x4VecTy);
return {extract_element(i8_ty, b0, i32_val(1)),
extract_element(i8_ty, b0, i32_val(3)),
extract_element(i8_ty, b1, i32_val(1)),
extract_element(i8_ty, b1, i32_val(3))
};
#else
auto *ptxAsm =
"{ \n"
".reg .b32 a<2>; \n"
"and.b32 a0, $1, 0x7fff7fff; \n" // a0 &= 0x7fff7fff
"and.b32 a1, $2, 0x7fff7fff; \n" // (strip sign)
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
"lop3.b32 a0, $1, 0x80008000, a0, 0xea; \n" // a0 = a0|(0x80008000&in0)
"lop3.b32 a1, $2, 0x80008000, a1, 0xea; \n" // (restore sign)
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
"}";
return convertFp16x4ToFp8x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
#endif
}
=======
class MultipleOperandsRange
: public iterator_range<SmallVector<SmallVector<Value>>::iterator> {
using ContainerT = SmallVector<SmallVector<Value>>;
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
public:
using iterator_range<ContainerT::iterator>::iterator_range;
@@ -896,121 +512,6 @@ public:
if (allOperands.size() == 0)
allOperands.push_back({});
<<<<<<< HEAD
static SmallVector<Value>
convertBf16x4ToFp8E4M3x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
#ifdef USE_ROCM
auto bf16x2VecTy = vec_ty(i16_ty, 2);
Value bf16x2Vec0 = undef(bf16x2VecTy);
Value bf16x2Vec1 = undef(bf16x2VecTy);
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v0, i32_val(0));
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v1, i32_val(1));
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v2, i32_val(0));
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v3, i32_val(1));
bf16x2Vec0 = bitcast(bf16x2Vec0, i32_ty);
bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty);
Value sign0 = and_(i32_ty, bf16x2Vec0, i32_val(0x80008000));
Value sign1 = and_(i32_ty, bf16x2Vec1, i32_val(0x80008000));
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value sign = undef(fp8x4VecTy);
sign0 = bitcast(sign0, fp8x4VecTy);
sign1 = bitcast(sign1, fp8x4VecTy);
sign = insert_element( fp8x4VecTy, sign, extract_element(i8_ty, sign0, i32_val(1)), i32_val(0) );
sign = insert_element( fp8x4VecTy, sign, extract_element(i8_ty, sign0, i32_val(3)), i32_val(1) );
sign = insert_element( fp8x4VecTy, sign, extract_element(i8_ty, sign1, i32_val(1)), i32_val(2) );
sign = insert_element( fp8x4VecTy, sign, extract_element(i8_ty, sign1, i32_val(3)), i32_val(3) );
sign = bitcast(sign, i32_ty);
Value nosign0 = and_(i32_ty, bf16x2Vec0, i32_val(0x7fff7fff));
Value nosign1 = and_(i32_ty, bf16x2Vec1, i32_val(0x7fff7fff));
Value nosign_0_0 = and_(i32_ty, nosign0, i32_val(0xffff0000));
nosign_0_0 = umax(i32_ty, nosign_0_0, i32_val(0x38000000));
nosign_0_0 = umin(i32_ty, nosign_0_0, i32_val(0x3ff00000));
Value nosign_0_1 = and_(i32_ty, nosign0, i32_val(0x0000ffff));
nosign_0_1 = umax(i32_ty, nosign_0_1, i32_val(0x3800));
nosign_0_1 = umin(i32_ty, nosign_0_1, i32_val(0x3ff0));
nosign0 = or_(i32_ty, nosign_0_0, nosign_0_1);
Value nosign_1_0 = and_(i32_ty, nosign1, i32_val(0xffff0000));
nosign_1_0 = umax(i32_ty, nosign_1_0, i32_val(0x38000000));
nosign_1_0 = umin(i32_ty, nosign_1_0, i32_val(0x3ff00000));
Value nosign_1_1 = and_(i32_ty, nosign1, i32_val(0x0000ffff));
nosign_1_1 = umax(i32_ty, nosign_1_1, i32_val(0x3800));
nosign_1_1 = umin(i32_ty, nosign_1_1, i32_val(0x3ff0));
nosign1 = or_(i32_ty, nosign_1_0, nosign_1_1);
nosign0 = add(i32_ty, nosign0, i32_val(0x80008));
nosign1 = add(i32_ty, nosign1, i32_val(0x80008));
nosign0 = sub(i32_ty, nosign0, i32_val(0x38003800));
nosign1 = sub(i32_ty, nosign1, i32_val(0x38003800));
nosign0 = lshr(i32_ty, nosign0, i32_val(4));
nosign1 = lshr(i32_ty, nosign1, i32_val(4));
nosign0 = bitcast(nosign0, fp8x4VecTy);
nosign1 = bitcast(nosign1, fp8x4VecTy);
Value nosign = undef(fp8x4VecTy);
nosign = insert_element( fp8x4VecTy, nosign, extract_element(i8_ty, nosign0, i32_val(0)), i32_val(0) );
nosign = insert_element( fp8x4VecTy, nosign, extract_element(i8_ty, nosign0, i32_val(2)), i32_val(1) );
nosign = insert_element( fp8x4VecTy, nosign, extract_element(i8_ty, nosign1, i32_val(0)), i32_val(2) );
nosign = insert_element( fp8x4VecTy, nosign, extract_element(i8_ty, nosign1, i32_val(2)), i32_val(3) );
nosign = bitcast(nosign, i32_ty);
Value fp8x4Vec = or_(i32_ty, nosign, sign);
fp8x4Vec = bitcast(fp8x4Vec, fp8x4VecTy);
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
#else
auto *ptxAsm = // bf16 is clamped firstly to fp8 min/max
"{ \n" // bf16=fp8>>4 + 120<<7
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n" // fp8_min = 0b00000000
".reg .u32 fp8_min, fp8_max, rn_; \n" // fp8_max = 0b11111111
"mov.u32 fp8_min, 0x3c003c00; \n" // so bf16_min = 0x3c00
"mov.u32 fp8_max, 0x43f043f0; \n" // so bf16_max = 0x43f0
"mov.u32 rn_, 0x80008; \n" // round to nearest
"and.b32 sign0, $1, 0x80008000; \n" // sign0=in0&0x80008000
"and.b32 sign1, $2, 0x80008000; \n" // (store sign)
"prmt.b32 sign, sign0, sign1, 0x7531; \n"
"and.b32 nosign0, $1, 0x7fff7fff; \n" // nosign0=in0&0x7fff7fff
"and.b32 nosign1, $2, 0x7fff7fff; \n" // (strip sign)
// nosign = clamp(nosign, min, max)
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n"
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n"
"max.u32 nosign_0_0, nosign_0_0, 0x3c000000; \n"
"min.u32 nosign_0_0, nosign_0_0, 0x43f00000; \n"
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n"
"max.u32 nosign_0_1, nosign_0_1, 0x3c00; \n"
"min.u32 nosign_0_1, nosign_0_1, 0x43f0; \n"
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n"
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n"
"max.u32 nosign_1_0, nosign_1_0, 0x3c000000; \n"
"min.u32 nosign_1_0, nosign_1_0, 0x43f00000; \n"
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n"
"max.u32 nosign_1_1, nosign_1_1, 0x3c00; \n"
"min.u32 nosign_1_1, nosign_1_1, 0x43f0; \n"
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n"
"add.u32 nosign0, nosign0, rn_; \n" // nosign0 += rn_
"add.u32 nosign1, nosign1, rn_; \n" // (round to nearest)
"sub.u32 nosign0, nosign0, 0x3c003c00; \n" // nosign0-=0x3c003c00
"sub.u32 nosign1, nosign1, 0x3c003c00; \n" // (compensate offset)
"shr.u32 nosign0, nosign0, 4; \n" // nosign0 >>= 4
"shr.u32 nosign1, nosign1, 4; \n" // shift into to fp8e4
"prmt.b32 nosign, nosign0, nosign1, 0x6420; \n" // nosign0 = 0x00f100f2
// nosign1 = 0x00f300f4
// nosign = 0xf3f4f1f2
"or.b32 $0, nosign, sign; \n" // restore sign
"}";
return convertBf16x4ToFp8x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
#endif
};
=======
SmallVector<Value> resultVals;
for (auto it = allOperands.begin(), end = allOperands.end(); it != end;) {
auto curr = static_cast<const ConcreteT *>(this)->createDestOps(
@@ -1033,7 +534,6 @@ public:
Value view = this->getTypeConverter()->packLLElements(loc, resultVals,
rewriter, resultTy);
rewriter.replaceOp(op, view);
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
return success();
}
@@ -1208,110 +708,20 @@ struct FpToFpOpConversion
for (Value &v : outVals)
v = convertFp16ToFp32(loc, rewriter, v);
// Pack values
<<<<<<< HEAD
assert(outVals.size() == elems);
outVals = reorderValues(outVals, srcTensorType, dstTensorType);
outVals =
packI32(outVals, dstTensorType, rewriter, loc, getTypeConverter());
auto result = getTypeConverter()->packLLElements(loc, outVals, rewriter,
dstTensorType);
rewriter.replaceOp(op, result);
return success();
return outVals;
}
};
template <typename OP>
Value EmitDualBF16ElementwiseOp(Location loc,
ConversionPatternRewriter &rewriter,
ValueRange operands) {
auto v0 = FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0]);
auto v1 = FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[1]);
MultipleOperandsRange operands) {
auto v0 = FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][0]);
auto v1 = FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][1]);
auto result = rewriter.create<OP>(loc, f32_ty, v0, v1);
return FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, result);
}
template <typename SourceOp, typename ConcreteT>
class ElementwiseOpConversionBase
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit ElementwiseOpConversionBase(
TritonGPUToLLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultTy = op.getType();
Location loc = op->getLoc();
// element type
auto resultElementTy = getElementTypeOrSelf(resultTy);
Type elemTy = this->getTypeConverter()->convertType(resultElementTy);
SmallVector<Value> resultVals;
//
SmallVector<SmallVector<Value>> allOperands;
for (auto operand : adaptor.getOperands()) {
auto argTy = op->getOperand(0).getType();
auto sub_operands = this->getTypeConverter()->unpackLLElements(
loc, operand, rewriter, argTy);
sub_operands = unpackI32(sub_operands, argTy, rewriter, loc,
this->getTypeConverter());
allOperands.resize(sub_operands.size());
for (auto v : llvm::enumerate(sub_operands))
allOperands[v.index()].push_back(v.value());
}
if (allOperands.size() == 0)
allOperands.push_back({});
for (const SmallVector<Value> &operands : allOperands) {
Value curr =
((ConcreteT *)(this))
->createDestOp(op, adaptor, rewriter, elemTy, operands, loc);
if (!bool(curr))
return failure();
resultVals.push_back(curr);
}
if (op->getNumOperands() > 0) {
auto argTy = op->getOperand(0).getType();
resultVals = reorderValues(resultVals, argTy, resultTy);
}
resultVals =
packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
Value view = this->getTypeConverter()->packLLElements(loc, resultVals,
rewriter, resultTy);
rewriter.replaceOp(op, view);
return success();
}
};
template <typename SourceOp, typename DestOp>
struct ElementwiseOpConversion
: public ElementwiseOpConversionBase<
SourceOp, ElementwiseOpConversion<SourceOp, DestOp>> {
using Base =
ElementwiseOpConversionBase<SourceOp,
ElementwiseOpConversion<SourceOp, DestOp>>;
using Base::Base;
using OpAdaptor = typename Base::OpAdaptor;
explicit ElementwiseOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ElementwiseOpConversionBase<SourceOp, ElementwiseOpConversion>(
typeConverter, benefit) {}
// An interface to support variant DestOp builder.
DestOp createDestOp(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
return rewriter.create<DestOp>(loc, elemTy, operands,
adaptor.getAttributes().getValue());
=======
return outVals;
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
}
};
struct CmpIOpConversion
: public ElementwiseOpConversionBase<triton::gpu::CmpIOp,
CmpIOpConversion> {
@@ -1458,20 +868,14 @@ struct FDivOpConversion
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
<<<<<<< HEAD
Value createDestOp(mlir::arith::DivFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
#ifdef USE_ROCM
return rewriter.create<LLVM::FDivOp>(loc, elemTy, operands[0],
operands[1]);
#else
=======
SmallVector<Value> createDestOps(mlir::arith::DivFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
#ifdef USE_ROCM
return {rewriter.create<LLVM::FDivOp>(loc, elemTy, operands[0][0],
operands[0][1])};
#else
PTXBuilder ptxBuilder;
auto &fdiv = *ptxBuilder.create<PTXInstr>("div");
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
@@ -1491,12 +895,8 @@ struct FDivOpConversion
fdiv(res, lhs, rhs);
Value ret = ptxBuilder.launch(rewriter, loc, elemTy, false);
<<<<<<< HEAD
return ret;
#endif
=======
return {ret};
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
#endif
}
};
@@ -1515,7 +915,7 @@ struct FMulOpConversion
auto rhsElemTy = getElementType(op.getRhs());
if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) {
#ifdef USE_ROCM
return EmitDualBF16ElementwiseOp<LLVM::FMulOp>(loc, rewriter, operands);
return {EmitDualBF16ElementwiseOp<LLVM::FMulOp>(loc, rewriter, operands)};
#else
PTXBuilder builder;
auto ptxAsm = " { .reg .b16 c; \n"
@@ -1526,12 +926,8 @@ struct FMulOpConversion
auto lhs = builder.newOperand(operands[0][0], "h");
auto rhs = builder.newOperand(operands[0][1], "h");
fMul({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
<<<<<<< HEAD
return builder.launch(rewriter, loc, i16_ty, false);
#endif
=======
return {builder.launch(rewriter, loc, i16_ty, false)};
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
#endif
} else {
return {rewriter.create<LLVM::FMulOp>(loc, elemTy, operands[0][0],
operands[0][1])};
@@ -1554,7 +950,7 @@ struct FAddOpConversion
auto rhsElemTy = getElementType(op.getRhs());
if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) {
#ifdef USE_ROCM
return EmitDualBF16ElementwiseOp<LLVM::FAddOp>(loc, rewriter, operands);
return {EmitDualBF16ElementwiseOp<LLVM::FAddOp>(loc, rewriter, operands)};
#else
PTXBuilder builder;
auto ptxAsm = "{ .reg .b16 c; \n"
@@ -1565,12 +961,8 @@ struct FAddOpConversion
auto lhs = builder.newOperand(operands[0][0], "h");
auto rhs = builder.newOperand(operands[0][1], "h");
fAdd({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
<<<<<<< HEAD
return builder.launch(rewriter, loc, i16_ty, false);
#endif
=======
return {builder.launch(rewriter, loc, i16_ty, false)};
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
#endif
} else {
return {rewriter.create<LLVM::FAddOp>(loc, elemTy, operands[0][0],
operands[0][1])};
@@ -1593,7 +985,7 @@ struct FSubOpConversion
auto rhsElemTy = getElementType(op.getRhs());
if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) {
#ifdef USE_ROCM
return EmitDualBF16ElementwiseOp<LLVM::FSubOp>(loc, rewriter, operands);
return {EmitDualBF16ElementwiseOp<LLVM::FSubOp>(loc, rewriter, operands)};
#else
PTXBuilder builder;
auto ptxAsm = " { .reg .b16 c; \n"
@@ -1604,12 +996,8 @@ struct FSubOpConversion
auto lhs = builder.newOperand(operands[0][0], "h");
auto rhs = builder.newOperand(operands[0][1], "h");
fSub({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
<<<<<<< HEAD
return builder.launch(rewriter, loc, i16_ty, false);
#endif
=======
return {builder.launch(rewriter, loc, i16_ty, false)};
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
#endif
} else {
return {rewriter.create<LLVM::FSubOp>(loc, elemTy, operands[0][0],
operands[0][1])};
@@ -1735,20 +1123,16 @@ struct ExpOpConversionApprox
Value prod = fmul(f32_ty, operands[0][0], f32_val(log2e));
#ifdef USE_ROCM
return rewriter.create<math::Exp2Op>(loc, f32_ty, prod,
adaptor.getAttributes().getValue());
return {rewriter.create<math::Exp2Op>(loc, f32_ty, prod,
adaptor.getAttributes().getValue())};
#else
PTXBuilder ptxBuilder;
auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32");
auto output = ptxBuilder.newOperand("=f");
auto input = ptxBuilder.newOperand(prod, "f");
exp2(output, input);
<<<<<<< HEAD
return ptxBuilder.launch(rewriter, loc, f32_ty, false);
#endif
=======
return {ptxBuilder.launch(rewriter, loc, f32_ty, false)};
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
#endif
}
};

View File

@@ -357,13 +357,9 @@ private:
unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData();
SmallVector<Value> smemBases(op.getNumOperands());
<<<<<<< HEAD
if (sizeInterWarps > 1) {
=======
bool isWarpSync = helper.isWarpSynchronous();
if (!isWarpSync) {
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
smemBases[0] = bitcast(
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]);
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
@@ -455,11 +451,7 @@ private:
accumulate(rewriter, *combineOp, acc, shfl, false);
}
<<<<<<< HEAD
if (sizeInterWarps == 1) {
=======
if (isWarpSync) {
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
finalAccs[key] = acc;
continue;
}
@@ -474,11 +466,7 @@ private:
}
}
<<<<<<< HEAD
if (sizeInterWarps == 1) {
=======
if (isWarpSync) {
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
if (auto resultTy =
@@ -513,12 +501,8 @@ private:
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
unsigned numThreads =
<<<<<<< HEAD
product<unsigned>(triton::gpu::getWarpsPerCTA(srcLayout)) * wavefront_size;
=======
product<unsigned>(triton::gpu::getWarpsPerCTA(srcLayout)) *
triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
unsigned elemsPerThread = std::max<unsigned>(elems / numThreads, 1);
Value readOffset = threadId;
for (unsigned round = 0; round < elemsPerThread; ++round) {

View File

@@ -487,12 +487,8 @@ struct GetNumProgramsOpConversion
#else
Value blockId =
rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxis()]);
<<<<<<< HEAD
rewriter.replaceOpWithNewOp<arith::TruncIOp>(op, i32_ty, blockId);
#endif // USE_ROCM
=======
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, blockId);
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
#endif // USE_ROCM
return success();
}

View File

@@ -618,11 +618,8 @@ private:
computeCapability)
.contains(byteWidth)) {
return;
<<<<<<< HEAD
#endif
=======
}
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
#endif
// load
auto tmpTy =

View File

@@ -471,13 +471,8 @@ public:
void runOnOperation() override {
// Only rewrite if the hardware does not support
<<<<<<< HEAD
if (!isROCM && computeCapability >= 90)
return;
=======
// if (computeCapability >= 90)
// return;
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
// NOTES(Chenggang): we don't use `ConversionPatternRewriter`, because
// MLIR does not support one-multiple value mapping. For example, if we use

View File

@@ -1315,7 +1315,7 @@ struct TritonGPUInferLayoutInterface
// Verify that the encodings are valid.
if (!aEncoding || !bEncoding)
return op->emitError("mismatching encoding between A and B operands");
if (aEncoding.getMMAv2kWidth() != bEncoding.getMMAv2kWidth())
if (aEncoding.getKWidth() != bEncoding.getKWidth())
return op->emitError("mismatching kWidth between A and B operands");
return success();
}

View File

@@ -431,15 +431,11 @@ public:
}
}
<<<<<<< HEAD
pushConversionForward(cvt, cvtSlices, sharedInfo, rewriter);
=======
// Call SimplifyReduceCvt instead of the general push conversion forward
if (isa<triton::ReduceOp>(cvtSlices.front()))
return failure();
pushConversionForward(cvt, cvtSlices, rewriter);
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
pushConversionForward(cvt, cvtSlices, sharedInfo, rewriter);
return success();
}
};

View File

@@ -91,9 +91,8 @@ public:
return;
op->moveAfter(argOp);
});
<<<<<<< HEAD
// Move `dot` operand so that conversions to opIdx=0 happens before
// conversions to opIdx=1
// Move `dot` operand so that conversions to opIdx=1 happens after
// conversions to opIdx=0
#ifdef USE_ROCM
// Skip this reordering for ROCm backend since it will sink shared->dot
// conversion for Q tensor in flash attention into the main loop. This
@@ -101,10 +100,6 @@ public:
// iteration.
return;
#endif
=======
// Move `dot` operand so that conversions to opIdx=1 happens after
// conversions to opIdx=0
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
m.walk([&](triton::gpu::ConvertLayoutOp op) {
auto dstType = op.getResult().getType().cast<RankedTensorType>();
auto dstEncoding =

View File

@@ -16,11 +16,8 @@
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
<<<<<<< HEAD
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
=======
#include "triton/Target/LLVMIR/Passes.h"
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/IR/CallingConv.h"
#include "llvm/ADT/APInt.h"
@@ -363,15 +360,12 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
// Simplify the IR
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
<<<<<<< HEAD
#ifdef USE_ROCM
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(createConvertControlFlowToLLVMPass());
#endif
=======
if (!::triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"))
pm.addPass(mlir::createLLVMDIScopePass());
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
if (failed(pm.run(module))) {
llvm::errs() << "Pass execution failed";

View File

@@ -1052,126 +1052,108 @@ void init_triton_ir(py::module &&m) {
.def("create_shl",
[](TritonOpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {
<<<<<<< HEAD
auto loc = self.getUnknownLoc();
#ifdef USE_ROCM
mlir::Type elementType = getElementTypeOrSelf(lhs.getType());
unsigned typeWidth = elementType.getIntOrFloatBitWidth();
auto constValue = self.create<mlir::arith::ConstantIntOp>(
loc, typeWidth, elementType);
typeWidth, elementType);
auto zeroConst =
self.create<mlir::arith::ConstantIntOp>(loc, 0, elementType);
self.create<mlir::arith::ConstantIntOp>(0, elementType);
if (lhs.getType().isIntOrIndex()) {
auto cmpValue = self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ult, rhs, constValue);
auto shiftValue = self.create<mlir::arith::ShLIOp>(loc, lhs, rhs);
mlir::arith::CmpIPredicate::ult, rhs, constValue);
auto shiftValue = self.create<mlir::arith::ShLIOp>(lhs, rhs);
return mlir::Value(self.create<mlir::arith::SelectOp>(
loc, cmpValue, shiftValue, zeroConst));
cmpValue, shiftValue, zeroConst));
} else {
auto splatValue = self.create<mlir::tensor::SplatOp>(
loc, lhs.getType(), constValue);
lhs.getType(), constValue);
auto zeroValue = self.create<mlir::tensor::SplatOp>(
loc, lhs.getType(), zeroConst);
lhs.getType(), zeroConst);
auto cmpValue = self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ult, rhs, splatValue);
auto shiftValue = self.create<mlir::arith::ShLIOp>(loc, lhs, rhs);
mlir::arith::CmpIPredicate::ult, rhs, splatValue);
auto shiftValue = self.create<mlir::arith::ShLIOp>(lhs, rhs);
return mlir::Value(self.create<mlir::arith::SelectOp>(
loc, cmpValue, shiftValue, zeroValue));
cmpValue, shiftValue, zeroValue));
}
#else
return mlir::Value(
self.create<mlir::arith::ShLIOp>(loc, lhs, rhs));
#endif
=======
return mlir::Value(self.create<mlir::arith::ShLIOp>(lhs, rhs));
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
#endif
})
.def("create_lshr",
[](TritonOpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {
<<<<<<< HEAD
auto loc = self.getUnknownLoc();
#ifdef USE_ROCM
mlir::Type elementType = getElementTypeOrSelf(lhs.getType());
unsigned typeWidth = elementType.getIntOrFloatBitWidth();
auto constValue = self.create<mlir::arith::ConstantIntOp>(
loc, typeWidth, elementType);
typeWidth, elementType);
auto zeroConst =
self.create<mlir::arith::ConstantIntOp>(loc, 0, elementType);
self.create<mlir::arith::ConstantIntOp>(0, elementType);
if (lhs.getType().isIntOrIndex()) {
auto cmpValue = self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ult, rhs, constValue);
auto shiftValue = self.create<mlir::arith::ShRUIOp>(loc, lhs, rhs);
mlir::arith::CmpIPredicate::ult, rhs, constValue);
auto shiftValue = self.create<mlir::arith::ShRUIOp>(lhs, rhs);
return mlir::Value(self.create<mlir::arith::SelectOp>(
loc, cmpValue, shiftValue, zeroConst));
cmpValue, shiftValue, zeroConst));
} else {
auto splatValue = self.create<mlir::tensor::SplatOp>(
loc, lhs.getType(), constValue);
lhs.getType(), constValue);
auto zeroValue = self.create<mlir::tensor::SplatOp>(
loc, lhs.getType(), zeroConst);
lhs.getType(), zeroConst);
auto cmpValue = self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ult, rhs, splatValue);
auto shiftValue = self.create<mlir::arith::ShRUIOp>(loc, lhs, rhs);
mlir::arith::CmpIPredicate::ult, rhs, splatValue);
auto shiftValue = self.create<mlir::arith::ShRUIOp>(lhs, rhs);
return mlir::Value(self.create<mlir::arith::SelectOp>(
loc, cmpValue, shiftValue, zeroValue));
cmpValue, shiftValue, zeroValue));
}
#else
return mlir::Value(
self.create<mlir::arith::ShRUIOp>(loc, lhs, rhs));
#endif
=======
return mlir::Value(self.create<mlir::arith::ShRUIOp>(lhs, rhs));
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
#endif
})
.def("create_ashr",
[](TritonOpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {
<<<<<<< HEAD
auto loc = self.getUnknownLoc();
#ifdef USE_ROCM
mlir::Type elementType = getElementTypeOrSelf(lhs.getType());
unsigned typeWidth = elementType.getIntOrFloatBitWidth();
auto constValue = self.create<mlir::arith::ConstantIntOp>(
loc, typeWidth, elementType);
typeWidth, elementType);
auto zeroConst =
self.create<mlir::arith::ConstantIntOp>(loc, 0, elementType);
self.create<mlir::arith::ConstantIntOp>(0, elementType);
uint64_t ones_val = 0xFFFFFFFFFFFFFFFF;
auto onesConst =
self.create<mlir::arith::ConstantIntOp>(loc, ones_val, elementType);
self.create<mlir::arith::ConstantIntOp>(ones_val, elementType);
if (lhs.getType().isIntOrIndex()) {
auto negativeCmpValue = self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::slt, lhs, zeroConst);
mlir::arith::CmpIPredicate::slt, lhs, zeroConst);
auto otherValue = mlir::Value(self.create<mlir::arith::SelectOp>(
loc, negativeCmpValue, onesConst, zeroConst));
negativeCmpValue, onesConst, zeroConst));
auto cmpValue = self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ult, rhs, constValue);
auto shiftValue = self.create<mlir::arith::ShRSIOp>(loc, lhs, rhs);
mlir::arith::CmpIPredicate::ult, rhs, constValue);
auto shiftValue = self.create<mlir::arith::ShRSIOp>(lhs, rhs);
return mlir::Value(self.create<mlir::arith::SelectOp>(
loc, cmpValue, shiftValue, otherValue));
cmpValue, shiftValue, otherValue));
} else {
auto splatValue = self.create<mlir::tensor::SplatOp>(
loc, lhs.getType(), constValue);
lhs.getType(), constValue);
auto zeroValue = self.create<mlir::tensor::SplatOp>(
loc, lhs.getType(), zeroConst);
lhs.getType(), zeroConst);
auto onesValue = self.create<mlir::tensor::SplatOp>(
loc, lhs.getType(), onesConst);
lhs.getType(), onesConst);
auto negativeCmpValue = self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::slt, lhs, zeroValue);
mlir::arith::CmpIPredicate::slt, lhs, zeroValue);
auto otherValue = mlir::Value(self.create<mlir::arith::SelectOp>(
loc, negativeCmpValue, onesValue, zeroValue));
negativeCmpValue, onesValue, zeroValue));
auto cmpValue = self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ult, rhs, splatValue);
auto shiftValue = self.create<mlir::arith::ShRSIOp>(loc, lhs, rhs);
mlir::arith::CmpIPredicate::ult, rhs, splatValue);
auto shiftValue = self.create<mlir::arith::ShRSIOp>(lhs, rhs);
return mlir::Value(self.create<mlir::arith::SelectOp>(
loc, cmpValue, shiftValue, otherValue));
cmpValue, shiftValue, otherValue));
}
#else
return mlir::Value(
self.create<mlir::arith::ShRSIOp>(loc, lhs, rhs));
#endif
=======
return mlir::Value(self.create<mlir::arith::ShRSIOp>(lhs, rhs));
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
#endif
})
// AddPtr (similar to GEP)
.def("create_addptr",

View File

@@ -14,14 +14,9 @@ from typing import Any, Tuple
from .._C.libtriton.triton import (add_external_libs, compile_ptx_to_cubin,
get_shared_memory_size, ir,
translate_llvmir_to_hsaco, translate_llvmir_to_ptx,
<<<<<<< HEAD
translate_triton_gpu_to_llvmir, get_arch_info,
get_warp_size)
from ..common.backend import get_backend
=======
translate_triton_gpu_to_llvmir)
from ..common.backend import get_backend, path_to_ptxas
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
# from ..runtime import driver, jit, JITFunction
# TODO: runtime.errors
from ..runtime.autotuner import OutOfResources

View File

@@ -294,15 +294,12 @@ class _attention(torch.autograd.Function):
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
raise RuntimeError("Flash attention currently only supported for compute capability >= 80")
<<<<<<< HEAD
if torch.version.hip is not None:
BLOCK = 64
BLOCK_M = 64
BLOCK_N = 64
else:
BLOCK = 128
=======
BLOCK_M = 128
BLOCK_N = 64
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
BLOCK_M = 128
BLOCK_N = 64
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv

View File

@@ -25,11 +25,7 @@ class OutOfResources(Exception):
class Autotuner(KernelInterface):
<<<<<<< HEAD
def __init__(self, fn, arg_names, configs, key, verbose, reset_to_zero, prune_configs_by: Dict = None):
=======
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100):
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
def __init__(self, fn, arg_names, configs, key, verbose, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100):
'''
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
@@ -62,12 +58,9 @@ class Autotuner(KernelInterface):
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn
<<<<<<< HEAD
self.verbose = verbose
=======
self.warmup = warmup
self.rep = rep
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
self.verbose = verbose
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
@@ -187,11 +180,7 @@ class Config:
return ', '.join(res)
<<<<<<< HEAD
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=False):
=======
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25, rep=100):
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=False, warmup=25, rep=100):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
@@ -222,21 +211,15 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25,
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
<<<<<<< HEAD
:param verbose: a boolean that controls whether the best_config for each key is printed
:type verbose: bool
"""
def decorator(fn):
return Autotuner(fn, fn.arg_names, configs, key, verbose, reset_to_zero, prune_configs_by)
=======
:param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25.
:type warmup: int
:param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
:type rep: int
:param verbose: a boolean that controls whether the best_config for each key is printed
:type verbose: bool
"""
def decorator(fn):
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, warmup, rep)
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
return Autotuner(fn, fn.arg_names, configs, key, verbose, reset_to_zero, prune_configs_by, warmup, rep)
return decorator

View File

@@ -325,12 +325,8 @@ class JITFunction(KernelInterface[T]):
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
src = f"""
<<<<<<< HEAD
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
=======
def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
from ..compiler import compile, CompiledKernel
sig_key = {sig_keys},
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}

View File

@@ -1,125 +0,0 @@
import argparse
import sys
from .._C.libtriton.triton import ir
# import triton.compiler.compiler as tc
from ..compiler.compiler import (get_amdgpu_arch_fulldetails, llir_to_amdgcn_and_hsaco,
llir_to_ptx, optimize_ttgir, optimize_ttir,
ttgir_to_llir, ttir_to_ttgir, CUDA_DEFAULT_WARP_SIZE)
if __name__ == '__main__':
# valid source and target formats
VALID_FORMATS = ['triton-ir', 'triton-gpu-ir', 'llvm-ir', 'ptx', 'amdgcn']
# set up the argument parser
# TODO: conditional requirements
parser = argparse.ArgumentParser()
parser.add_argument('src', help="Source file to compile")
parser.add_argument('--target', required=True,
help="Target format, one of: " + ', '.join(VALID_FORMATS))
parser.add_argument('--sm', type=int, help="Compute capability to compile for")
parser.add_argument('--ptx-version', type=int, help="PTX version to compile for")
parser.add_argument('--gfx', type=str, help="AMDGPU target to compile for")
parser.add_argument('--triple', type=str, help="target triple, for example: amdgcn-amd-amdhsa")
parser.add_argument('--features', type=str, help="target features, for example: +sramecc,-xnack")
parser.add_argument('--num_warps', type=int, help="number of warps to compile ttgir for")
# parse the args
args = parser.parse_args()
# TODO: clean-up and re-use triton.compiler primitive functions
# check for validity of format arguments
if args.target not in VALID_FORMATS:
print("Invalid target format: " + args.target)
sys.exit(0)
# parse source file to MLIR module
context = ir.context()
module = ir.parse_mlir_module(args.src, context)
module.context = context
# optimizer triton-ir
module = optimize_ttir(module, arch=args.sm)
if args.target == 'triton-ir':
print(module.str())
sys.exit(0)
if not args.num_warps:
args.num_warps = 4
# llvm-ir -> amdgcn
if args.target == 'amdgcn':
# auto detect available architecture and features
# if nothing detected, set with default values
arch_details = get_amdgpu_arch_fulldetails()
if not arch_details:
arch_name = ""
arch_triple = "amdgcn-amd-amdhsa"
arch_features = ""
arch_warpsize = 64
else:
arch_triple, arch_name, arch_features, arch_warpsize = arch_details
# stop processing if architecture name is not automatically detected and is not set manually
if not args.gfx and not arch_name:
raise argparse.ArgumentError(None, "Must specify --gfx for AMDGCN compilation")
# rewrite default and automatically detected values with manually provided data
if args.gfx:
arch_name = args.gfx
if args.triple:
arch_triple = args.triple
if args.features:
arch_features = args.features
# triton-ir -> triton-gpu-ir
# use compute_capability == 80
module = ttir_to_ttgir(module, num_warps=args.num_warps, warpsize=arch_warpsize) # num_stages=3, compute_capability=80)
module = optimize_ttgir(module, num_stages=3, arch=args.gfx)
# triton-gpu-ir -> llvm-ir
# use compute_capability == 80
module = ttgir_to_llir(module, extern_libs=None, arch=args.gfx)
# llvm-ir -> amdgcn asm, hsaco binary
module, hsaco_path = llir_to_amdgcn_and_hsaco(module, arch_name, arch_triple, arch_features)
print(hsaco_path)
print(module)
sys.exit(0)
# set arch depending on platform
if args.gfx:
arch = args.gfx
elif args.sm:
arch = args.sm
else:
raise argparse.ArgumentError(None, "Must specify --sm or --gfx for ttgir compilation")
# triton-ir -> triton-gpu-ir
module = ttir_to_ttgir(module, num_warps=args.num_warps, warpsize=CUDA_DEFAULT_WARP_SIZE)
module = optimize_ttgir(module, num_stages=3, arch=arch)
if args.target == 'triton-gpu-ir':
print(module.str())
sys.exit(0)
# triton-gpu-ir -> llvm-ir
module = ttgir_to_llir(module, extern_libs=None, arch=arch)
if args.target == 'llvm-ir':
print(module)
sys.exit(0)
# llvm-ir -> ptx
if args.target == 'ptx':
if not args.sm:
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
if not args.ptx_version:
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
module = llir_to_ptx(module, arch=args.sm, ptx_version=args.ptx_version)
# llvm-ir -> amdgcn
if args.target == 'amdgcn':
if not args.gfx:
raise argparse.ArgumentError(None, "Must specify --gfx for AMDGCN compilation")
module, hsaco_path = llir_to_amdgcn_and_hsaco(module, args.gfx)
print(module)

View File

@@ -39,16 +39,10 @@ def _fwd_kernel(
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
<<<<<<< HEAD
q_offset = off_hz * stride_qh
kv_offset = off_hz * stride_kh
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
=======
qvk_offset = off_hz * stride_qh
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
@@ -56,26 +50,16 @@ def _fwd_kernel(
order=(1, 0)
)
K_block_ptr = tl.make_block_ptr(
<<<<<<< HEAD
base=K + kv_offset,
shape=(BLOCK_DMODEL, N_CTX + P_SEQ),
=======
base=K + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
<<<<<<< HEAD
base=V + kv_offset,
shape=(N_CTX + P_SEQ, BLOCK_DMODEL),
=======
base=V + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
@@ -97,11 +81,7 @@ def _fwd_kernel(
q = (q * qk_scale).to(tl.float16)
# loop over k, v and update accumulator
lo = 0
<<<<<<< HEAD
hi = P_SEQ + (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + P_SEQ
=======
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
for start_n in range(lo, hi, BLOCK_N):
# -- load k, v --
k = tl.load(K_block_ptr)
@@ -109,11 +89,7 @@ def _fwd_kernel(
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
if IS_CAUSAL:
<<<<<<< HEAD
qk = tl.where(P_SEQ + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
=======
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
qk += tl.dot(q, k)
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
@@ -135,11 +111,7 @@ def _fwd_kernel(
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
O_block_ptr = tl.make_block_ptr(
<<<<<<< HEAD
base=Out + q_offset,
=======
base=Out + qvk_offset,
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
@@ -152,11 +124,7 @@ def _fwd_kernel(
@triton.jit
def _bwd_preprocess(
Out, DO,
<<<<<<< HEAD
NewDO, Delta,
=======
Delta,
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
@@ -164,12 +132,10 @@ def _bwd_preprocess(
# load
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
<<<<<<< HEAD
=======
# compute
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
tl.store(Delta + off_m, delta)
@@ -233,22 +199,13 @@ def _bwd_kernel(
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
<<<<<<< HEAD
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, tl.trans(k))
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
if CAUSAL:
qk = tl.dot(q, tl.trans(k), out_dtype=tl.float32)
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
else:
qk = tl.dot(q, tl.trans(k), out_dtype=tl.float32)
l_i = tl.load(l_ptrs + offs_m_curr)
p = tl.math.exp2(qk * qk_scale - l_i[:, None])
=======
if CAUSAL:
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
else:
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
qk *= qk_scale
l_i = tl.load(l_ptrs + offs_m_curr)
p = tl.math.exp2(qk - l_i[:, None])
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
@@ -492,18 +449,13 @@ empty = torch.empty(128, device="cuda")
class _attention(torch.autograd.Function):
@staticmethod
<<<<<<< HEAD
def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False):
=======
def forward(ctx, q, k, v, causal, sm_scale):
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
BLOCK_M = 128
<<<<<<< HEAD
if torch.version.hip is None:
BLOCK_N = 64 if Lk <= 64 else 32
num_stages = 4 if Lk <= 64 else 3
@@ -514,11 +466,6 @@ class _attention(torch.autograd.Function):
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2]
=======
BLOCK_N = 64
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid](
@@ -529,48 +476,36 @@ class _attention(torch.autograd.Function):
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
<<<<<<< HEAD
q.shape[0], q.shape[1], q.shape[2], P_SEQ,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
IS_CAUSAL=causal,
num_warps=num_warps,
num_stages=num_stages)
=======
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
IS_CAUSAL=causal,
num_warps=num_warps,
num_stages=4)
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
ctx.save_for_backward(q, k, v, o, L)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
ctx.causal = causal
<<<<<<< HEAD
ctx.split_kernel = split_kernel
ctx.P_SEQ = P_SEQ
=======
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
return o
@staticmethod
def backward(ctx, do):
<<<<<<< HEAD
BLOCK = 64
q, k, v, o, l = ctx.saved_tensors
=======
BLOCK = 128
# configuration is not supported
assert(not (ctx.split_kernel and not ctx.causal))
if torch.version.hip is not None:
BLOCK = 64
else:
BLOCK = 128
q, k, v, o, L = ctx.saved_tensors
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
do = do.contiguous()
dq = torch.zeros_like(q)
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
<<<<<<< HEAD
delta = torch.empty_like(L)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
# Figure out what BLOCK size fwd used and adjust num_blocks accordingly.
# If the two are the same, we don't need this but the bwd pass block size
# is smaller than the fwd so we need this scaling to ensure we loop over all
@@ -588,8 +523,7 @@ class _attention(torch.autograd.Function):
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l,
delta,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
@@ -597,15 +531,16 @@ class _attention(torch.autograd.Function):
block_scale * ctx.grid[0],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
CAUSAL=ctx.causal,
num_stages=1,
)
else :
dq = torch.zeros_like(q)
_bwd_kernel_dk_dv[(block_scale * ctx.grid[0], ctx.grid[1])](
q, k, v, ctx.sm_scale,
o, do_scaled,
dk, dv,
l,
delta,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
@@ -618,8 +553,7 @@ class _attention(torch.autograd.Function):
q, k, v, ctx.sm_scale,
o, do_scaled,
dq,
l,
delta,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
@@ -630,36 +564,10 @@ class _attention(torch.autograd.Function):
)
# print(h.asm["ttgir"])
return dq, dk, dv, None, None, None
=======
delta = torch.empty_like(L)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do,
delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do,
dq, dk, dv,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
CAUSAL=ctx.causal,
num_stages=1,
)
return dq, dk, dv, None, None
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
attention = _attention.apply
<<<<<<< HEAD
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, P_SEQ',
[(4, 48, 1024, 64, 128),
(4, 48, 2048, 64, 128),
@@ -702,16 +610,6 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, P_SEQ, dtype=torch.float16):
v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
sm_scale = q.shape[-1] ** (-0.5)
split_kernel = True
=======
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 9, 1024, 64)])
@pytest.mark.parametrize('causal', [False, True])
def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
sm_scale = 0.5
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ)
@@ -724,13 +622,8 @@ def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
<<<<<<< HEAD
# # triton implementation
tri_out = attention(q, k, v, causal, sm_scale, split_kernel)
=======
# triton implementation
tri_out = attention(q, k, v, causal, sm_scale).half()
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
@@ -771,11 +664,7 @@ configs = [triton.testing.Benchmark(
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode, 'causal': causal}
<<<<<<< HEAD
) for mode in ['fwd', 'bwd'] for causal in [True, False]]
=======
) for mode in ['fwd', 'bwd'] for causal in [False, True]]
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
@triton.testing.perf_report(configs)
@@ -793,11 +682,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
sm_scale = 1.3
<<<<<<< HEAD
fn = lambda: attention(q, k, v, causal, sm_scale, split_kernel)
=======
fn = lambda: attention(q, k, v, causal, sm_scale)
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)

View File

@@ -1638,25 +1638,6 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
}
// -----
<<<<<<< HEAD
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [2, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: atomic_add_f16
tt.func @atomic_add_f16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: f16 {tt.difisibility = 16 : i32}) {
%c1_i1 = arith.constant 1 : i1
%1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%2 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x32xi32, #blocked>
%3 = tt.broadcast %2 : (tensor<1x32xi32, #blocked>) -> tensor<32x32xi32, #blocked>
%4 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<32x32x!tt.ptr<f16>, #blocked>
%5 = tt.addptr %4, %3 : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
%6 = tt.splat %arg1 : (f16) -> tensor<32x32xf16, #blocked>
%7 = tt.splat %c1_i1 : (i1) -> tensor<32x32xi1, #blocked>
// PTX: llvm.inline_asm
// PTX-SAME: @$3 atom.global.gpu.add.noftz.f16x2
// GCN-COUNT-8: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f16, 1>, f16
%8 = "tt.atomic_rmw"(%5, %6, %7) {atomic_rmw_op = 5 : i32, sem = 1: i32} : (tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xf16, #blocked>, tensor<32x32xi1, #blocked>) -> tensor<32x32xf16, #blocked>
=======
#mma = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[2, 2]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
@@ -1721,6 +1702,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
}
// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: test_s8_to_bf16_conversion
@@ -1735,6 +1717,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
}
// -----
#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1]}>
#dot = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
@@ -1746,7 +1729,31 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.inline_asm
// CHECK-NOT: llvm.inline_asm
%out = arith.sitofp %in : tensor<16x16xi8, #mma> to tensor<16x16xbf16, #mma>
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
tt.return
}
}
// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [2, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: atomic_add_f16
tt.func @atomic_add_f16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: f16 {tt.difisibility = 16 : i32}) {
%c1_i1 = arith.constant 1 : i1
%1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%2 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x32xi32, #blocked>
%3 = tt.broadcast %2 : (tensor<1x32xi32, #blocked>) -> tensor<32x32xi32, #blocked>
%4 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<32x32x!tt.ptr<f16>, #blocked>
%5 = tt.addptr %4, %3 : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
%6 = tt.splat %arg1 : (f16) -> tensor<32x32xf16, #blocked>
%7 = tt.splat %c1_i1 : (i1) -> tensor<32x32xi1, #blocked>
// PTX: llvm.inline_asm
// PTX-SAME: @$3 atom.global.gpu.add.noftz.f16x2
// GCN-COUNT-8: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f16, 1>, f16
%8 = "tt.atomic_rmw"(%5, %6, %7) {atomic_rmw_op = 5 : i32, sem = 1: i32} : (tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xf16, #blocked>, tensor<32x32xi1, #blocked>) -> tensor<32x32xf16, #blocked>
tt.return
}
}
// -----

View File

@@ -1,27 +0,0 @@
// RUN: %PYTHON -m triton.tools.aot %s --num_warps=1 --target=amdgcn --gfx=gfx906 --triple=amdgcn-amd-amdhsa --features="+sramecc,-xnack" | FileCheck %s
// RUN: %PYTHON -m triton.tools.aot %s --num_warps=1 --target=amdgcn --gfx=gfx908 --triple=amdgcn-amd-amdhsa --features="+sramecc,-xnack" | FileCheck %s
// RUN: %PYTHON -m triton.tools.aot %s --num_warps=1 --target=amdgcn --gfx=gfx90a --triple=amdgcn-amd-amdhsa --features="+sramecc,-xnack" | FileCheck %s
// == LLVM IR check begin ==
// CHECK-LABEL: {{^}}test_float16_load:
// CHECK: global_load_dword
// CHECK: global_load_dword
// CHECK: global_store_dword
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
tt.func public @test_float16_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked>
%1 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x!tt.ptr<f16>, #blocked>
%2 = tt.addptr %1, %0 : tensor<128x!tt.ptr<f16>, #blocked>, tensor<128xi32, #blocked>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf16, #blocked>
%4 = tt.splat %arg2 : (!tt.ptr<f16>) -> tensor<128x!tt.ptr<f16>, #blocked>
%5 = tt.addptr %4, %0 : tensor<128x!tt.ptr<f16>, #blocked>, tensor<128xi32, #blocked>
%6 = tt.load %5 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf16, #blocked>
%7 = arith.addf %3, %6 : tensor<128xf16, #blocked>
%8 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<128x!tt.ptr<f16>, #blocked>
%9 = tt.addptr %8, %0 : tensor<128x!tt.ptr<f16>, #blocked>, tensor<128xi32, #blocked>
tt.store %9, %7 {cache = 1 : i32, evict = 1 : i32} : tensor<128xf16, #blocked>
tt.return
}
}

View File

@@ -1,28 +0,0 @@
// RUN: %PYTHON -m triton.tools.aot %s --num_warps=1 --target=amdgcn --gfx=gfx906 --triple=amdgcn-amd-amdhsa --features="+sramecc,-xnack" | FileCheck %s
// RUN: %PYTHON -m triton.tools.aot %s --num_warps=1 --target=amdgcn --gfx=gfx908 --triple=amdgcn-amd-amdhsa --features="+sramecc,-xnack" | FileCheck %s
// RUN: %PYTHON -m triton.tools.aot %s --num_warps=1 --target=amdgcn --gfx=gfx90a --triple=amdgcn-amd-amdhsa --features="+sramecc,-xnack" | FileCheck %s
// == LLVM IR check begin ==
// CHECK-LABEL: {{^}}test_int16_load:
// CHECK: global_load_dword
// CHECK: global_load_dword
// CHECK: global_store_dword
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
tt.func public @test_int16_load(%arg0: !tt.ptr<i16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i16> {tt.divisibility = 16 : i32}) {
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked>
%1 = tt.splat %arg1 : (!tt.ptr<i16>) -> tensor<128x!tt.ptr<i16>, #blocked>
%2 = tt.addptr %1, %0 : tensor<128x!tt.ptr<i16>, #blocked>, tensor<128xi32, #blocked>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xi16, #blocked>
%4 = tt.splat %arg2 : (!tt.ptr<i16>) -> tensor<128x!tt.ptr<i16>, #blocked>
%5 = tt.addptr %4, %0 : tensor<128x!tt.ptr<i16>, #blocked>, tensor<128xi32, #blocked>
%6 = tt.load %5 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xi16, #blocked>
%7 = arith.addi %3, %6 : tensor<128xi16, #blocked>
%8 = tt.splat %arg0 : (!tt.ptr<i16>) -> tensor<128x!tt.ptr<i16>, #blocked>
%9 = tt.addptr %8, %0 : tensor<128x!tt.ptr<i16>, #blocked>, tensor<128xi32, #blocked>
tt.store %9, %7 {cache = 1 : i32, evict = 1 : i32} : tensor<128xi16, #blocked>
tt.return
}
}

View File

@@ -1,14 +0,0 @@
// RUN: %PYTHON -m triton.tools.aot %s --target=amdgcn --gfx=gfx906 --triple=amdgcn-amd-amdhsa --features="+sramecc,-xnack" | FileCheck %s
// == LLVM IR check begin ==
// CHECK-LABEL: {{^}}test_empty_kernel:
// CHECK-NEXT: s_endpgm
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
tt.return
}
}

View File

@@ -1,43 +0,0 @@
// RUN: export ROCM_PATH=/opt/rocm
// RUN: HSACO_PATH=$(%PYTHON -m triton.tools.aot %s --target=amdgcn --gfx=gfx906 --triple=amdgcn-amd-amdhsa --features="+sramecc,-xnack" | head -n 1)
// RUN: llvm-readobj -a "${HSACO_PATH}" | FileCheck %s
// CHECK: Format: elf64-amdgpu
// TODO: Arch: unknown
// CHECK: AddressSize: 64bit
// CHECK: ElfHeader {
// CHECK-NEXT: Ident {
// CHECK-NEXT: Magic: (7F 45 4C 46)
// CHECK-NEXT: Class: 64-bit (0x2)
// CHECK-NEXT: DataEncoding: LittleEndian (0x1)
// CHECK-NEXT: FileVersion: 1
// CHECK-NEXT: OS/ABI: AMDGPU_HSA (0x40)
// CHECK-NEXT: ABIVersion: 2
// CHECK-NEXT: Unused: (00 00 00 00 00 00 00)
// CHECK-NEXT: }
// CHECK-NEXT: Type: SharedObject (0x3)
// CHECK-NEXT: Machine: EM_AMDGPU (0xE0)
// CHECK-NEXT: Version: 1
// CHECK: Name: test_empty_kernel
// CHECK: Size: 4
// CHECK: Binding: Global
// CHECK: Type: Function
// CHECK: Section: .text
// CHECK: Type: NT_AMDGPU_METADATA (AMDGPU Metadata)
// CHECK: .group_segment_fixed_size: 0
// CHECK-NEXT: .kernarg_segment_align: 8
// CHECK-NEXT: .kernarg_segment_size: 16
// CHECK-NEXT: .max_flat_workgroup_size: 256
// CHECK-NEXT: .name: test_empty_kernel
// CHECK-NEXT: .private_segment_fixed_size: 0
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
tt.return
}
}

View File

@@ -1,16 +0,0 @@
// RUN: %PYTHON -m triton.tools.aot %s --target=llvm-ir --gfx=90a | FileCheck %s
// == LLVM IR check begin ==
// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule'
// CHECK: define amdgpu_kernel void @test_empty_kernel
// XHECK: !nvvm.annotations
// XHECK: !{ptr @test_empty_kernel, !"maxntidx", i32 128}
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
tt.return
}
}

View File

@@ -1,22 +0,0 @@
// RUN: %PYTHON -m triton.tools.aot %s --target=llvm-ir --gfx=90a | FileCheck %s
// == LLVM IR check begin ==
// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule'
// CHECK: define void @test_func
// CHECK: define amdgpu_kernel void @test_kernel
// CHECK: tail call void @test_func
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @test_func(%lb : index, %A : !tt.ptr<f16>) attributes { noinline = true } {
%0 = arith.constant 1.0 : f16
tt.store %A, %0 : f16
tt.return
}
tt.func @test_kernel(%lb : index, %A : !tt.ptr<f16>) {
tt.call @test_func(%lb, %A) : (index, !tt.ptr<f16>) -> ()
tt.return
}
}