Resolve merge conflicts; AMD adjustments for new LLVM version

This commit is contained in:
Jason Furmanek
2023-11-09 19:00:49 +00:00
parent 977d5aa267
commit 484852876e
11 changed files with 125 additions and 345 deletions

View File

@@ -147,17 +147,13 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
llvm::outs() << *llvmir << '\n';
} else if (targetKind == "ptx") {
llvm::outs() << ::triton::translateLLVMIRToPTX(*llvmir, SMArch.getValue(),
<<<<<<< HEAD
ptxVersion.getValue());
ptxVersion.getValue(),
enableFpFusion.getValue());
} else if (targetKind == "hsaco") {
auto [module, hsaco] = mlir::triton::translateLLVMIRToHSACO(
*llvmir, GCNArch.getValue(), GCNTriple.getValue(),
GCNFeatures.getValue());
llvm::outs() << hsaco;
=======
ptxVersion.getValue(),
enableFpFusion.getValue());
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
} else {
llvm::errs() << "Error: Unknown target specified: " << targetKind << "\n";
return failure();

View File

@@ -6,24 +6,6 @@ using ::mlir::triton::gpu::getTotalElemsPerThread;
/* ----- FP8E5M2 ------ */
// This data-type is the standard FP8E5M2 format
static const std::string Fp16_to_Fp8E5M2(bool hasNativeFP) {
std::string ret;
if (!hasNativeFP) {
ret = "{ \n"
".reg .b32 a<2>; \n"
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
"}";
} else {
ret = "cvt.rn.satfinite.e5m2x2.f16x2 $0, $1; \n\t";
}
return ret;
}
<<<<<<< HEAD
#ifdef USE_ROCM
static SmallVector<Value>
Fp16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter,
@@ -50,15 +32,22 @@ Fp16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter,
};
}
#else
const std::string Fp16_to_Fp8E5M2 =
"{ \n"
".reg .b32 a<2>; \n"
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
"}";
static const std::string Fp16_to_Fp8E5M2(bool hasNativeFP) {
std::string ret;
if (!hasNativeFP) {
ret = "{ \n"
".reg .b32 a<2>; \n"
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
"}";
} else {
ret = "cvt.rn.satfinite.e5m2x2.f16x2 $0, $1; \n\t";
}
return ret;
}
#endif
#ifdef USE_ROCM
@@ -141,10 +130,18 @@ Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
};
}
#else
const std::string Fp8E5M2_to_Fp16 = "{ \n"
"prmt.b32 $0, 0, $2, 0x5140; \n\t"
"prmt.b32 $1, 0, $2, 0x7362; \n\t"
"}";
static const std::string Fp8E5M2_to_Fp16(bool hasNativeFP) {
std::string ret;
if (!hasNativeFP) {
ret = "{ \n"
"prmt.b32 $0, 0, $2, 0x5140; \n\t"
"prmt.b32 $1, 0, $2, 0x7362; \n\t"
"}";
} else {
ret = "cvt.rn.f16x2.e5m2x2 $0, $1; \n\t";
}
return ret;
}
#endif
#ifdef USE_ROCM
@@ -195,11 +192,6 @@ Fp8E5M2FNUZ_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
return result;
}
#else
const std::string Fp8E5M2FNUZ_to_Fp16 = "{ \n"
"prmt.b32 $0, 0, $2, 0x5140; \n\t"
"prmt.b32 $1, 0, $2, 0x7362; \n\t"
"}";
#endif
#ifdef USE_ROCM
@@ -245,21 +237,47 @@ Fp8E5M2_to_Bf16(Location loc, ConversionPatternRewriter &rewriter,
};
}
#else
const std::string Fp8E5M2_to_Bf16 =
"{ \n"
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
"shr.b32 b1, b1, 3; \n" // shift into bf16 position
"add.u32 b0, b0, 0x38003800; \n" // b0.exp += 2**7-2**4
// exponent compensate = 112
"add.u32 b1, b1, 0x38003800; \n" // b1 += 112<<7 | 112<<7<<16
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
"}";
static const std::string Fp8E5M2_to_Bf16(bool hasNativeFP) {
std::string ret;
if (!hasNativeFP) {
ret = "{ \n"
".reg .b32 a<2>, b<2>, c<4>, d<4>, e112; \n" // if input = 0xf1f2f3f4
"mov.u32 e112, 0x77800000; \n"
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
"shr.b32 b1, b1, 3; \n" // shift into bf16
// position
"and.b32 c0, b0, 0xFFFF0000; \n" // c0 = f3
"shl.b32 c1, b0, 16; \n" // c1 = f4
"and.b32 c2, b1, 0xFFFF0000; \n" // c2 = f1
"shl.b32 c3, b1, 16; \n" // c3 = f2
"mul.f32 d0, c0, e112; \n" // d0 = c0 * 0x77800000
"mul.f32 d1, c1, e112; \n" // d1 = c1 * 0x77800000
"mul.f32 d2, c2, e112; \n" // d2 = c2 * 0x77800000
"mul.f32 d3, c3, e112; \n" // d3 = c3 * 0x77800000
"prmt.b32 b0, d0, d1, 0x3276; \n" // b0 = 0xd3d4
"prmt.b32 b1, d2, d3, 0x3276; \n" // b1 = 0xd1d2
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 =
// b0|(0x80008000&a0)
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
"}";
} else {
ret = "{ \n"
".reg .b32 a; \n"
".reg .f16 a<2>; \n"
".reg .b16 b<2>; \n"
"cvt.rn.f16x2.e5m2x2 a, $1; \n"
"mov.b32 {a0, a1}, a; \n"
"cvt.bf16.f16 b0, a0; \n"
"cvt.bf16.f16 b1, a1; \n"
"mov.b32 $0, {b0, b1}; \n"
"}";
}
return ret;
}
#endif
#ifdef USE_ROCM
@@ -331,96 +349,6 @@ Bf16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter,
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
}
#else
const std::string Bf16_to_Fp8E5M2 =
"{ \n" // bf16=fp8>>3 + 112<<7
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n" // fp8_min = 0b00000000
".reg .u32 fp8_min, fp8_max, rn_; \n" // fp8_max = 0b11111111
"mov.u32 fp8_min, 0x38003800; \n" // so bf16_min = 0x3800
"mov.u32 fp8_max, 0x57e057e0; \n" // so bf16_max = 0x57e0
"mov.u32 rn_, 0x00100010; \n" // round to nearest
"and.b32 sign0, $1, 0x80008000; \n" // sign0=in0&0x80008000
"and.b32 sign1, $2, 0x80008000; \n" // (store sign)
"prmt.b32 sign, sign0, sign1, 0x7531; \n"
"and.b32 nosign0, $1, 0x7fff7fff; \n" // nosign0=in0&0x7fff7fff
"and.b32 nosign1, $2, 0x7fff7fff; \n" // (strip sign)
// nosign = clamp(nosign, min, max)
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n"
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n"
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n"
"min.u32 nosign_0_0, nosign_0_0, 0x57e00000; \n"
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n"
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n"
"min.u32 nosign_0_1, nosign_0_1, 0x57e0; \n"
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n"
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n"
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n"
"min.u32 nosign_1_0, nosign_1_0, 0x57e00000; \n"
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n"
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n"
"min.u32 nosign_1_1, nosign_1_1, 0x57e0; \n"
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n"
"add.u32 nosign0, nosign0, rn_; \n" // nosign0 += rn_
"add.u32 nosign1, nosign1, rn_; \n" // (round to nearest)
"sub.u32 nosign0, nosign0, 0x38003800; \n" // nosign0-=0x38003800
"sub.u32 nosign1, nosign1, 0x38003800; \n" // (compensate offset)
"shl.b32 nosign0, nosign0, 3; \n" // nosign0 <<= 3
"shl.b32 nosign1, nosign1, 3; \n" // shift into to fp8e4
"prmt.b32 nosign, nosign0, nosign1, 0x7531; \n" // nosign0 = 0xf100f200
// nosign1 = 0xf300f400
// nosign = 0xf3f4f1f2
"or.b32 $0, nosign, sign; \n" // restore sign
"}";
#endif
=======
static const std::string Fp8E5M2_to_Fp16(bool hasNativeFP) {
std::string ret;
if (!hasNativeFP) {
ret = "{ \n"
"prmt.b32 $0, 0, $2, 0x5140; \n\t"
"prmt.b32 $1, 0, $2, 0x7362; \n\t"
"}";
} else {
ret = "cvt.rn.f16x2.e5m2x2 $0, $1; \n\t";
}
return ret;
}
static const std::string Fp8E5M2_to_Bf16(bool hasNativeFP) {
std::string ret;
if (!hasNativeFP) {
ret =
"{ \n"
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
"shr.b32 b1, b1, 3; \n" // shift into bf16 position
"add.u32 b0, b0, 0x38003800; \n" // b0.exp += 2**7-2**4
// exponent compensate = 112
"add.u32 b1, b1, 0x38003800; \n" // b1 += 112<<7 | 112<<7<<16
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
"}";
} else {
ret = "{ \n"
".reg .b32 a; \n"
".reg .f16 a<2>; \n"
".reg .b16 b<2>; \n"
"cvt.rn.f16x2.e5m2x2 a, $1; \n"
"mov.b32 {a0, a1}, a; \n"
"cvt.bf16.f16 b0, a0; \n"
"cvt.bf16.f16 b1, a1; \n"
"mov.b32 $0, {b0, b1}; \n"
"}";
}
return ret;
}
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
static const std::string Bf16_to_Fp8E5M2(bool hasNativeFP) {
std::string ret;
if (!hasNativeFP) {
@@ -477,15 +405,14 @@ static const std::string Bf16_to_Fp8E5M2(bool hasNativeFP) {
}
return ret;
}
#endif
/* ----- FP8E4M3B15 ------ */
// This data-type is a variant of the standard FP8E4M3 format.
// It was designed for fast software conversion to FP16 on
// nvidia GPUs that do not support it natively.
<<<<<<< HEAD
// Specifically, this data-type:
// - has infinities
// - has multiple nans (when all exponent bits are 1)
// - has an exponent bias of 15 (vs. 7 for fp8e4m3)
// This is the same format as FP8E4M3Nv, but:
// - the exponent bias is 15 instead of 7
// - 0xff and 0x7f are mapped to +-1.750 instead of +-nan
#ifdef USE_ROCM
static SmallVector<Value>
Fp8E4M3B15_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
@@ -525,11 +452,6 @@ Fp8E4M3B15_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
};
}
#else
=======
// This is the same format as FP8E4M3Nv, but:
// - the exponent bias is 15 instead of 7
// - 0xff and 0x7f are mapped to +-1.750 instead of +-nan
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
const std::string Fp8E4M3B15_to_Fp16 =
"{ \n"
".reg .b32 a<2>, b<2>; \n"
@@ -542,7 +464,6 @@ const std::string Fp8E4M3B15_to_Fp16 =
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
"shl.b32 $1, b1, 7; \n"
"} \n";
<<<<<<< HEAD
#endif
#ifdef USE_ROCM
@@ -590,11 +511,7 @@ Fp16_to_Fp8E4M3B15(Location loc, ConversionPatternRewriter &rewriter,
};
}
#else
const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
=======
static const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
std::string ret;
ret += "{ \n"
".reg .pred p<4>; \n"
@@ -640,7 +557,6 @@ static const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
// $0 = (($2 << 1) & 0x80008000u) | (($2 << 7) & 0x3f803f80u);
// $1 = (($2 << 0) & 0x80008000u) | (($2 << 0) & 0x3f803f80u);
// WARN: subnormal (0bs0000xxx) are not handled
<<<<<<< HEAD
#ifdef USE_ROCM
static SmallVector<Value>
Fp8E4M3B15x4_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
@@ -671,10 +587,7 @@ Fp8E4M3B15x4_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
};
}
#else
const std::string Fp8E4M3B15x4_to_Fp16 =
=======
static const std::string Fp8E4M3B15x4_to_Fp16 =
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
"{ \n"
".reg .b32 a<2>; \n"
"add.u32 a0, $2, $2; \n"
@@ -692,7 +605,6 @@ static const std::string Fp8E4M3B15x4_to_Fp16 =
// ((e4.y >> 0) & (0x80008000u >> 0)) |
// ((e4.y >> 0) & (0x3f803f80u >> 0)) ;
// WARN: subnormal (0bs0000xxx) are not handled
<<<<<<< HEAD
#ifdef USE_ROCM
static SmallVector<Value>
Fp16_to_Fp8E4M3B15x4(Location loc, ConversionPatternRewriter &rewriter,
@@ -726,10 +638,7 @@ Fp16_to_Fp8E4M3B15x4(Location loc, ConversionPatternRewriter &rewriter,
};
}
#else
const std::string Fp16_to_Fp8E4M3B15x4 =
=======
static const std::string Fp16_to_Fp8E4M3B15x4 =
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
"{ \n"
".reg .b32 a<2>; \n"
"shr.b32 a0, $1, 1; \n"
@@ -1040,7 +949,6 @@ const std::string Bf16_to_Fp8E4M3 =
// nosign = 0xf3f4f1f2
"or.b32 $0, nosign, sign; \n" // restore sign
"}";
#endif
// Fp8E4M3 (x2) -> Fp16 (x2) (packed)
static const std::string Fp8E4M3Nv_to_Fp16 = "{ \n"
@@ -1052,7 +960,6 @@ static const std::string Fp16_to_Fp8E4M3Nv =
"cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; \n"
"}";
#ifndef USE_ROCM
// Fp8E4M3 (x2) -> Fp16 (x2) (packed)
static const std::string Fp8E4M3Nv_to_Bf16 =
"{ \n"
@@ -1090,13 +997,13 @@ static const std::string S8_to_Bf16 =
"prmt.b32 $0, f0, f1, 0x7632; \n" // f32->bf16 + pack
"prmt.b32 $1, f2, f3, 0x7632; \n" //
"}";
#endif
// Fp32 (x2) -> Fp8 (x2) (packed)
static const std::string Fp32_to_Fp8E4M3Nv =
"cvt.rn.satfinite.e4m3x2.f32 $0, $2, $1; \n";
static const std::string Fp32_to_Fp8E5M2 =
"cvt.rn.satfinite.e5m2x2.f32 $0, $2, $1; \n";
#endif
static SmallVector<Value> reorderValues(const SmallVector<Value> &values,
Type inType, Type ouType) {
@@ -1529,49 +1436,46 @@ struct FpToFpOpConversion
// F8 -> F16
{{F8E4M3B15TyID, F16TyID}, Fp8E4M3B15_to_Fp16},
{{F8E4M3FNTyID, F16TyID}, Fp8E4M3B15x4_to_Fp16},
<<<<<<< HEAD
{{F8E4M3FNUZTyID, F16TyID}, Fp8E4M3FNUZ_to_Fp16},
#ifdef USE_ROCM
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16},
{{F8E5M2FNUZTyID, F16TyID}, Fp8E5M2FNUZ_to_Fp16},
=======
#else
{{F8E4M3TyID, F16TyID}, Fp8E4M3Nv_to_Fp16},
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16(computeCapability >= 90)},
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
#endif
// F16 -> F8
#ifdef USE_ROCM
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15},
#else
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15(computeCapability >= 80)},
#endif
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
<<<<<<< HEAD
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
{{F16TyID, F8E4M3FNUZTyID}, Fp16_to_Fp8E4M3FNUZ},
#ifdef USE_ROCM
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2},
{{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ},
// F8 -> BF16
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
#ifndef USE_ROCM
=======
#else
{{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3Nv},
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2(computeCapability >= 90)},
// F8 -> BF16
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)},
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
#endif
// F8 -> BF16
#ifdef USE_ROCM
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
#else
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)},
{{F8E4M3TyID, BF16TyID}, Fp8E4M3Nv_to_Bf16},
#endif
// BF16 -> F8
<<<<<<< HEAD
// BF16 -> F8
#ifdef USE_ROCM
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2},
#ifndef USE_ROCM
{{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3Nv},
#endif
=======
#else
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2(computeCapability >= 90)},
{{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3Nv},
// F32 -> F8
{{F32TyID, F8E4M3TyID}, Fp32_to_Fp8E4M3Nv},
{{F32TyID, F8E5M2TyID}, Fp32_to_Fp8E5M2},
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
#endif
};
int inVecWidthBits = 32;
int outVecWidthBits = 32;

