mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Fix merge conflicts
This commit is contained in:
@@ -8,15 +8,11 @@ namespace triton {
|
||||
|
||||
std::unique_ptr<Pass> createCombineOpsPass();
|
||||
|
||||
<<<<<<< HEAD
|
||||
std::unique_ptr<Pass> createRewriteTensorPointerPass(int computeCapability = 80,
|
||||
bool isROCM = false);
|
||||
=======
|
||||
std::unique_ptr<Pass> createReorderBroadcastPass();
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createRewriteTensorPointerPass(int computeCapability = 80);
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
createRewriteTensorPointerPass(int computeCapability = 80,
|
||||
bool isROCM = false);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
|
||||
@@ -31,13 +31,10 @@ SmallVector<unsigned> getElemsPerThread(Type type);
|
||||
// getThreadsPerWarpWithUniqueData.
|
||||
SmallVector<unsigned> getThreadsPerWarp(Attribute layout);
|
||||
|
||||
<<<<<<< HEAD
|
||||
unsigned getWarpSize(Attribute layout);
|
||||
|
||||
=======
|
||||
// Returns the number of warps per CTA that may have access to replicated
|
||||
// elements. If you want non-replicated warps, use getWarpsPerCTAWithUniqueData.
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
SmallVector<unsigned> getWarpsPerCTA(Attribute layout);
|
||||
|
||||
SmallVector<unsigned> getSizePerThread(Attribute layout);
|
||||
|
||||
@@ -77,8 +77,7 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
|
||||
"ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
<<<<<<< HEAD
|
||||
"Type":$eltTy), [{
|
||||
"unsigned":$typeWidthInBit), [{
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// ---- begin GFX908/GFX90A ----
|
||||
@@ -93,9 +92,8 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
const int SIMDWidth = 16;
|
||||
|
||||
// number of inner dimension rows per one pattern repeat
|
||||
int typeBitWidth = eltTy.getIntOrFloatBitWidth();
|
||||
int innerDimLength = shape[order[0]];
|
||||
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeBitWidth;
|
||||
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit;
|
||||
|
||||
int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
|
||||
// Note: the following settings is customized to avoid
|
||||
@@ -111,7 +109,7 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
// 4. TODO: what about f64?
|
||||
//
|
||||
// maxPhase is set to SIMDWidth / perPhase
|
||||
int vecSize = (eltTy.isF16() ? 64 : 32 ) / typeBitWidth;
|
||||
int vecSize = ((typeWidthInBit == 16) ? 64 : 32 ) / typeWidthInBit;
|
||||
int maxPhase = SIMDWidth / perPhase;
|
||||
|
||||
return $_get(context, vecSize, perPhase, maxPhase, order);
|
||||
@@ -122,9 +120,6 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
}
|
||||
}
|
||||
#endif
|
||||
=======
|
||||
"unsigned":$typeWidthInBit), [{
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
|
||||
|
||||
if(!mmaEnc)
|
||||
@@ -154,11 +149,11 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
|
||||
// ---- begin Ampere ----
|
||||
if (mmaEnc.isAmpere()) {
|
||||
int perPhase = 128 / (shape[order[0]] * 4 / dotOpEnc.getMMAv2kWidth());
|
||||
int perPhase = 128 / (shape[order[0]] * 4 / dotOpEnc.getKWidth());
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getMMAv2kWidth()};
|
||||
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
|
||||
// for now, disable swizzle when using transposed int8 tensor cores
|
||||
if ((32 / typeWidthInBit != dotOpEnc.getMMAv2kWidth()) && order[0] == inner)
|
||||
if ((32 / typeWidthInBit != dotOpEnc.getKWidth()) && order[0] == inner)
|
||||
return $_get(context, 1, 1, 1, order);
|
||||
|
||||
// --- handle A operand ---
|
||||
|
||||
@@ -183,13 +183,10 @@ private:
|
||||
ReduceOpHelper helper(reduceOp);
|
||||
unsigned bytes = helper.getScratchSizeInBytes();
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
} else if (auto scanOp = dyn_cast<triton::ScanOp>(op)) {
|
||||
ScanLoweringHelper helper(scanOp);
|
||||
unsigned bytes = helper.getScratchSizeInBytes();
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
|
||||
auto srcTy = cvtLayout.getSrc().getType().cast<RankedTensorType>();
|
||||
auto dstTy = cvtLayout.getResult().getType().cast<RankedTensorType>();
|
||||
|
||||
@@ -94,11 +94,7 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
|
||||
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
|
||||
// that case doesn't need inter-warp communication
|
||||
<<<<<<< HEAD
|
||||
if (isFastReduction() && triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
|
||||
=======
|
||||
if (isWarpSynchronous())
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
return {{0, 0}, {0, 0}};
|
||||
|
||||
/// shared memory block0
|
||||
|
||||
@@ -401,153 +401,6 @@ inline SmallVector<Value> packI32(const SmallVector<Value> &inValues,
|
||||
return outValues;
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
struct FpToFpOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::FpToFpOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
typedef std::function<SmallVector<Value>(
|
||||
Location, ConversionPatternRewriter &, const Value &, const Value &,
|
||||
const Value &, const Value &)>
|
||||
ConvertorT;
|
||||
/* ------------------ */
|
||||
// FP8 -> FP16
|
||||
/* ------------------ */
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp8x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const char *ptxAsm, const Value &v0, const Value &v1,
|
||||
const Value &v2, const Value &v3) {
|
||||
auto ctx = rewriter.getContext();
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
Value fp8x4Vec = undef(fp8x4VecTy);
|
||||
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0));
|
||||
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1));
|
||||
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v2, i32_val(2));
|
||||
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v3, i32_val(3));
|
||||
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
|
||||
|
||||
PTXBuilder builder;
|
||||
auto &ptxOp = *builder.create(ptxAsm);
|
||||
|
||||
auto *o0 = builder.newOperand("=r");
|
||||
auto *o1 = builder.newOperand("=r");
|
||||
auto *i = builder.newOperand(fp8x4Vec, "r");
|
||||
ptxOp({o0, o1, i}, /*onlyAttachMLIRArgs=*/true);
|
||||
|
||||
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||
auto fp16x2x2StructTy =
|
||||
struct_ty(SmallVector<Type>{fp16x2VecTy, fp16x2VecTy});
|
||||
auto fp16x2x2Struct =
|
||||
builder.launch(rewriter, loc, fp16x2x2StructTy, false);
|
||||
auto fp16x2Vec0 = extract_val(fp16x2VecTy, fp16x2x2Struct, 0);
|
||||
auto fp16x2Vec1 = extract_val(fp16x2VecTy, fp16x2x2Struct, 1);
|
||||
return {extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
|
||||
extract_element(f16_ty, fp16x2Vec0, i32_val(1)),
|
||||
extract_element(f16_ty, fp16x2Vec1, i32_val(0)),
|
||||
extract_element(f16_ty, fp16x2Vec1, i32_val(1))};
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp8E4M3x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
#ifdef USE_ROCM
|
||||
auto ctx = rewriter.getContext();
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
Value a0 = undef(fp8x4VecTy);
|
||||
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1));
|
||||
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3));
|
||||
a0 = bitcast(a0, i32_ty);
|
||||
|
||||
Value a1 = undef(fp8x4VecTy);
|
||||
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1));
|
||||
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3));
|
||||
a1 = bitcast(a1, i32_ty);
|
||||
|
||||
Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
|
||||
Value b1 = and_(i32_ty, a1, i32_val(0x7fff7fff));
|
||||
|
||||
b0 = lshr(i32_ty, b0, i32_val(1));
|
||||
b1 = lshr(i32_ty, b1, i32_val(1));
|
||||
|
||||
b0 = or_( i32_ty, b0, and_(i32_ty, a0, i32_val(0x80008000)) );
|
||||
b1 = or_( i32_ty, b1, and_(i32_ty, a1, i32_val(0x80008000)) );
|
||||
|
||||
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||
auto fp16x2Vec0 = bitcast(b0, fp16x2VecTy);
|
||||
auto fp16x2Vec1 = bitcast(b1, fp16x2VecTy);
|
||||
|
||||
return { extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
|
||||
extract_element(f16_ty, fp16x2Vec0, i32_val(1)),
|
||||
extract_element(f16_ty, fp16x2Vec1, i32_val(0)),
|
||||
extract_element(f16_ty, fp16x2Vec1, i32_val(1))
|
||||
};
|
||||
#else
|
||||
auto *ptxAsm = // WARN: subnormal (0bs0000xxx) are not handled
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
|
||||
"prmt.b32 a0, 0, $2, 0x5040; \n" // a0 = 0xf300f400
|
||||
"prmt.b32 a1, 0, $2, 0x7060; \n" // a1 = 0xf100f200
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
|
||||
"shr.b32 b0, b0, 1; \n" // b0 >>= 1
|
||||
"shr.b32 b1, b1, 1; \n" // shift into fp16 position
|
||||
"add.u32 b0, b0, 0x20002000; \n" // b0.exp += 2**4-2**3
|
||||
// exponent compensate = 8
|
||||
"add.u32 b1, b1, 0x20002000; \n" // b1 += 8<<10 | 8<<10<<16
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
|
||||
"}";
|
||||
return convertFp8x4ToFp16x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
|
||||
#endif
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp8E5M2x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
// exponent bias of Fp8E5M2 and Fp16 are the same
|
||||
#ifdef USE_ROCM
|
||||
auto ctx = rewriter.getContext();
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
Value a0 = undef(fp8x4VecTy);
|
||||
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1));
|
||||
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3));
|
||||
a0 = bitcast(a0, i32_ty);
|
||||
|
||||
Value a1 = undef(fp8x4VecTy);
|
||||
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1));
|
||||
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3));
|
||||
a1 = bitcast(a1, i32_ty);
|
||||
|
||||
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||
auto fp16x2Vec0 = bitcast(a0, fp16x2VecTy);
|
||||
auto fp16x2Vec1 = bitcast(a1, fp16x2VecTy);
|
||||
|
||||
return { extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
|
||||
extract_element(f16_ty, fp16x2Vec0, i32_val(1)),
|
||||
extract_element(f16_ty, fp16x2Vec1, i32_val(0)),
|
||||
extract_element(f16_ty, fp16x2Vec1, i32_val(1))
|
||||
};
|
||||
#else
|
||||
auto *ptxAsm = "{ \n"
|
||||
"prmt.b32 $0, 0, $2, 0x5140; \n\t"
|
||||
"prmt.b32 $1, 0, $2, 0x7362; \n\t"
|
||||
"}";
|
||||
return convertFp8x4ToFp16x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
|
||||
#endif
|
||||
}
|
||||
=======
|
||||
typedef std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
|
||||
const Value &, const Value &,
|
||||
const Value &, const Value &)>
|
||||
@@ -555,7 +408,6 @@ typedef std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
|
||||
|
||||
static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType,
|
||||
Type outType) {
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
|
||||
ConverterT converter = [ptxAsm, inType, outType](
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
@@ -585,109 +437,6 @@ static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType,
|
||||
for (Value inVal : inPacked)
|
||||
operands.push_back(builder.newOperand(inVal, "r"));
|
||||
auto &ptxOp = *builder.create(ptxAsm);
|
||||
<<<<<<< HEAD
|
||||
|
||||
auto *o0 = builder.newOperand("=r");
|
||||
auto *o1 = builder.newOperand("=r");
|
||||
auto *i = builder.newOperand(fp8x4Vec, "r");
|
||||
ptxOp({o0, o1, i}, /* onlyAttachMLIRArgs */ true);
|
||||
|
||||
auto bf16x2VecTy = vec_ty(i16_ty, 2);
|
||||
auto bf16x2x2StructTy =
|
||||
struct_ty(SmallVector<Type>{bf16x2VecTy, bf16x2VecTy});
|
||||
auto bf16x2x2Struct =
|
||||
builder.launch(rewriter, loc, bf16x2x2StructTy, false);
|
||||
auto bf16x2Vec0 = extract_val(bf16x2VecTy, bf16x2x2Struct, 0);
|
||||
auto bf16x2Vec1 = extract_val(bf16x2VecTy, bf16x2x2Struct, 1);
|
||||
return {extract_element(i16_ty, bf16x2Vec0, i32_val(0)),
|
||||
extract_element(i16_ty, bf16x2Vec0, i32_val(1)),
|
||||
extract_element(i16_ty, bf16x2Vec1, i32_val(0)),
|
||||
extract_element(i16_ty, bf16x2Vec1, i32_val(1))};
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp8E4M3x4ToBf16x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
#ifdef USE_ROCM
|
||||
auto ctx = rewriter.getContext();
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
Value a0 = undef(fp8x4VecTy);
|
||||
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1));
|
||||
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2));
|
||||
a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3));
|
||||
a0 = bitcast(a0, i32_ty);
|
||||
|
||||
Value a1 = undef(fp8x4VecTy);
|
||||
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1));
|
||||
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3));
|
||||
a1 = bitcast(a1, i32_ty);
|
||||
|
||||
Value sign0 = and_(i32_ty, a0, i32_val(0x80008000));
|
||||
Value sign1 = and_(i32_ty, a1, i32_val(0x80008000));
|
||||
Value nosign0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
|
||||
Value nosign1 = and_(i32_ty, a1, i32_val(0x7fff7fff));
|
||||
|
||||
nosign0 = lshr(i32_ty, nosign0, i32_val(4));
|
||||
nosign1 = lshr(i32_ty, nosign1, i32_val(4));
|
||||
nosign0 = add(i32_ty, nosign0, i32_val(0x38003800));
|
||||
nosign1 = add(i32_ty, nosign1, i32_val(0x38003800));
|
||||
|
||||
auto bf16x2VecTy = vec_ty(i16_ty, 2);
|
||||
Value bf16x2Vec0 = or_(i32_ty, sign0, nosign0);
|
||||
Value bf16x2Vec1 = or_(i32_ty, sign1, nosign1);
|
||||
bf16x2Vec0 = bitcast(bf16x2Vec0, bf16x2VecTy);
|
||||
bf16x2Vec1 = bitcast(bf16x2Vec1, bf16x2VecTy);
|
||||
|
||||
return { extract_element(i16_ty, bf16x2Vec0, i32_val(0)),
|
||||
extract_element(i16_ty, bf16x2Vec0, i32_val(1)),
|
||||
extract_element(i16_ty, bf16x2Vec1, i32_val(0)),
|
||||
extract_element(i16_ty, bf16x2Vec1, i32_val(1))
|
||||
};
|
||||
#else
|
||||
auto *ptxAsm = // WARN: subnormal (0bs0000xxx) are not handled
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
|
||||
"prmt.b32 a0, 0, $2, 0x5040; \n" // a0 = 0xf300f400
|
||||
"prmt.b32 a1, 0, $2, 0x7060; \n" // a1 = 0xf100f200
|
||||
"and.b32 b0, a0, 0x7fff7fff; \n" // b0 = a0 & 0x7fff7fff
|
||||
"and.b32 b1, a1, 0x7fff7fff; \n" // (strip sign)
|
||||
"shr.b32 b0, b0, 4; \n" // b0 >>= 4
|
||||
"shr.b32 b1, b1, 4; \n" // shift into fp16 position
|
||||
"add.u32 b0, b0, 0x3c003c00; \n" // b0.exp += 2**7-2**3
|
||||
// exponent compensate = 120
|
||||
"add.u32 b1, b1, 0x3c003c00; \n" // b1 += 120<<7 | 120<<7<<16
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
|
||||
"}";
|
||||
return convertFp8x4ToBf16x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
|
||||
#endif
|
||||
};
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp8E5M2x4ToBf16x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto *ptxAsm = // WARN: subnormal (0bs00000xx) are not handled
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
|
||||
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
|
||||
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
|
||||
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
|
||||
"shr.b32 b1, b1, 3; \n" // shift into bf16 position
|
||||
"add.u32 b0, b0, 0x38003800; \n" // b0.exp += 2**7-2**4
|
||||
// exponent compensate = 112
|
||||
"add.u32 b1, b1, 0x38003800; \n" // b1 += 112<<7 | 112<<7<<16
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
|
||||
"}";
|
||||
return convertFp8x4ToBf16x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
|
||||
=======
|
||||
ptxOp(operands, /*onlyAttachMLIRArgs=*/true);
|
||||
auto outVecTy = vec_ty(outType, outVecWidth);
|
||||
SmallVector<Value> outPacked;
|
||||
@@ -705,146 +454,13 @@ static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType,
|
||||
ret.push_back(extract_element(outType, outPacked[i / outVecWidth],
|
||||
i32_val(i % outVecWidth)));
|
||||
return ret;
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
};
|
||||
return converter;
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
/* ------------------ */
|
||||
// FP16 -> FP8
|
||||
/* ------------------ */
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const char *ptxAsm, const Value &v0, const Value &v1,
|
||||
const Value &v2, const Value &v3) {
|
||||
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||
Value fp16x2Vec0 = undef(fp16x2VecTy);
|
||||
Value fp16x2Vec1 = undef(fp16x2VecTy);
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
|
||||
fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
|
||||
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
|
||||
|
||||
PTXBuilder builder;
|
||||
auto &ptxOp = *builder.create(ptxAsm);
|
||||
|
||||
auto *o = builder.newOperand("=r");
|
||||
auto *i0 = builder.newOperand(fp16x2Vec0, "r");
|
||||
auto *i1 = builder.newOperand(fp16x2Vec1, "r");
|
||||
ptxOp({o, i0, i1}, /*onlyAttachMLIRArgs=*/true);
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
|
||||
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp16x4ToFp8E4M3x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
#ifdef USE_ROCM
|
||||
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||
Value fp16x2Vec0 = undef(fp16x2VecTy);
|
||||
Value fp16x2Vec1 = undef(fp16x2VecTy);
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
|
||||
fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
|
||||
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
|
||||
|
||||
Value a0 = shl(i32_ty, fp16x2Vec0, i32_val(1));
|
||||
Value a1 = shl(i32_ty, fp16x2Vec1, i32_val(1));
|
||||
a0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
|
||||
a1 = and_(i32_ty, a1, i32_val(0x7fff7fff));
|
||||
a0 = add(i32_ty, a0, i32_val(0x00800080));
|
||||
a1 = add(i32_ty, a1, i32_val(0x00800080));
|
||||
Value b0 = or_( i32_ty, and_(i32_ty, fp16x2Vec0, i32_val(0x80008000)), a0 );
|
||||
Value b1 = or_( i32_ty, and_(i32_ty, fp16x2Vec1, i32_val(0x80008000)), a1 );
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
b0 = bitcast(b0, fp8x4VecTy);
|
||||
b1 = bitcast(b1, fp8x4VecTy);
|
||||
|
||||
return {extract_element(i8_ty, b0, i32_val(1)),
|
||||
extract_element(i8_ty, b0, i32_val(3)),
|
||||
extract_element(i8_ty, b1, i32_val(1)),
|
||||
extract_element(i8_ty, b1, i32_val(3))
|
||||
};
|
||||
|
||||
#else
|
||||
auto *ptxAsm = // WARN: subnormal Fp8s are not handled
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n" // see Fp8E4M3x4ToFp16x4
|
||||
"sub.u32 a0, $1, 0x20002000; \n" // a0 = input0 - 0x20002000
|
||||
// (compensate offset)
|
||||
"sub.u32 a1, $2, 0x20002000; \n" // a1 = input1 - 0x20002000
|
||||
// (8 << 10 | 8 << 10 << 16)
|
||||
"shl.b32 a0, a0, 1; \n" // a0 <<= 1
|
||||
"shl.b32 a1, a1, 1; \n" // shift into fp8e4 position
|
||||
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n" // a0 &= 0x7fff7fff
|
||||
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
|
||||
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
|
||||
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
|
||||
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n" // b0 = a0|(0x80008000&in0)
|
||||
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n" // (restore sign)
|
||||
"prmt.b32 $0, b0, b1, 0x7531; \n" // output = b1b0
|
||||
"}";
|
||||
return convertFp16x4ToFp8x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
|
||||
#endif
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp16x4ToFp8E5M2x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
#ifdef USE_ROCM
|
||||
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||
Value fp16x2Vec0 = undef(fp16x2VecTy);
|
||||
Value fp16x2Vec1 = undef(fp16x2VecTy);
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
|
||||
Value b0 = bitcast(fp16x2Vec0, i32_ty);
|
||||
Value b1 = bitcast(fp16x2Vec1, i32_ty);
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
b0 = bitcast(b0, fp8x4VecTy);
|
||||
b1 = bitcast(b1, fp8x4VecTy);
|
||||
|
||||
return {extract_element(i8_ty, b0, i32_val(1)),
|
||||
extract_element(i8_ty, b0, i32_val(3)),
|
||||
extract_element(i8_ty, b1, i32_val(1)),
|
||||
extract_element(i8_ty, b1, i32_val(3))
|
||||
};
|
||||
#else
|
||||
auto *ptxAsm =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>; \n"
|
||||
"and.b32 a0, $1, 0x7fff7fff; \n" // a0 &= 0x7fff7fff
|
||||
"and.b32 a1, $2, 0x7fff7fff; \n" // (strip sign)
|
||||
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
|
||||
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
|
||||
"lop3.b32 a0, $1, 0x80008000, a0, 0xea; \n" // a0 = a0|(0x80008000&in0)
|
||||
"lop3.b32 a1, $2, 0x80008000, a1, 0xea; \n" // (restore sign)
|
||||
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
|
||||
"}";
|
||||
return convertFp16x4ToFp8x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
|
||||
#endif
|
||||
}
|
||||
=======
|
||||
class MultipleOperandsRange
|
||||
: public iterator_range<SmallVector<SmallVector<Value>>::iterator> {
|
||||
using ContainerT = SmallVector<SmallVector<Value>>;
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
|
||||
public:
|
||||
using iterator_range<ContainerT::iterator>::iterator_range;
|
||||
@@ -896,121 +512,6 @@ public:
|
||||
if (allOperands.size() == 0)
|
||||
allOperands.push_back({});
|
||||
|
||||
<<<<<<< HEAD
|
||||
static SmallVector<Value>
|
||||
convertBf16x4ToFp8E4M3x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
#ifdef USE_ROCM
|
||||
auto bf16x2VecTy = vec_ty(i16_ty, 2);
|
||||
Value bf16x2Vec0 = undef(bf16x2VecTy);
|
||||
Value bf16x2Vec1 = undef(bf16x2VecTy);
|
||||
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v0, i32_val(0));
|
||||
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v1, i32_val(1));
|
||||
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v2, i32_val(0));
|
||||
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v3, i32_val(1));
|
||||
bf16x2Vec0 = bitcast(bf16x2Vec0, i32_ty);
|
||||
bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty);
|
||||
|
||||
Value sign0 = and_(i32_ty, bf16x2Vec0, i32_val(0x80008000));
|
||||
Value sign1 = and_(i32_ty, bf16x2Vec1, i32_val(0x80008000));
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
Value sign = undef(fp8x4VecTy);
|
||||
sign0 = bitcast(sign0, fp8x4VecTy);
|
||||
sign1 = bitcast(sign1, fp8x4VecTy);
|
||||
sign = insert_element( fp8x4VecTy, sign, extract_element(i8_ty, sign0, i32_val(1)), i32_val(0) );
|
||||
sign = insert_element( fp8x4VecTy, sign, extract_element(i8_ty, sign0, i32_val(3)), i32_val(1) );
|
||||
sign = insert_element( fp8x4VecTy, sign, extract_element(i8_ty, sign1, i32_val(1)), i32_val(2) );
|
||||
sign = insert_element( fp8x4VecTy, sign, extract_element(i8_ty, sign1, i32_val(3)), i32_val(3) );
|
||||
sign = bitcast(sign, i32_ty);
|
||||
|
||||
Value nosign0 = and_(i32_ty, bf16x2Vec0, i32_val(0x7fff7fff));
|
||||
Value nosign1 = and_(i32_ty, bf16x2Vec1, i32_val(0x7fff7fff));
|
||||
|
||||
Value nosign_0_0 = and_(i32_ty, nosign0, i32_val(0xffff0000));
|
||||
nosign_0_0 = umax(i32_ty, nosign_0_0, i32_val(0x38000000));
|
||||
nosign_0_0 = umin(i32_ty, nosign_0_0, i32_val(0x3ff00000));
|
||||
Value nosign_0_1 = and_(i32_ty, nosign0, i32_val(0x0000ffff));
|
||||
nosign_0_1 = umax(i32_ty, nosign_0_1, i32_val(0x3800));
|
||||
nosign_0_1 = umin(i32_ty, nosign_0_1, i32_val(0x3ff0));
|
||||
nosign0 = or_(i32_ty, nosign_0_0, nosign_0_1);
|
||||
|
||||
Value nosign_1_0 = and_(i32_ty, nosign1, i32_val(0xffff0000));
|
||||
nosign_1_0 = umax(i32_ty, nosign_1_0, i32_val(0x38000000));
|
||||
nosign_1_0 = umin(i32_ty, nosign_1_0, i32_val(0x3ff00000));
|
||||
Value nosign_1_1 = and_(i32_ty, nosign1, i32_val(0x0000ffff));
|
||||
nosign_1_1 = umax(i32_ty, nosign_1_1, i32_val(0x3800));
|
||||
nosign_1_1 = umin(i32_ty, nosign_1_1, i32_val(0x3ff0));
|
||||
nosign1 = or_(i32_ty, nosign_1_0, nosign_1_1);
|
||||
|
||||
nosign0 = add(i32_ty, nosign0, i32_val(0x80008));
|
||||
nosign1 = add(i32_ty, nosign1, i32_val(0x80008));
|
||||
nosign0 = sub(i32_ty, nosign0, i32_val(0x38003800));
|
||||
nosign1 = sub(i32_ty, nosign1, i32_val(0x38003800));
|
||||
nosign0 = lshr(i32_ty, nosign0, i32_val(4));
|
||||
nosign1 = lshr(i32_ty, nosign1, i32_val(4));
|
||||
|
||||
nosign0 = bitcast(nosign0, fp8x4VecTy);
|
||||
nosign1 = bitcast(nosign1, fp8x4VecTy);
|
||||
Value nosign = undef(fp8x4VecTy);
|
||||
nosign = insert_element( fp8x4VecTy, nosign, extract_element(i8_ty, nosign0, i32_val(0)), i32_val(0) );
|
||||
nosign = insert_element( fp8x4VecTy, nosign, extract_element(i8_ty, nosign0, i32_val(2)), i32_val(1) );
|
||||
nosign = insert_element( fp8x4VecTy, nosign, extract_element(i8_ty, nosign1, i32_val(0)), i32_val(2) );
|
||||
nosign = insert_element( fp8x4VecTy, nosign, extract_element(i8_ty, nosign1, i32_val(2)), i32_val(3) );
|
||||
nosign = bitcast(nosign, i32_ty);
|
||||
|
||||
Value fp8x4Vec = or_(i32_ty, nosign, sign);
|
||||
fp8x4Vec = bitcast(fp8x4Vec, fp8x4VecTy);
|
||||
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
|
||||
#else
|
||||
auto *ptxAsm = // bf16 is clamped firstly to fp8 min/max
|
||||
"{ \n" // bf16=fp8>>4 + 120<<7
|
||||
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n" // fp8_min = 0b00000000
|
||||
".reg .u32 fp8_min, fp8_max, rn_; \n" // fp8_max = 0b11111111
|
||||
"mov.u32 fp8_min, 0x3c003c00; \n" // so bf16_min = 0x3c00
|
||||
"mov.u32 fp8_max, 0x43f043f0; \n" // so bf16_max = 0x43f0
|
||||
"mov.u32 rn_, 0x80008; \n" // round to nearest
|
||||
"and.b32 sign0, $1, 0x80008000; \n" // sign0=in0&0x80008000
|
||||
"and.b32 sign1, $2, 0x80008000; \n" // (store sign)
|
||||
"prmt.b32 sign, sign0, sign1, 0x7531; \n"
|
||||
"and.b32 nosign0, $1, 0x7fff7fff; \n" // nosign0=in0&0x7fff7fff
|
||||
"and.b32 nosign1, $2, 0x7fff7fff; \n" // (strip sign)
|
||||
|
||||
// nosign = clamp(nosign, min, max)
|
||||
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n"
|
||||
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n"
|
||||
"max.u32 nosign_0_0, nosign_0_0, 0x3c000000; \n"
|
||||
"min.u32 nosign_0_0, nosign_0_0, 0x43f00000; \n"
|
||||
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n"
|
||||
"max.u32 nosign_0_1, nosign_0_1, 0x3c00; \n"
|
||||
"min.u32 nosign_0_1, nosign_0_1, 0x43f0; \n"
|
||||
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n"
|
||||
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n"
|
||||
"max.u32 nosign_1_0, nosign_1_0, 0x3c000000; \n"
|
||||
"min.u32 nosign_1_0, nosign_1_0, 0x43f00000; \n"
|
||||
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n"
|
||||
"max.u32 nosign_1_1, nosign_1_1, 0x3c00; \n"
|
||||
"min.u32 nosign_1_1, nosign_1_1, 0x43f0; \n"
|
||||
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n"
|
||||
|
||||
"add.u32 nosign0, nosign0, rn_; \n" // nosign0 += rn_
|
||||
"add.u32 nosign1, nosign1, rn_; \n" // (round to nearest)
|
||||
"sub.u32 nosign0, nosign0, 0x3c003c00; \n" // nosign0-=0x3c003c00
|
||||
"sub.u32 nosign1, nosign1, 0x3c003c00; \n" // (compensate offset)
|
||||
"shr.u32 nosign0, nosign0, 4; \n" // nosign0 >>= 4
|
||||
"shr.u32 nosign1, nosign1, 4; \n" // shift into to fp8e4
|
||||
"prmt.b32 nosign, nosign0, nosign1, 0x6420; \n" // nosign0 = 0x00f100f2
|
||||
// nosign1 = 0x00f300f4
|
||||
// nosign = 0xf3f4f1f2
|
||||
"or.b32 $0, nosign, sign; \n" // restore sign
|
||||
"}";
|
||||
return convertBf16x4ToFp8x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
|
||||
#endif
|
||||
};
|
||||
=======
|
||||
SmallVector<Value> resultVals;
|
||||
for (auto it = allOperands.begin(), end = allOperands.end(); it != end;) {
|
||||
auto curr = static_cast<const ConcreteT *>(this)->createDestOps(
|
||||
@@ -1033,7 +534,6 @@ public:
|
||||
Value view = this->getTypeConverter()->packLLElements(loc, resultVals,
|
||||
rewriter, resultTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
|
||||
return success();
|
||||
}
|
||||
@@ -1208,110 +708,20 @@ struct FpToFpOpConversion
|
||||
for (Value &v : outVals)
|
||||
v = convertFp16ToFp32(loc, rewriter, v);
|
||||
// Pack values
|
||||
<<<<<<< HEAD
|
||||
assert(outVals.size() == elems);
|
||||
outVals = reorderValues(outVals, srcTensorType, dstTensorType);
|
||||
outVals =
|
||||
packI32(outVals, dstTensorType, rewriter, loc, getTypeConverter());
|
||||
auto result = getTypeConverter()->packLLElements(loc, outVals, rewriter,
|
||||
dstTensorType);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
return outVals;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OP>
|
||||
Value EmitDualBF16ElementwiseOp(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
ValueRange operands) {
|
||||
auto v0 = FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0]);
|
||||
auto v1 = FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[1]);
|
||||
MultipleOperandsRange operands) {
|
||||
auto v0 = FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][0]);
|
||||
auto v1 = FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][1]);
|
||||
auto result = rewriter.create<OP>(loc, f32_ty, v0, v1);
|
||||
return FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, result);
|
||||
}
|
||||
|
||||
template <typename SourceOp, typename ConcreteT>
|
||||
class ElementwiseOpConversionBase
|
||||
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
public:
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
|
||||
explicit ElementwiseOpConversionBase(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto resultTy = op.getType();
|
||||
Location loc = op->getLoc();
|
||||
// element type
|
||||
auto resultElementTy = getElementTypeOrSelf(resultTy);
|
||||
Type elemTy = this->getTypeConverter()->convertType(resultElementTy);
|
||||
SmallVector<Value> resultVals;
|
||||
//
|
||||
SmallVector<SmallVector<Value>> allOperands;
|
||||
for (auto operand : adaptor.getOperands()) {
|
||||
auto argTy = op->getOperand(0).getType();
|
||||
auto sub_operands = this->getTypeConverter()->unpackLLElements(
|
||||
loc, operand, rewriter, argTy);
|
||||
sub_operands = unpackI32(sub_operands, argTy, rewriter, loc,
|
||||
this->getTypeConverter());
|
||||
allOperands.resize(sub_operands.size());
|
||||
for (auto v : llvm::enumerate(sub_operands))
|
||||
allOperands[v.index()].push_back(v.value());
|
||||
}
|
||||
if (allOperands.size() == 0)
|
||||
allOperands.push_back({});
|
||||
for (const SmallVector<Value> &operands : allOperands) {
|
||||
Value curr =
|
||||
((ConcreteT *)(this))
|
||||
->createDestOp(op, adaptor, rewriter, elemTy, operands, loc);
|
||||
if (!bool(curr))
|
||||
return failure();
|
||||
resultVals.push_back(curr);
|
||||
}
|
||||
if (op->getNumOperands() > 0) {
|
||||
auto argTy = op->getOperand(0).getType();
|
||||
resultVals = reorderValues(resultVals, argTy, resultTy);
|
||||
}
|
||||
resultVals =
|
||||
packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
|
||||
Value view = this->getTypeConverter()->packLLElements(loc, resultVals,
|
||||
rewriter, resultTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SourceOp, typename DestOp>
|
||||
struct ElementwiseOpConversion
|
||||
: public ElementwiseOpConversionBase<
|
||||
SourceOp, ElementwiseOpConversion<SourceOp, DestOp>> {
|
||||
using Base =
|
||||
ElementwiseOpConversionBase<SourceOp,
|
||||
ElementwiseOpConversion<SourceOp, DestOp>>;
|
||||
using Base::Base;
|
||||
using OpAdaptor = typename Base::OpAdaptor;
|
||||
|
||||
explicit ElementwiseOpConversion(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ElementwiseOpConversionBase<SourceOp, ElementwiseOpConversion>(
|
||||
typeConverter, benefit) {}
|
||||
|
||||
// An interface to support variant DestOp builder.
|
||||
DestOp createDestOp(SourceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
return rewriter.create<DestOp>(loc, elemTy, operands,
|
||||
adaptor.getAttributes().getValue());
|
||||
=======
|
||||
return outVals;
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
}
|
||||
};
|
||||
|
||||
struct CmpIOpConversion
|
||||
: public ElementwiseOpConversionBase<triton::gpu::CmpIOp,
|
||||
CmpIOpConversion> {
|
||||
@@ -1458,20 +868,14 @@ struct FDivOpConversion
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
<<<<<<< HEAD
|
||||
Value createDestOp(mlir::arith::DivFOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
#ifdef USE_ROCM
|
||||
return rewriter.create<LLVM::FDivOp>(loc, elemTy, operands[0],
|
||||
operands[1]);
|
||||
#else
|
||||
=======
|
||||
SmallVector<Value> createDestOps(mlir::arith::DivFOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type elemTy, MultipleOperandsRange operands,
|
||||
Location loc) const {
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
#ifdef USE_ROCM
|
||||
return {rewriter.create<LLVM::FDivOp>(loc, elemTy, operands[0][0],
|
||||
operands[0][1])};
|
||||
#else
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &fdiv = *ptxBuilder.create<PTXInstr>("div");
|
||||
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
|
||||
@@ -1491,12 +895,8 @@ struct FDivOpConversion
|
||||
fdiv(res, lhs, rhs);
|
||||
|
||||
Value ret = ptxBuilder.launch(rewriter, loc, elemTy, false);
|
||||
<<<<<<< HEAD
|
||||
return ret;
|
||||
#endif
|
||||
=======
|
||||
return {ret};
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1515,7 +915,7 @@ struct FMulOpConversion
|
||||
auto rhsElemTy = getElementType(op.getRhs());
|
||||
if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) {
|
||||
#ifdef USE_ROCM
|
||||
return EmitDualBF16ElementwiseOp<LLVM::FMulOp>(loc, rewriter, operands);
|
||||
return {EmitDualBF16ElementwiseOp<LLVM::FMulOp>(loc, rewriter, operands)};
|
||||
#else
|
||||
PTXBuilder builder;
|
||||
auto ptxAsm = " { .reg .b16 c; \n"
|
||||
@@ -1526,12 +926,8 @@ struct FMulOpConversion
|
||||
auto lhs = builder.newOperand(operands[0][0], "h");
|
||||
auto rhs = builder.newOperand(operands[0][1], "h");
|
||||
fMul({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
|
||||
<<<<<<< HEAD
|
||||
return builder.launch(rewriter, loc, i16_ty, false);
|
||||
#endif
|
||||
=======
|
||||
return {builder.launch(rewriter, loc, i16_ty, false)};
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
#endif
|
||||
} else {
|
||||
return {rewriter.create<LLVM::FMulOp>(loc, elemTy, operands[0][0],
|
||||
operands[0][1])};
|
||||
@@ -1554,7 +950,7 @@ struct FAddOpConversion
|
||||
auto rhsElemTy = getElementType(op.getRhs());
|
||||
if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) {
|
||||
#ifdef USE_ROCM
|
||||
return EmitDualBF16ElementwiseOp<LLVM::FAddOp>(loc, rewriter, operands);
|
||||
return {EmitDualBF16ElementwiseOp<LLVM::FAddOp>(loc, rewriter, operands)};
|
||||
#else
|
||||
PTXBuilder builder;
|
||||
auto ptxAsm = "{ .reg .b16 c; \n"
|
||||
@@ -1565,12 +961,8 @@ struct FAddOpConversion
|
||||
auto lhs = builder.newOperand(operands[0][0], "h");
|
||||
auto rhs = builder.newOperand(operands[0][1], "h");
|
||||
fAdd({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
|
||||
<<<<<<< HEAD
|
||||
return builder.launch(rewriter, loc, i16_ty, false);
|
||||
#endif
|
||||
=======
|
||||
return {builder.launch(rewriter, loc, i16_ty, false)};
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
#endif
|
||||
} else {
|
||||
return {rewriter.create<LLVM::FAddOp>(loc, elemTy, operands[0][0],
|
||||
operands[0][1])};
|
||||
@@ -1593,7 +985,7 @@ struct FSubOpConversion
|
||||
auto rhsElemTy = getElementType(op.getRhs());
|
||||
if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) {
|
||||
#ifdef USE_ROCM
|
||||
return EmitDualBF16ElementwiseOp<LLVM::FSubOp>(loc, rewriter, operands);
|
||||
return {EmitDualBF16ElementwiseOp<LLVM::FSubOp>(loc, rewriter, operands)};
|
||||
#else
|
||||
PTXBuilder builder;
|
||||
auto ptxAsm = " { .reg .b16 c; \n"
|
||||
@@ -1604,12 +996,8 @@ struct FSubOpConversion
|
||||
auto lhs = builder.newOperand(operands[0][0], "h");
|
||||
auto rhs = builder.newOperand(operands[0][1], "h");
|
||||
fSub({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
|
||||
<<<<<<< HEAD
|
||||
return builder.launch(rewriter, loc, i16_ty, false);
|
||||
#endif
|
||||
=======
|
||||
return {builder.launch(rewriter, loc, i16_ty, false)};
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
#endif
|
||||
} else {
|
||||
return {rewriter.create<LLVM::FSubOp>(loc, elemTy, operands[0][0],
|
||||
operands[0][1])};
|
||||
@@ -1735,20 +1123,16 @@ struct ExpOpConversionApprox
|
||||
Value prod = fmul(f32_ty, operands[0][0], f32_val(log2e));
|
||||
|
||||
#ifdef USE_ROCM
|
||||
return rewriter.create<math::Exp2Op>(loc, f32_ty, prod,
|
||||
adaptor.getAttributes().getValue());
|
||||
return {rewriter.create<math::Exp2Op>(loc, f32_ty, prod,
|
||||
adaptor.getAttributes().getValue())};
|
||||
#else
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32");
|
||||
auto output = ptxBuilder.newOperand("=f");
|
||||
auto input = ptxBuilder.newOperand(prod, "f");
|
||||
exp2(output, input);
|
||||
<<<<<<< HEAD
|
||||
return ptxBuilder.launch(rewriter, loc, f32_ty, false);
|
||||
#endif
|
||||
=======
|
||||
return {ptxBuilder.launch(rewriter, loc, f32_ty, false)};
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -357,13 +357,9 @@ private:
|
||||
unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData();
|
||||
|
||||
SmallVector<Value> smemBases(op.getNumOperands());
|
||||
<<<<<<< HEAD
|
||||
if (sizeInterWarps > 1) {
|
||||
=======
|
||||
bool isWarpSync = helper.isWarpSynchronous();
|
||||
|
||||
if (!isWarpSync) {
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
smemBases[0] = bitcast(
|
||||
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]);
|
||||
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
|
||||
@@ -455,11 +451,7 @@ private:
|
||||
accumulate(rewriter, *combineOp, acc, shfl, false);
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
if (sizeInterWarps == 1) {
|
||||
=======
|
||||
if (isWarpSync) {
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
finalAccs[key] = acc;
|
||||
continue;
|
||||
}
|
||||
@@ -474,11 +466,7 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
if (sizeInterWarps == 1) {
|
||||
=======
|
||||
if (isWarpSync) {
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
SmallVector<Value> results(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
if (auto resultTy =
|
||||
@@ -513,12 +501,8 @@ private:
|
||||
|
||||
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
|
||||
unsigned numThreads =
|
||||
<<<<<<< HEAD
|
||||
product<unsigned>(triton::gpu::getWarpsPerCTA(srcLayout)) * wavefront_size;
|
||||
=======
|
||||
product<unsigned>(triton::gpu::getWarpsPerCTA(srcLayout)) *
|
||||
triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
unsigned elemsPerThread = std::max<unsigned>(elems / numThreads, 1);
|
||||
Value readOffset = threadId;
|
||||
for (unsigned round = 0; round < elemsPerThread; ++round) {
|
||||
|
||||
@@ -487,12 +487,8 @@ struct GetNumProgramsOpConversion
|
||||
#else
|
||||
Value blockId =
|
||||
rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxis()]);
|
||||
<<<<<<< HEAD
|
||||
rewriter.replaceOpWithNewOp<arith::TruncIOp>(op, i32_ty, blockId);
|
||||
#endif // USE_ROCM
|
||||
=======
|
||||
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, blockId);
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
#endif // USE_ROCM
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -618,11 +618,8 @@ private:
|
||||
computeCapability)
|
||||
.contains(byteWidth)) {
|
||||
return;
|
||||
<<<<<<< HEAD
|
||||
#endif
|
||||
=======
|
||||
}
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
#endif
|
||||
|
||||
// load
|
||||
auto tmpTy =
|
||||
|
||||
@@ -471,13 +471,8 @@ public:
|
||||
|
||||
void runOnOperation() override {
|
||||
// Only rewrite if the hardware does not support
|
||||
<<<<<<< HEAD
|
||||
if (!isROCM && computeCapability >= 90)
|
||||
return;
|
||||
=======
|
||||
// if (computeCapability >= 90)
|
||||
// return;
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
|
||||
// NOTES(Chenggang): we don't use `ConversionPatternRewriter`, because
|
||||
// MLIR does not support one-multiple value mapping. For example, if we use
|
||||
|
||||
@@ -1315,7 +1315,7 @@ struct TritonGPUInferLayoutInterface
|
||||
// Verify that the encodings are valid.
|
||||
if (!aEncoding || !bEncoding)
|
||||
return op->emitError("mismatching encoding between A and B operands");
|
||||
if (aEncoding.getMMAv2kWidth() != bEncoding.getMMAv2kWidth())
|
||||
if (aEncoding.getKWidth() != bEncoding.getKWidth())
|
||||
return op->emitError("mismatching kWidth between A and B operands");
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -431,15 +431,11 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
pushConversionForward(cvt, cvtSlices, sharedInfo, rewriter);
|
||||
=======
|
||||
// Call SimplifyReduceCvt instead of the general push conversion forward
|
||||
if (isa<triton::ReduceOp>(cvtSlices.front()))
|
||||
return failure();
|
||||
|
||||
pushConversionForward(cvt, cvtSlices, rewriter);
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
pushConversionForward(cvt, cvtSlices, sharedInfo, rewriter);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -91,9 +91,8 @@ public:
|
||||
return;
|
||||
op->moveAfter(argOp);
|
||||
});
|
||||
<<<<<<< HEAD
|
||||
// Move `dot` operand so that conversions to opIdx=0 happens before
|
||||
// conversions to opIdx=1
|
||||
// Move `dot` operand so that conversions to opIdx=1 happens after
|
||||
// conversions to opIdx=0
|
||||
#ifdef USE_ROCM
|
||||
// Skip this reordering for ROCm backend since it will sink shared->dot
|
||||
// conversion for Q tensor in flash attention into the main loop. This
|
||||
@@ -101,10 +100,6 @@ public:
|
||||
// iteration.
|
||||
return;
|
||||
#endif
|
||||
=======
|
||||
// Move `dot` operand so that conversions to opIdx=1 happens after
|
||||
// conversions to opIdx=0
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
m.walk([&](triton::gpu::ConvertLayoutOp op) {
|
||||
auto dstType = op.getResult().getType().cast<RankedTensorType>();
|
||||
auto dstEncoding =
|
||||
|
||||
@@ -16,11 +16,8 @@
|
||||
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
||||
<<<<<<< HEAD
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
=======
|
||||
#include "triton/Target/LLVMIR/Passes.h"
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
#include "triton/Tools/Sys/GetEnv.hpp"
|
||||
#include "llvm/IR/CallingConv.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
@@ -363,15 +360,12 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
// Simplify the IR
|
||||
pm.addPass(mlir::createCSEPass());
|
||||
pm.addPass(mlir::createSymbolDCEPass());
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
pm.addPass(mlir::createConvertSCFToCFPass());
|
||||
pm.addPass(createConvertControlFlowToLLVMPass());
|
||||
#endif
|
||||
=======
|
||||
if (!::triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"))
|
||||
pm.addPass(mlir::createLLVMDIScopePass());
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
|
||||
if (failed(pm.run(module))) {
|
||||
llvm::errs() << "Pass execution failed";
|
||||
|
||||
@@ -1052,126 +1052,108 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("create_shl",
|
||||
[](TritonOpBuilder &self, mlir::Value &lhs,
|
||||
mlir::Value &rhs) -> mlir::Value {
|
||||
<<<<<<< HEAD
|
||||
auto loc = self.getUnknownLoc();
|
||||
#ifdef USE_ROCM
|
||||
mlir::Type elementType = getElementTypeOrSelf(lhs.getType());
|
||||
unsigned typeWidth = elementType.getIntOrFloatBitWidth();
|
||||
auto constValue = self.create<mlir::arith::ConstantIntOp>(
|
||||
loc, typeWidth, elementType);
|
||||
typeWidth, elementType);
|
||||
auto zeroConst =
|
||||
self.create<mlir::arith::ConstantIntOp>(loc, 0, elementType);
|
||||
self.create<mlir::arith::ConstantIntOp>(0, elementType);
|
||||
if (lhs.getType().isIntOrIndex()) {
|
||||
auto cmpValue = self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ult, rhs, constValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShLIOp>(loc, lhs, rhs);
|
||||
mlir::arith::CmpIPredicate::ult, rhs, constValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShLIOp>(lhs, rhs);
|
||||
return mlir::Value(self.create<mlir::arith::SelectOp>(
|
||||
loc, cmpValue, shiftValue, zeroConst));
|
||||
cmpValue, shiftValue, zeroConst));
|
||||
} else {
|
||||
auto splatValue = self.create<mlir::tensor::SplatOp>(
|
||||
loc, lhs.getType(), constValue);
|
||||
lhs.getType(), constValue);
|
||||
auto zeroValue = self.create<mlir::tensor::SplatOp>(
|
||||
loc, lhs.getType(), zeroConst);
|
||||
lhs.getType(), zeroConst);
|
||||
auto cmpValue = self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ult, rhs, splatValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShLIOp>(loc, lhs, rhs);
|
||||
mlir::arith::CmpIPredicate::ult, rhs, splatValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShLIOp>(lhs, rhs);
|
||||
return mlir::Value(self.create<mlir::arith::SelectOp>(
|
||||
loc, cmpValue, shiftValue, zeroValue));
|
||||
cmpValue, shiftValue, zeroValue));
|
||||
}
|
||||
#else
|
||||
return mlir::Value(
|
||||
self.create<mlir::arith::ShLIOp>(loc, lhs, rhs));
|
||||
#endif
|
||||
=======
|
||||
return mlir::Value(self.create<mlir::arith::ShLIOp>(lhs, rhs));
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
#endif
|
||||
})
|
||||
.def("create_lshr",
|
||||
[](TritonOpBuilder &self, mlir::Value &lhs,
|
||||
mlir::Value &rhs) -> mlir::Value {
|
||||
<<<<<<< HEAD
|
||||
auto loc = self.getUnknownLoc();
|
||||
#ifdef USE_ROCM
|
||||
mlir::Type elementType = getElementTypeOrSelf(lhs.getType());
|
||||
unsigned typeWidth = elementType.getIntOrFloatBitWidth();
|
||||
auto constValue = self.create<mlir::arith::ConstantIntOp>(
|
||||
loc, typeWidth, elementType);
|
||||
typeWidth, elementType);
|
||||
auto zeroConst =
|
||||
self.create<mlir::arith::ConstantIntOp>(loc, 0, elementType);
|
||||
self.create<mlir::arith::ConstantIntOp>(0, elementType);
|
||||
if (lhs.getType().isIntOrIndex()) {
|
||||
auto cmpValue = self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ult, rhs, constValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShRUIOp>(loc, lhs, rhs);
|
||||
mlir::arith::CmpIPredicate::ult, rhs, constValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShRUIOp>(lhs, rhs);
|
||||
return mlir::Value(self.create<mlir::arith::SelectOp>(
|
||||
loc, cmpValue, shiftValue, zeroConst));
|
||||
cmpValue, shiftValue, zeroConst));
|
||||
} else {
|
||||
auto splatValue = self.create<mlir::tensor::SplatOp>(
|
||||
loc, lhs.getType(), constValue);
|
||||
lhs.getType(), constValue);
|
||||
auto zeroValue = self.create<mlir::tensor::SplatOp>(
|
||||
loc, lhs.getType(), zeroConst);
|
||||
lhs.getType(), zeroConst);
|
||||
auto cmpValue = self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ult, rhs, splatValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShRUIOp>(loc, lhs, rhs);
|
||||
mlir::arith::CmpIPredicate::ult, rhs, splatValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShRUIOp>(lhs, rhs);
|
||||
return mlir::Value(self.create<mlir::arith::SelectOp>(
|
||||
loc, cmpValue, shiftValue, zeroValue));
|
||||
cmpValue, shiftValue, zeroValue));
|
||||
}
|
||||
#else
|
||||
return mlir::Value(
|
||||
self.create<mlir::arith::ShRUIOp>(loc, lhs, rhs));
|
||||
#endif
|
||||
=======
|
||||
return mlir::Value(self.create<mlir::arith::ShRUIOp>(lhs, rhs));
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
#endif
|
||||
})
|
||||
.def("create_ashr",
|
||||
[](TritonOpBuilder &self, mlir::Value &lhs,
|
||||
mlir::Value &rhs) -> mlir::Value {
|
||||
<<<<<<< HEAD
|
||||
auto loc = self.getUnknownLoc();
|
||||
#ifdef USE_ROCM
|
||||
mlir::Type elementType = getElementTypeOrSelf(lhs.getType());
|
||||
unsigned typeWidth = elementType.getIntOrFloatBitWidth();
|
||||
auto constValue = self.create<mlir::arith::ConstantIntOp>(
|
||||
loc, typeWidth, elementType);
|
||||
typeWidth, elementType);
|
||||
auto zeroConst =
|
||||
self.create<mlir::arith::ConstantIntOp>(loc, 0, elementType);
|
||||
self.create<mlir::arith::ConstantIntOp>(0, elementType);
|
||||
uint64_t ones_val = 0xFFFFFFFFFFFFFFFF;
|
||||
auto onesConst =
|
||||
self.create<mlir::arith::ConstantIntOp>(loc, ones_val, elementType);
|
||||
self.create<mlir::arith::ConstantIntOp>(ones_val, elementType);
|
||||
if (lhs.getType().isIntOrIndex()) {
|
||||
auto negativeCmpValue = self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::slt, lhs, zeroConst);
|
||||
mlir::arith::CmpIPredicate::slt, lhs, zeroConst);
|
||||
auto otherValue = mlir::Value(self.create<mlir::arith::SelectOp>(
|
||||
loc, negativeCmpValue, onesConst, zeroConst));
|
||||
negativeCmpValue, onesConst, zeroConst));
|
||||
auto cmpValue = self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ult, rhs, constValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShRSIOp>(loc, lhs, rhs);
|
||||
mlir::arith::CmpIPredicate::ult, rhs, constValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShRSIOp>(lhs, rhs);
|
||||
return mlir::Value(self.create<mlir::arith::SelectOp>(
|
||||
loc, cmpValue, shiftValue, otherValue));
|
||||
cmpValue, shiftValue, otherValue));
|
||||
} else {
|
||||
auto splatValue = self.create<mlir::tensor::SplatOp>(
|
||||
loc, lhs.getType(), constValue);
|
||||
lhs.getType(), constValue);
|
||||
auto zeroValue = self.create<mlir::tensor::SplatOp>(
|
||||
loc, lhs.getType(), zeroConst);
|
||||
lhs.getType(), zeroConst);
|
||||
auto onesValue = self.create<mlir::tensor::SplatOp>(
|
||||
loc, lhs.getType(), onesConst);
|
||||
lhs.getType(), onesConst);
|
||||
auto negativeCmpValue = self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::slt, lhs, zeroValue);
|
||||
mlir::arith::CmpIPredicate::slt, lhs, zeroValue);
|
||||
auto otherValue = mlir::Value(self.create<mlir::arith::SelectOp>(
|
||||
loc, negativeCmpValue, onesValue, zeroValue));
|
||||
negativeCmpValue, onesValue, zeroValue));
|
||||
auto cmpValue = self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ult, rhs, splatValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShRSIOp>(loc, lhs, rhs);
|
||||
mlir::arith::CmpIPredicate::ult, rhs, splatValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShRSIOp>(lhs, rhs);
|
||||
return mlir::Value(self.create<mlir::arith::SelectOp>(
|
||||
loc, cmpValue, shiftValue, otherValue));
|
||||
cmpValue, shiftValue, otherValue));
|
||||
}
|
||||
#else
|
||||
return mlir::Value(
|
||||
self.create<mlir::arith::ShRSIOp>(loc, lhs, rhs));
|
||||
#endif
|
||||
=======
|
||||
return mlir::Value(self.create<mlir::arith::ShRSIOp>(lhs, rhs));
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
#endif
|
||||
})
|
||||
// AddPtr (similar to GEP)
|
||||
.def("create_addptr",
|
||||
|
||||
@@ -14,14 +14,9 @@ from typing import Any, Tuple
|
||||
from .._C.libtriton.triton import (add_external_libs, compile_ptx_to_cubin,
|
||||
get_shared_memory_size, ir,
|
||||
translate_llvmir_to_hsaco, translate_llvmir_to_ptx,
|
||||
<<<<<<< HEAD
|
||||
translate_triton_gpu_to_llvmir, get_arch_info,
|
||||
get_warp_size)
|
||||
from ..common.backend import get_backend
|
||||
=======
|
||||
translate_triton_gpu_to_llvmir)
|
||||
from ..common.backend import get_backend, path_to_ptxas
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
# from ..runtime import driver, jit, JITFunction
|
||||
# TODO: runtime.errors
|
||||
from ..runtime.autotuner import OutOfResources
|
||||
|
||||
@@ -294,15 +294,12 @@ class _attention(torch.autograd.Function):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
raise RuntimeError("Flash attention currently only supported for compute capability >= 80")
|
||||
<<<<<<< HEAD
|
||||
if torch.version.hip is not None:
|
||||
BLOCK = 64
|
||||
BLOCK_M = 64
|
||||
BLOCK_N = 64
|
||||
else:
|
||||
BLOCK = 128
|
||||
=======
|
||||
BLOCK_M = 128
|
||||
BLOCK_N = 64
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
BLOCK_M = 128
|
||||
BLOCK_N = 64
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
|
||||
@@ -25,11 +25,7 @@ class OutOfResources(Exception):
|
||||
|
||||
|
||||
class Autotuner(KernelInterface):
|
||||
<<<<<<< HEAD
|
||||
def __init__(self, fn, arg_names, configs, key, verbose, reset_to_zero, prune_configs_by: Dict = None):
|
||||
=======
|
||||
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100):
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
def __init__(self, fn, arg_names, configs, key, verbose, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100):
|
||||
'''
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
@@ -62,12 +58,9 @@ class Autotuner(KernelInterface):
|
||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||
self.early_config_prune = early_config_prune
|
||||
self.fn = fn
|
||||
<<<<<<< HEAD
|
||||
self.verbose = verbose
|
||||
=======
|
||||
self.warmup = warmup
|
||||
self.rep = rep
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
self.verbose = verbose
|
||||
|
||||
def _bench(self, *args, config, **meta):
|
||||
# check for conflicts, i.e. meta-parameters both provided
|
||||
@@ -187,11 +180,7 @@ class Config:
|
||||
return ', '.join(res)
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=False):
|
||||
=======
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25, rep=100):
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=False, warmup=25, rep=100):
|
||||
"""
|
||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||
|
||||
@@ -222,21 +211,15 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25,
|
||||
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
|
||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||
:type reset_to_zero: list[str]
|
||||
<<<<<<< HEAD
|
||||
:param verbose: a boolean that controls whether the best_config for each key is printed
|
||||
:type verbose: bool
|
||||
"""
|
||||
def decorator(fn):
|
||||
return Autotuner(fn, fn.arg_names, configs, key, verbose, reset_to_zero, prune_configs_by)
|
||||
=======
|
||||
:param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25.
|
||||
:type warmup: int
|
||||
:param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
|
||||
:type rep: int
|
||||
:param verbose: a boolean that controls whether the best_config for each key is printed
|
||||
:type verbose: bool
|
||||
"""
|
||||
def decorator(fn):
|
||||
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, warmup, rep)
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
return Autotuner(fn, fn.arg_names, configs, key, verbose, reset_to_zero, prune_configs_by, warmup, rep)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@@ -325,12 +325,8 @@ class JITFunction(KernelInterface[T]):
|
||||
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
|
||||
|
||||
src = f"""
|
||||
<<<<<<< HEAD
|
||||
|
||||
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
=======
|
||||
def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
from ..compiler import compile, CompiledKernel
|
||||
sig_key = {sig_keys},
|
||||
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
|
||||
|
||||
@@ -1,125 +0,0 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from .._C.libtriton.triton import ir
|
||||
# import triton.compiler.compiler as tc
|
||||
from ..compiler.compiler import (get_amdgpu_arch_fulldetails, llir_to_amdgcn_and_hsaco,
|
||||
llir_to_ptx, optimize_ttgir, optimize_ttir,
|
||||
ttgir_to_llir, ttir_to_ttgir, CUDA_DEFAULT_WARP_SIZE)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# valid source and target formats
|
||||
VALID_FORMATS = ['triton-ir', 'triton-gpu-ir', 'llvm-ir', 'ptx', 'amdgcn']
|
||||
|
||||
# set up the argument parser
|
||||
# TODO: conditional requirements
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('src', help="Source file to compile")
|
||||
parser.add_argument('--target', required=True,
|
||||
help="Target format, one of: " + ', '.join(VALID_FORMATS))
|
||||
parser.add_argument('--sm', type=int, help="Compute capability to compile for")
|
||||
parser.add_argument('--ptx-version', type=int, help="PTX version to compile for")
|
||||
parser.add_argument('--gfx', type=str, help="AMDGPU target to compile for")
|
||||
parser.add_argument('--triple', type=str, help="target triple, for example: amdgcn-amd-amdhsa")
|
||||
parser.add_argument('--features', type=str, help="target features, for example: +sramecc,-xnack")
|
||||
parser.add_argument('--num_warps', type=int, help="number of warps to compile ttgir for")
|
||||
|
||||
# parse the args
|
||||
args = parser.parse_args()
|
||||
|
||||
# TODO: clean-up and re-use triton.compiler primitive functions
|
||||
# check for validity of format arguments
|
||||
if args.target not in VALID_FORMATS:
|
||||
print("Invalid target format: " + args.target)
|
||||
sys.exit(0)
|
||||
|
||||
# parse source file to MLIR module
|
||||
context = ir.context()
|
||||
module = ir.parse_mlir_module(args.src, context)
|
||||
module.context = context
|
||||
|
||||
# optimizer triton-ir
|
||||
module = optimize_ttir(module, arch=args.sm)
|
||||
if args.target == 'triton-ir':
|
||||
print(module.str())
|
||||
sys.exit(0)
|
||||
|
||||
if not args.num_warps:
|
||||
args.num_warps = 4
|
||||
|
||||
# llvm-ir -> amdgcn
|
||||
if args.target == 'amdgcn':
|
||||
# auto detect available architecture and features
|
||||
# if nothing detected, set with default values
|
||||
arch_details = get_amdgpu_arch_fulldetails()
|
||||
if not arch_details:
|
||||
arch_name = ""
|
||||
arch_triple = "amdgcn-amd-amdhsa"
|
||||
arch_features = ""
|
||||
arch_warpsize = 64
|
||||
else:
|
||||
arch_triple, arch_name, arch_features, arch_warpsize = arch_details
|
||||
|
||||
# stop processing if architecture name is not automatically detected and is not set manually
|
||||
if not args.gfx and not arch_name:
|
||||
raise argparse.ArgumentError(None, "Must specify --gfx for AMDGCN compilation")
|
||||
|
||||
# rewrite default and automatically detected values with manually provided data
|
||||
if args.gfx:
|
||||
arch_name = args.gfx
|
||||
if args.triple:
|
||||
arch_triple = args.triple
|
||||
if args.features:
|
||||
arch_features = args.features
|
||||
|
||||
# triton-ir -> triton-gpu-ir
|
||||
# use compute_capability == 80
|
||||
module = ttir_to_ttgir(module, num_warps=args.num_warps, warpsize=arch_warpsize) # num_stages=3, compute_capability=80)
|
||||
module = optimize_ttgir(module, num_stages=3, arch=args.gfx)
|
||||
# triton-gpu-ir -> llvm-ir
|
||||
# use compute_capability == 80
|
||||
module = ttgir_to_llir(module, extern_libs=None, arch=args.gfx)
|
||||
# llvm-ir -> amdgcn asm, hsaco binary
|
||||
module, hsaco_path = llir_to_amdgcn_and_hsaco(module, arch_name, arch_triple, arch_features)
|
||||
|
||||
print(hsaco_path)
|
||||
print(module)
|
||||
sys.exit(0)
|
||||
|
||||
# set arch depending on platform
|
||||
if args.gfx:
|
||||
arch = args.gfx
|
||||
elif args.sm:
|
||||
arch = args.sm
|
||||
else:
|
||||
raise argparse.ArgumentError(None, "Must specify --sm or --gfx for ttgir compilation")
|
||||
|
||||
# triton-ir -> triton-gpu-ir
|
||||
module = ttir_to_ttgir(module, num_warps=args.num_warps, warpsize=CUDA_DEFAULT_WARP_SIZE)
|
||||
module = optimize_ttgir(module, num_stages=3, arch=arch)
|
||||
if args.target == 'triton-gpu-ir':
|
||||
print(module.str())
|
||||
sys.exit(0)
|
||||
|
||||
# triton-gpu-ir -> llvm-ir
|
||||
module = ttgir_to_llir(module, extern_libs=None, arch=arch)
|
||||
if args.target == 'llvm-ir':
|
||||
print(module)
|
||||
sys.exit(0)
|
||||
|
||||
# llvm-ir -> ptx
|
||||
if args.target == 'ptx':
|
||||
if not args.sm:
|
||||
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
|
||||
if not args.ptx_version:
|
||||
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
|
||||
module = llir_to_ptx(module, arch=args.sm, ptx_version=args.ptx_version)
|
||||
|
||||
# llvm-ir -> amdgcn
|
||||
if args.target == 'amdgcn':
|
||||
if not args.gfx:
|
||||
raise argparse.ArgumentError(None, "Must specify --gfx for AMDGCN compilation")
|
||||
module, hsaco_path = llir_to_amdgcn_and_hsaco(module, args.gfx)
|
||||
|
||||
print(module)
|
||||
@@ -39,16 +39,10 @@ def _fwd_kernel(
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
<<<<<<< HEAD
|
||||
q_offset = off_hz * stride_qh
|
||||
kv_offset = off_hz * stride_kh
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + q_offset,
|
||||
=======
|
||||
qvk_offset = off_hz * stride_qh
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + qvk_offset,
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
@@ -56,26 +50,16 @@ def _fwd_kernel(
|
||||
order=(1, 0)
|
||||
)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
<<<<<<< HEAD
|
||||
base=K + kv_offset,
|
||||
shape=(BLOCK_DMODEL, N_CTX + P_SEQ),
|
||||
=======
|
||||
base=K + qvk_offset,
|
||||
shape=(BLOCK_DMODEL, N_CTX),
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1)
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
<<<<<<< HEAD
|
||||
base=V + kv_offset,
|
||||
shape=(N_CTX + P_SEQ, BLOCK_DMODEL),
|
||||
=======
|
||||
base=V + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
@@ -97,11 +81,7 @@ def _fwd_kernel(
|
||||
q = (q * qk_scale).to(tl.float16)
|
||||
# loop over k, v and update accumulator
|
||||
lo = 0
|
||||
<<<<<<< HEAD
|
||||
hi = P_SEQ + (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + P_SEQ
|
||||
=======
|
||||
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
# -- load k, v --
|
||||
k = tl.load(K_block_ptr)
|
||||
@@ -109,11 +89,7 @@ def _fwd_kernel(
|
||||
# -- compute qk ---
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
if IS_CAUSAL:
|
||||
<<<<<<< HEAD
|
||||
qk = tl.where(P_SEQ + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
=======
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
qk += tl.dot(q, k)
|
||||
# -- compute scaling constant ---
|
||||
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
||||
@@ -135,11 +111,7 @@ def _fwd_kernel(
|
||||
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
||||
# write back O
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
<<<<<<< HEAD
|
||||
base=Out + q_offset,
|
||||
=======
|
||||
base=Out + qvk_offset,
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
@@ -152,11 +124,7 @@ def _fwd_kernel(
|
||||
@triton.jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO,
|
||||
<<<<<<< HEAD
|
||||
NewDO, Delta,
|
||||
=======
|
||||
Delta,
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
@@ -164,12 +132,10 @@ def _bwd_preprocess(
|
||||
# load
|
||||
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
# compute
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
delta = tl.sum(o * do, axis=1)
|
||||
# write-back
|
||||
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
|
||||
tl.store(Delta + off_m, delta)
|
||||
|
||||
|
||||
@@ -233,22 +199,13 @@ def _bwd_kernel(
|
||||
# load q, k, v, do on-chip
|
||||
q = tl.load(q_ptrs)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
<<<<<<< HEAD
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
qk = tl.dot(q, tl.trans(k))
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
if CAUSAL:
|
||||
qk = tl.dot(q, tl.trans(k), out_dtype=tl.float32)
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
else:
|
||||
qk = tl.dot(q, tl.trans(k), out_dtype=tl.float32)
|
||||
l_i = tl.load(l_ptrs + offs_m_curr)
|
||||
p = tl.math.exp2(qk * qk_scale - l_i[:, None])
|
||||
=======
|
||||
if CAUSAL:
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
|
||||
else:
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
qk *= qk_scale
|
||||
l_i = tl.load(l_ptrs + offs_m_curr)
|
||||
p = tl.math.exp2(qk - l_i[:, None])
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
|
||||
@@ -492,18 +449,13 @@ empty = torch.empty(128, device="cuda")
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
<<<<<<< HEAD
|
||||
def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False):
|
||||
=======
|
||||
def forward(ctx, q, k, v, causal, sm_scale):
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
BLOCK_M = 128
|
||||
<<<<<<< HEAD
|
||||
if torch.version.hip is None:
|
||||
BLOCK_N = 64 if Lk <= 64 else 32
|
||||
num_stages = 4 if Lk <= 64 else 3
|
||||
@@ -514,11 +466,6 @@ class _attention(torch.autograd.Function):
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2]
|
||||
=======
|
||||
BLOCK_N = 64
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
_fwd_kernel[grid](
|
||||
@@ -529,48 +476,36 @@ class _attention(torch.autograd.Function):
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
<<<<<<< HEAD
|
||||
q.shape[0], q.shape[1], q.shape[2], P_SEQ,
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
|
||||
IS_CAUSAL=causal,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages)
|
||||
=======
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
|
||||
IS_CAUSAL=causal,
|
||||
num_warps=num_warps,
|
||||
num_stages=4)
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L)
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = Lk
|
||||
ctx.causal = causal
|
||||
<<<<<<< HEAD
|
||||
ctx.split_kernel = split_kernel
|
||||
ctx.P_SEQ = P_SEQ
|
||||
=======
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
<<<<<<< HEAD
|
||||
BLOCK = 64
|
||||
q, k, v, o, l = ctx.saved_tensors
|
||||
=======
|
||||
BLOCK = 128
|
||||
# configuration is not supported
|
||||
assert(not (ctx.split_kernel and not ctx.causal))
|
||||
if torch.version.hip is not None:
|
||||
BLOCK = 64
|
||||
else:
|
||||
BLOCK = 128
|
||||
q, k, v, o, L = ctx.saved_tensors
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
do = do.contiguous()
|
||||
dq = torch.zeros_like(q)
|
||||
dq = torch.zeros_like(q, dtype=torch.float32)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
<<<<<<< HEAD
|
||||
delta = torch.empty_like(L)
|
||||
do_scaled = torch.empty_like(do)
|
||||
delta = torch.empty_like(l)
|
||||
# Figure out what BLOCK size fwd used and adjust num_blocks accordingly.
|
||||
# If the two are the same, we don't need this but the bwd pass block size
|
||||
# is smaller than the fwd so we need this scaling to ensure we loop over all
|
||||
@@ -588,8 +523,7 @@ class _attention(torch.autograd.Function):
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
l,
|
||||
delta,
|
||||
L, delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
@@ -597,15 +531,16 @@ class _attention(torch.autograd.Function):
|
||||
block_scale * ctx.grid[0],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
|
||||
CAUSAL=ctx.causal,
|
||||
num_stages=1,
|
||||
)
|
||||
else :
|
||||
dq = torch.zeros_like(q)
|
||||
_bwd_kernel_dk_dv[(block_scale * ctx.grid[0], ctx.grid[1])](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dk, dv,
|
||||
l,
|
||||
delta,
|
||||
L, delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
@@ -618,8 +553,7 @@ class _attention(torch.autograd.Function):
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq,
|
||||
l,
|
||||
delta,
|
||||
L, delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
@@ -630,36 +564,10 @@ class _attention(torch.autograd.Function):
|
||||
)
|
||||
# print(h.asm["ttgir"])
|
||||
return dq, dk, dv, None, None, None
|
||||
=======
|
||||
delta = torch.empty_like(L)
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do,
|
||||
delta,
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do,
|
||||
dq, dk, dv,
|
||||
L, delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
ctx.grid[0],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||
CAUSAL=ctx.causal,
|
||||
num_stages=1,
|
||||
)
|
||||
return dq, dk, dv, None, None
|
||||
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
|
||||
attention = _attention.apply
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, P_SEQ',
|
||||
[(4, 48, 1024, 64, 128),
|
||||
(4, 48, 2048, 64, 128),
|
||||
@@ -702,16 +610,6 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, P_SEQ, dtype=torch.float16):
|
||||
v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
sm_scale = q.shape[-1] ** (-0.5)
|
||||
split_kernel = True
|
||||
=======
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 9, 1024, 64)])
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
sm_scale = 0.5
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ)
|
||||
@@ -724,13 +622,8 @@ def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
<<<<<<< HEAD
|
||||
# # triton implementation
|
||||
tri_out = attention(q, k, v, causal, sm_scale, split_kernel)
|
||||
=======
|
||||
# triton implementation
|
||||
tri_out = attention(q, k, v, causal, sm_scale).half()
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
@@ -771,11 +664,7 @@ configs = [triton.testing.Benchmark(
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode, 'causal': causal}
|
||||
<<<<<<< HEAD
|
||||
) for mode in ['fwd', 'bwd'] for causal in [True, False]]
|
||||
=======
|
||||
) for mode in ['fwd', 'bwd'] for causal in [False, True]]
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
@@ -793,11 +682,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
|
||||
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
sm_scale = 1.3
|
||||
<<<<<<< HEAD
|
||||
fn = lambda: attention(q, k, v, causal, sm_scale, split_kernel)
|
||||
=======
|
||||
fn = lambda: attention(q, k, v, causal, sm_scale)
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
if mode == 'bwd':
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
|
||||
@@ -1638,25 +1638,6 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
}
|
||||
|
||||
// -----
|
||||
<<<<<<< HEAD
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [2, 1], order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
// CHECK-LABEL: atomic_add_f16
|
||||
tt.func @atomic_add_f16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: f16 {tt.difisibility = 16 : i32}) {
|
||||
%c1_i1 = arith.constant 1 : i1
|
||||
%1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%2 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x32xi32, #blocked>
|
||||
%3 = tt.broadcast %2 : (tensor<1x32xi32, #blocked>) -> tensor<32x32xi32, #blocked>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<32x32x!tt.ptr<f16>, #blocked>
|
||||
%5 = tt.addptr %4, %3 : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
|
||||
%6 = tt.splat %arg1 : (f16) -> tensor<32x32xf16, #blocked>
|
||||
%7 = tt.splat %c1_i1 : (i1) -> tensor<32x32xi1, #blocked>
|
||||
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: @$3 atom.global.gpu.add.noftz.f16x2
|
||||
// GCN-COUNT-8: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f16, 1>, f16
|
||||
%8 = "tt.atomic_rmw"(%5, %6, %7) {atomic_rmw_op = 5 : i32, sem = 1: i32} : (tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xf16, #blocked>, tensor<32x32xi1, #blocked>) -> tensor<32x32xf16, #blocked>
|
||||
=======
|
||||
|
||||
#mma = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[2, 2]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||
@@ -1721,6 +1702,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: test_s8_to_bf16_conversion
|
||||
@@ -1735,6 +1717,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1]}>
|
||||
#dot = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -1746,7 +1729,31 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-NOT: llvm.inline_asm
|
||||
%out = arith.sitofp %in : tensor<16x16xi8, #mma> to tensor<16x16xbf16, #mma>
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [2, 1], order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
// CHECK-LABEL: atomic_add_f16
|
||||
tt.func @atomic_add_f16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: f16 {tt.difisibility = 16 : i32}) {
|
||||
%c1_i1 = arith.constant 1 : i1
|
||||
%1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%2 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x32xi32, #blocked>
|
||||
%3 = tt.broadcast %2 : (tensor<1x32xi32, #blocked>) -> tensor<32x32xi32, #blocked>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<32x32x!tt.ptr<f16>, #blocked>
|
||||
%5 = tt.addptr %4, %3 : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
|
||||
%6 = tt.splat %arg1 : (f16) -> tensor<32x32xf16, #blocked>
|
||||
%7 = tt.splat %c1_i1 : (i1) -> tensor<32x32xi1, #blocked>
|
||||
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: @$3 atom.global.gpu.add.noftz.f16x2
|
||||
// GCN-COUNT-8: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f16, 1>, f16
|
||||
%8 = "tt.atomic_rmw"(%5, %6, %7) {atomic_rmw_op = 5 : i32, sem = 1: i32} : (tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xf16, #blocked>, tensor<32x32xi1, #blocked>) -> tensor<32x32xf16, #blocked>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
// RUN: %PYTHON -m triton.tools.aot %s --num_warps=1 --target=amdgcn --gfx=gfx906 --triple=amdgcn-amd-amdhsa --features="+sramecc,-xnack" | FileCheck %s
|
||||
// RUN: %PYTHON -m triton.tools.aot %s --num_warps=1 --target=amdgcn --gfx=gfx908 --triple=amdgcn-amd-amdhsa --features="+sramecc,-xnack" | FileCheck %s
|
||||
// RUN: %PYTHON -m triton.tools.aot %s --num_warps=1 --target=amdgcn --gfx=gfx90a --triple=amdgcn-amd-amdhsa --features="+sramecc,-xnack" | FileCheck %s
|
||||
|
||||
// == LLVM IR check begin ==
|
||||
// CHECK-LABEL: {{^}}test_float16_load:
|
||||
// CHECK: global_load_dword
|
||||
// CHECK: global_load_dword
|
||||
// CHECK: global_store_dword
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
tt.func public @test_float16_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked>
|
||||
%1 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x!tt.ptr<f16>, #blocked>
|
||||
%2 = tt.addptr %1, %0 : tensor<128x!tt.ptr<f16>, #blocked>, tensor<128xi32, #blocked>
|
||||
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf16, #blocked>
|
||||
%4 = tt.splat %arg2 : (!tt.ptr<f16>) -> tensor<128x!tt.ptr<f16>, #blocked>
|
||||
%5 = tt.addptr %4, %0 : tensor<128x!tt.ptr<f16>, #blocked>, tensor<128xi32, #blocked>
|
||||
%6 = tt.load %5 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf16, #blocked>
|
||||
%7 = arith.addf %3, %6 : tensor<128xf16, #blocked>
|
||||
%8 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<128x!tt.ptr<f16>, #blocked>
|
||||
%9 = tt.addptr %8, %0 : tensor<128x!tt.ptr<f16>, #blocked>, tensor<128xi32, #blocked>
|
||||
tt.store %9, %7 {cache = 1 : i32, evict = 1 : i32} : tensor<128xf16, #blocked>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
// RUN: %PYTHON -m triton.tools.aot %s --num_warps=1 --target=amdgcn --gfx=gfx906 --triple=amdgcn-amd-amdhsa --features="+sramecc,-xnack" | FileCheck %s
|
||||
// RUN: %PYTHON -m triton.tools.aot %s --num_warps=1 --target=amdgcn --gfx=gfx908 --triple=amdgcn-amd-amdhsa --features="+sramecc,-xnack" | FileCheck %s
|
||||
// RUN: %PYTHON -m triton.tools.aot %s --num_warps=1 --target=amdgcn --gfx=gfx90a --triple=amdgcn-amd-amdhsa --features="+sramecc,-xnack" | FileCheck %s
|
||||
|
||||
// == LLVM IR check begin ==
|
||||
// CHECK-LABEL: {{^}}test_int16_load:
|
||||
// CHECK: global_load_dword
|
||||
// CHECK: global_load_dword
|
||||
// CHECK: global_store_dword
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
tt.func public @test_int16_load(%arg0: !tt.ptr<i16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i16> {tt.divisibility = 16 : i32}) {
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked>
|
||||
%1 = tt.splat %arg1 : (!tt.ptr<i16>) -> tensor<128x!tt.ptr<i16>, #blocked>
|
||||
%2 = tt.addptr %1, %0 : tensor<128x!tt.ptr<i16>, #blocked>, tensor<128xi32, #blocked>
|
||||
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xi16, #blocked>
|
||||
%4 = tt.splat %arg2 : (!tt.ptr<i16>) -> tensor<128x!tt.ptr<i16>, #blocked>
|
||||
%5 = tt.addptr %4, %0 : tensor<128x!tt.ptr<i16>, #blocked>, tensor<128xi32, #blocked>
|
||||
%6 = tt.load %5 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xi16, #blocked>
|
||||
%7 = arith.addi %3, %6 : tensor<128xi16, #blocked>
|
||||
%8 = tt.splat %arg0 : (!tt.ptr<i16>) -> tensor<128x!tt.ptr<i16>, #blocked>
|
||||
%9 = tt.addptr %8, %0 : tensor<128x!tt.ptr<i16>, #blocked>, tensor<128xi32, #blocked>
|
||||
tt.store %9, %7 {cache = 1 : i32, evict = 1 : i32} : tensor<128xi16, #blocked>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
// RUN: %PYTHON -m triton.tools.aot %s --target=amdgcn --gfx=gfx906 --triple=amdgcn-amd-amdhsa --features="+sramecc,-xnack" | FileCheck %s
|
||||
|
||||
// == LLVM IR check begin ==
|
||||
// CHECK-LABEL: {{^}}test_empty_kernel:
|
||||
// CHECK-NEXT: s_endpgm
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
|
||||
tt.return
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
// RUN: export ROCM_PATH=/opt/rocm
|
||||
// RUN: HSACO_PATH=$(%PYTHON -m triton.tools.aot %s --target=amdgcn --gfx=gfx906 --triple=amdgcn-amd-amdhsa --features="+sramecc,-xnack" | head -n 1)
|
||||
// RUN: llvm-readobj -a "${HSACO_PATH}" | FileCheck %s
|
||||
|
||||
// CHECK: Format: elf64-amdgpu
|
||||
// TODO: Arch: unknown
|
||||
// CHECK: AddressSize: 64bit
|
||||
// CHECK: ElfHeader {
|
||||
// CHECK-NEXT: Ident {
|
||||
// CHECK-NEXT: Magic: (7F 45 4C 46)
|
||||
// CHECK-NEXT: Class: 64-bit (0x2)
|
||||
// CHECK-NEXT: DataEncoding: LittleEndian (0x1)
|
||||
// CHECK-NEXT: FileVersion: 1
|
||||
// CHECK-NEXT: OS/ABI: AMDGPU_HSA (0x40)
|
||||
// CHECK-NEXT: ABIVersion: 2
|
||||
// CHECK-NEXT: Unused: (00 00 00 00 00 00 00)
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: Type: SharedObject (0x3)
|
||||
// CHECK-NEXT: Machine: EM_AMDGPU (0xE0)
|
||||
// CHECK-NEXT: Version: 1
|
||||
|
||||
// CHECK: Name: test_empty_kernel
|
||||
// CHECK: Size: 4
|
||||
// CHECK: Binding: Global
|
||||
// CHECK: Type: Function
|
||||
// CHECK: Section: .text
|
||||
|
||||
// CHECK: Type: NT_AMDGPU_METADATA (AMDGPU Metadata)
|
||||
// CHECK: .group_segment_fixed_size: 0
|
||||
// CHECK-NEXT: .kernarg_segment_align: 8
|
||||
// CHECK-NEXT: .kernarg_segment_size: 16
|
||||
// CHECK-NEXT: .max_flat_workgroup_size: 256
|
||||
// CHECK-NEXT: .name: test_empty_kernel
|
||||
// CHECK-NEXT: .private_segment_fixed_size: 0
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
|
||||
tt.return
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
// RUN: %PYTHON -m triton.tools.aot %s --target=llvm-ir --gfx=90a | FileCheck %s
|
||||
|
||||
// == LLVM IR check begin ==
|
||||
// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule'
|
||||
// CHECK: define amdgpu_kernel void @test_empty_kernel
|
||||
// XHECK: !nvvm.annotations
|
||||
// XHECK: !{ptr @test_empty_kernel, !"maxntidx", i32 128}
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
|
||||
tt.return
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
// RUN: %PYTHON -m triton.tools.aot %s --target=llvm-ir --gfx=90a | FileCheck %s
|
||||
|
||||
// == LLVM IR check begin ==
|
||||
// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule'
|
||||
// CHECK: define void @test_func
|
||||
// CHECK: define amdgpu_kernel void @test_kernel
|
||||
// CHECK: tail call void @test_func
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
tt.func @test_func(%lb : index, %A : !tt.ptr<f16>) attributes { noinline = true } {
|
||||
%0 = arith.constant 1.0 : f16
|
||||
tt.store %A, %0 : f16
|
||||
tt.return
|
||||
}
|
||||
|
||||
tt.func @test_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
tt.call @test_func(%lb, %A) : (index, !tt.ptr<f16>) -> ()
|
||||
tt.return
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user