View File

@@ -1150,7 +1150,7 @@ struct AtomicRMWOpConversion
#ifdef USE_ROCM
/// Try to match the mlir::triton::RMWOp to LLVM::AtomicBinOp.
static Optional<LLVM::AtomicBinOp> matchAtomicOp(RMWOp atomicOp) {
static std::optional<LLVM::AtomicBinOp> matchAtomicOp(RMWOp atomicOp) {
switch (atomicOp) {
case RMWOp::AND:
return LLVM::AtomicBinOp::_and;

View File

@@ -489,22 +489,11 @@ struct GetNumProgramsOpConversion
ConversionPatternRewriter &rewriter) const override {
#ifdef USE_ROCM
Location loc = op->getLoc();
assert(op.getAxis() < 3);
// Seem like GridDimOp returns the number of threads (not the number of
// workgroups) in a kernel (a bug in llvm https://reviews.llvm.org/D156009),
// so as a workaround here, we divide by the number of threads
// per workgroup to get the number of workgroups in a kernel.
// TODO: when we do upstream to include llvm fix, we can remove this workaround
// The unit test added in this PR can guarantee that.
Value threadsPerGrid =
Value blockId =
rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxis()]);
Value threadsPerBlock =
rewriter.create<::mlir::gpu::BlockDimOp>(loc, dims[op.getAxis()]);
Value threadNumPerGrid = rewriter.create<arith::TruncIOp>(loc, i32_ty, threadsPerGrid);
Value threadNumPerBlock = rewriter.create<arith::TruncIOp>(loc, i32_ty, threadsPerBlock);
rewriter.replaceOpWithNewOp<LLVM::UDivOp>(op, threadNumPerGrid, threadNumPerBlock);
rewriter.replaceOpWithNewOp<arith::TruncIOp>(op, i32_ty, blockId);
return success();
#else
// It is not easy to get the compute capability here, so we use numCTAs to

View File

@@ -26,7 +26,12 @@ warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
auto slices = mlir::getSlice(dotOp, filter);
mlir::ForwardSliceOptions fwdOpt;
fwdOpt.filter = filter;
mlir::BackwardSliceOptions bwdOpt;
bwdOpt.omitBlockArguments = true;
bwdOpt.filter = filter;
auto slices = mlir::getSlice(dotOp, bwdOpt, fwdOpt);
for (Operation *op : slices)
if (isa<tt::DotOp>(op) && (op != dotOp))
return {(unsigned)numWarps, 1};
@@ -71,7 +76,12 @@ public:
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
auto slices = mlir::getSlice(dotOp, filter);
mlir::ForwardSliceOptions fwdOpt;
fwdOpt.filter = filter;
mlir::BackwardSliceOptions bwdOpt;
bwdOpt.omitBlockArguments = true;
bwdOpt.filter = filter;
auto slices = mlir::getSlice(dotOp, bwdOpt, fwdOpt);
for (Operation *op : slices) {
if (isa<tt::DotOp>(op) && (op != dotOp))
return true;

View File

@@ -31,81 +31,7 @@ using triton::gpu::SliceEncodingAttr;
//
// -----------------------------------------------------------------------------
<<<<<<< HEAD
// convert(blocked, dot_operand) ->
// convert(blocked, mma) + convert(mma, dot_operand)
// if this value is itself the result of a dot operation
// this is a heuristic to accommodate some pattern seen in fused attention
// kernels.
// TODO: replace this by something more generic, i.e. layout-aware CSE
class DecomposeDotOperand : public mlir::RewritePattern {
public:
explicit DecomposeDotOperand(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
1, context) {}
template <typename encTy>
mlir::LogicalResult processEncoding(encTy encoding,
triton::gpu::ConvertLayoutOp convert,
RankedTensorType &dstType,
mlir::PatternRewriter &rewriter) const {
SetVector<Operation *> bwdSlices;
mlir::getBackwardSlice(convert.getResult(), &bwdSlices);
if (llvm::find_if(bwdSlices, [](Operation *op) {
return isa<triton::DotOp>(op);
}) == bwdSlices.end())
return mlir::failure();
auto tmpType = RankedTensorType::get(dstType.getShape(),
dstType.getElementType(), encoding);
auto tmp = rewriter.create<triton::gpu::ConvertLayoutOp>(
convert.getLoc(), tmpType, convert.getOperand());
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(convert, dstType,
tmp);
return mlir::success();
}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
return mlir::failure();
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
auto dstType = convert.getType().cast<RankedTensorType>();
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
auto dstDotOperand =
dstType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
auto dstParent = dstDotOperand.getParent();
if (dstDotOperand.getOpIdx() == 1 ||
(!dstParent.isa<triton::gpu::MmaEncodingAttr>() &&
!dstParent.isa<triton::gpu::MfmaEncodingAttr>()))
return mlir::failure();
if (dstParent.isa<triton::gpu::MmaEncodingAttr>()) {
auto dstParentMma = dstParent.cast<triton::gpu::MmaEncodingAttr>();
if (dstParentMma.isVolta() || dstParentMma.getWarpsPerCTA()[1] > 1)
return mlir::failure();
return processEncoding(dstParentMma, convert, dstType, rewriter);
}
if (dstParent.isa<triton::gpu::MfmaEncodingAttr>()) {
auto dstParentMfma = dstParent.cast<triton::gpu::MfmaEncodingAttr>();
if (dstParentMfma.getWarpsPerCTA()[1] > 1)
return mlir::failure();
return processEncoding(dstParentMfma, convert, dstType, rewriter);
}
}
return mlir::failure();
}
};
//
=======
// dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0))
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
class ConvertDotConvert : public mlir::RewritePattern {
public:
ConvertDotConvert(mlir::MLIRContext *context)

View File

@@ -188,7 +188,7 @@ void LoopPipeliner::collectValueDep(Value v, int stage,
return;
// Loop-invariant value, skip
if (v.getParentRegion() != &forOp.getLoopBody())
if (v.getParentRegion() != &forOp.getRegion())
return;
if (Operation *op = v.getDefiningOp()) {
@@ -598,7 +598,7 @@ SmallVector<Value> LoopPipeliner::collectNewLoopArgs() {
// We need this to update operands for yield
// original block arg => new arg's idx
SmallVector<Value> newLoopArgs;
for (auto v : forOp.getIterOperands()) {
for (auto v : forOp.getInitArgs()) {
newLoopArgs.push_back(lookupOrDefault(v, numStages - 1));/*1*/
}

View File

@@ -116,7 +116,7 @@ initialize_module(llvm::Module *module, const std::string &triple,
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
std::nullopt, llvm::CodeGenOpt::Aggressive);
std::nullopt, llvm::CodeGenOptLevel::Aggressive);
module->setDataLayout(machine->createDataLayout());
@@ -141,7 +141,7 @@ std::string generate_amdgcn_assembly(llvm::Module *module,
// emit
machine->addPassesToEmitFile(pass, stream, nullptr,
llvm::CodeGenFileType::CGFT_AssemblyFile);
llvm::CodeGenFileType::AssemblyFile);
pass.run(*module);
std::string amdgcn(buffer.begin(), buffer.end());
@@ -210,7 +210,7 @@ std::string generate_hsaco(llvm::Module *module, const std::string &triple,
// emit
llvm::legacy::PassManager pass;
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr,
llvm::CGFT_ObjectFile);
llvm::CodeGenFileType::ObjectFile);
pass.run(*module);
// generate HASCO file
@@ -792,4 +792,4 @@ translateTritonIRToHSACO(mlir::ModuleOp module, std::string gfx_arch,
}
} // namespace triton
} // namespace mlir
} // namespace mlir

View File

@@ -409,12 +409,9 @@ def compile(fn, **kwargs):
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
num_ctas = kwargs.get("num_ctas", 1)
num_stages = kwargs.get("num_stages", get_arch_default_num_stages(device_type, capability=capability))
<<<<<<< HEAD
waves_per_eu = kwargs.get("waves_per_eu", 0)
matrix_instr_nonkdim = kwargs.get("matrix_instr_nonkdim", 0)
=======
enable_fp_fusion = kwargs.get("enable_fp_fusion", True)
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
# TODO[shuhaoj]: Default should be to enable warp specialization once possible
enable_warp_specialization = kwargs.get("enable_warp_specialization", False)
# TODO[shuhaoj]: persistent can be decoupled with warp specialization

View File

@@ -281,21 +281,13 @@ class JITFunction(KernelInterface[T]):
constants = dict(zip(self.constexprs, constexpr_key))
return constants
<<<<<<< HEAD
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, configs):
=======
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, configs):
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization,enable_fp_fusion, extern_libs, configs):
if JITFunction.cache_hook is None:
return False
name = self.fn.__name__
module = self.fn.__module__
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
<<<<<<< HEAD
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})"
=======
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}, enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs}), enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
key = str(key)
class LegacyCompiler:
@@ -305,11 +297,7 @@ class JITFunction(KernelInterface[T]):
pass
kwargs = dict(signature=signature, device=device, constants=constants,
<<<<<<< HEAD
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs,
=======
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs,
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs,
configs=configs)
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={
@@ -363,11 +351,7 @@ class JITFunction(KernelInterface[T]):
def regular_args_v(args_proxy):
return [args_proxy[arg_name] for arg_name in regular_args]
<<<<<<< HEAD
def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, stream, warmup, device, device_type):
=======
def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type):
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type):
from ..compiler import (CompiledKernel, compile,
get_arch_default_num_stages,
get_arch_default_num_warps)
@@ -418,11 +402,7 @@ class JITFunction(KernelInterface[T]):
if num_stages is None:
num_stages = get_arch_default_num_stages(device_type)
<<<<<<< HEAD
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, self.debug)
=======
key = (version_key(), sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, self.debug)
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, enable_fp_fusion, self.debug)
if extern_libs is not None:
key = (key, tuple(extern_libs.items()))
@@ -450,13 +430,8 @@ class JITFunction(KernelInterface[T]):
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported")
<<<<<<< HEAD
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, configs):
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=matrix_instr_nonkdim, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
=======
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, configs):
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, enable_fp_fusion, extern_libs, configs):
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=matrix_instr_nonkdim, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
# Create tensormaps and append to args
args = bin.assemble_tensormap_to_arg(args)
if not warmup:
@@ -471,13 +446,8 @@ class JITFunction(KernelInterface[T]):
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
src = f"""
import triton
<<<<<<< HEAD
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, matrix_instr_nonkdim=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, stream, warmup, device, device_type)
=======
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, enable_fp_fusion=True, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type)
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, matrix_instr_nonkdim=0, enable_warp_specialization=False, enable_fp_fusion=True, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type)
"""
scope = {"launcher_body": launcher_body}
exec(src, scope)

View File

@@ -2043,28 +2043,16 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c
// -----
<<<<<<< HEAD
// CHECK-LABEL: copyitem
// CHECK-LABEL: reduce_slice
// GCN: llvm.store
// GCN: llvm.load
// PTX: st.shared.b8
// PTX: ld.shared.b8
// PTX-NOT: st.shared.b1
// PTX-NOT: ld.shared.b1
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @copyitem() attributes {noinline = false} {
%cst = arith.constant dense<true> : tensor<4x1xi1, #blocked>
=======
// CHECK-LABEL: reduce_slice
// CHECK-NOT: st.shared
// CHECK-NOT: ld.shared
// PTX-NOT: st.shared
// PTX-NOT: ld.shared
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [4, 4, 2], warpsPerCTA = [2, 4, 2], order = [2, 0, 1], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [0, 1, 2]}>
#sliced2 = #triton_gpu.slice<{dim = 2, parent = #blocked}>
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @reduce_slice() attributes {noinline = false} {
%cst = arith.constant dense<true> : tensor<4x1xi1, #sliced2>
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
%0 = "tt.reduce"(%cst) <{axis = 1 : i32}> ({
^bb0(%arg0: i1, %arg1: i1):
%1 = arith.ori %arg0, %arg1 : i1