mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Resolve merge conflicts; AMD adjustments for new LLVM version
This commit is contained in:
@@ -147,17 +147,13 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
|
||||
llvm::outs() << *llvmir << '\n';
|
||||
} else if (targetKind == "ptx") {
|
||||
llvm::outs() << ::triton::translateLLVMIRToPTX(*llvmir, SMArch.getValue(),
|
||||
<<<<<<< HEAD
|
||||
ptxVersion.getValue());
|
||||
ptxVersion.getValue(),
|
||||
enableFpFusion.getValue());
|
||||
} else if (targetKind == "hsaco") {
|
||||
auto [module, hsaco] = mlir::triton::translateLLVMIRToHSACO(
|
||||
*llvmir, GCNArch.getValue(), GCNTriple.getValue(),
|
||||
GCNFeatures.getValue());
|
||||
llvm::outs() << hsaco;
|
||||
=======
|
||||
ptxVersion.getValue(),
|
||||
enableFpFusion.getValue());
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
} else {
|
||||
llvm::errs() << "Error: Unknown target specified: " << targetKind << "\n";
|
||||
return failure();
|
||||
|
||||
@@ -6,24 +6,6 @@ using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
|
||||
/* ----- FP8E5M2 ------ */
|
||||
// This data-type is the standard FP8E5M2 format
|
||||
static const std::string Fp16_to_Fp8E5M2(bool hasNativeFP) {
|
||||
std::string ret;
|
||||
if (!hasNativeFP) {
|
||||
ret = "{ \n"
|
||||
".reg .b32 a<2>; \n"
|
||||
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
|
||||
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
|
||||
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
|
||||
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
|
||||
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
|
||||
"}";
|
||||
} else {
|
||||
ret = "cvt.rn.satfinite.e5m2x2.f16x2 $0, $1; \n\t";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter,
|
||||
@@ -50,15 +32,22 @@ Fp16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter,
|
||||
};
|
||||
}
|
||||
#else
|
||||
const std::string Fp16_to_Fp8E5M2 =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>; \n"
|
||||
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
|
||||
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
|
||||
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
|
||||
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
|
||||
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
|
||||
"}";
|
||||
static const std::string Fp16_to_Fp8E5M2(bool hasNativeFP) {
|
||||
std::string ret;
|
||||
if (!hasNativeFP) {
|
||||
ret = "{ \n"
|
||||
".reg .b32 a<2>; \n"
|
||||
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
|
||||
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
|
||||
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
|
||||
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
|
||||
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
|
||||
"}";
|
||||
} else {
|
||||
ret = "cvt.rn.satfinite.e5m2x2.f16x2 $0, $1; \n\t";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
@@ -141,10 +130,18 @@ Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
};
|
||||
}
|
||||
#else
|
||||
const std::string Fp8E5M2_to_Fp16 = "{ \n"
|
||||
"prmt.b32 $0, 0, $2, 0x5140; \n\t"
|
||||
"prmt.b32 $1, 0, $2, 0x7362; \n\t"
|
||||
"}";
|
||||
static const std::string Fp8E5M2_to_Fp16(bool hasNativeFP) {
|
||||
std::string ret;
|
||||
if (!hasNativeFP) {
|
||||
ret = "{ \n"
|
||||
"prmt.b32 $0, 0, $2, 0x5140; \n\t"
|
||||
"prmt.b32 $1, 0, $2, 0x7362; \n\t"
|
||||
"}";
|
||||
} else {
|
||||
ret = "cvt.rn.f16x2.e5m2x2 $0, $1; \n\t";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
@@ -195,11 +192,6 @@ Fp8E5M2FNUZ_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
|
||||
return result;
|
||||
}
|
||||
#else
|
||||
const std::string Fp8E5M2FNUZ_to_Fp16 = "{ \n"
|
||||
"prmt.b32 $0, 0, $2, 0x5140; \n\t"
|
||||
"prmt.b32 $1, 0, $2, 0x7362; \n\t"
|
||||
"}";
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
@@ -245,21 +237,47 @@ Fp8E5M2_to_Bf16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
};
|
||||
}
|
||||
#else
|
||||
const std::string Fp8E5M2_to_Bf16 =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
|
||||
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
|
||||
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
|
||||
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
|
||||
"shr.b32 b1, b1, 3; \n" // shift into bf16 position
|
||||
"add.u32 b0, b0, 0x38003800; \n" // b0.exp += 2**7-2**4
|
||||
// exponent compensate = 112
|
||||
"add.u32 b1, b1, 0x38003800; \n" // b1 += 112<<7 | 112<<7<<16
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
|
||||
"}";
|
||||
static const std::string Fp8E5M2_to_Bf16(bool hasNativeFP) {
|
||||
std::string ret;
|
||||
if (!hasNativeFP) {
|
||||
ret = "{ \n"
|
||||
".reg .b32 a<2>, b<2>, c<4>, d<4>, e112; \n" // if input = 0xf1f2f3f4
|
||||
"mov.u32 e112, 0x77800000; \n"
|
||||
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
|
||||
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
|
||||
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
|
||||
"shr.b32 b1, b1, 3; \n" // shift into bf16
|
||||
// position
|
||||
"and.b32 c0, b0, 0xFFFF0000; \n" // c0 = f3
|
||||
"shl.b32 c1, b0, 16; \n" // c1 = f4
|
||||
"and.b32 c2, b1, 0xFFFF0000; \n" // c2 = f1
|
||||
"shl.b32 c3, b1, 16; \n" // c3 = f2
|
||||
"mul.f32 d0, c0, e112; \n" // d0 = c0 * 0x77800000
|
||||
"mul.f32 d1, c1, e112; \n" // d1 = c1 * 0x77800000
|
||||
"mul.f32 d2, c2, e112; \n" // d2 = c2 * 0x77800000
|
||||
"mul.f32 d3, c3, e112; \n" // d3 = c3 * 0x77800000
|
||||
"prmt.b32 b0, d0, d1, 0x3276; \n" // b0 = 0xd3d4
|
||||
"prmt.b32 b1, d2, d3, 0x3276; \n" // b1 = 0xd1d2
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 =
|
||||
// b0|(0x80008000&a0)
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
|
||||
"}";
|
||||
} else {
|
||||
ret = "{ \n"
|
||||
".reg .b32 a; \n"
|
||||
".reg .f16 a<2>; \n"
|
||||
".reg .b16 b<2>; \n"
|
||||
"cvt.rn.f16x2.e5m2x2 a, $1; \n"
|
||||
"mov.b32 {a0, a1}, a; \n"
|
||||
"cvt.bf16.f16 b0, a0; \n"
|
||||
"cvt.bf16.f16 b1, a1; \n"
|
||||
"mov.b32 $0, {b0, b1}; \n"
|
||||
"}";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
@@ -331,96 +349,6 @@ Bf16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter,
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
|
||||
}
|
||||
#else
|
||||
const std::string Bf16_to_Fp8E5M2 =
|
||||
"{ \n" // bf16=fp8>>3 + 112<<7
|
||||
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n" // fp8_min = 0b00000000
|
||||
".reg .u32 fp8_min, fp8_max, rn_; \n" // fp8_max = 0b11111111
|
||||
"mov.u32 fp8_min, 0x38003800; \n" // so bf16_min = 0x3800
|
||||
"mov.u32 fp8_max, 0x57e057e0; \n" // so bf16_max = 0x57e0
|
||||
"mov.u32 rn_, 0x00100010; \n" // round to nearest
|
||||
"and.b32 sign0, $1, 0x80008000; \n" // sign0=in0&0x80008000
|
||||
"and.b32 sign1, $2, 0x80008000; \n" // (store sign)
|
||||
"prmt.b32 sign, sign0, sign1, 0x7531; \n"
|
||||
"and.b32 nosign0, $1, 0x7fff7fff; \n" // nosign0=in0&0x7fff7fff
|
||||
"and.b32 nosign1, $2, 0x7fff7fff; \n" // (strip sign)
|
||||
|
||||
// nosign = clamp(nosign, min, max)
|
||||
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n"
|
||||
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n"
|
||||
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n"
|
||||
"min.u32 nosign_0_0, nosign_0_0, 0x57e00000; \n"
|
||||
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n"
|
||||
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n"
|
||||
"min.u32 nosign_0_1, nosign_0_1, 0x57e0; \n"
|
||||
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n"
|
||||
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n"
|
||||
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n"
|
||||
"min.u32 nosign_1_0, nosign_1_0, 0x57e00000; \n"
|
||||
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n"
|
||||
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n"
|
||||
"min.u32 nosign_1_1, nosign_1_1, 0x57e0; \n"
|
||||
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n"
|
||||
|
||||
"add.u32 nosign0, nosign0, rn_; \n" // nosign0 += rn_
|
||||
"add.u32 nosign1, nosign1, rn_; \n" // (round to nearest)
|
||||
"sub.u32 nosign0, nosign0, 0x38003800; \n" // nosign0-=0x38003800
|
||||
"sub.u32 nosign1, nosign1, 0x38003800; \n" // (compensate offset)
|
||||
"shl.b32 nosign0, nosign0, 3; \n" // nosign0 <<= 3
|
||||
"shl.b32 nosign1, nosign1, 3; \n" // shift into to fp8e4
|
||||
"prmt.b32 nosign, nosign0, nosign1, 0x7531; \n" // nosign0 = 0xf100f200
|
||||
// nosign1 = 0xf300f400
|
||||
// nosign = 0xf3f4f1f2
|
||||
"or.b32 $0, nosign, sign; \n" // restore sign
|
||||
"}";
|
||||
#endif
|
||||
=======
|
||||
static const std::string Fp8E5M2_to_Fp16(bool hasNativeFP) {
|
||||
std::string ret;
|
||||
if (!hasNativeFP) {
|
||||
ret = "{ \n"
|
||||
"prmt.b32 $0, 0, $2, 0x5140; \n\t"
|
||||
"prmt.b32 $1, 0, $2, 0x7362; \n\t"
|
||||
"}";
|
||||
} else {
|
||||
ret = "cvt.rn.f16x2.e5m2x2 $0, $1; \n\t";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
static const std::string Fp8E5M2_to_Bf16(bool hasNativeFP) {
|
||||
std::string ret;
|
||||
if (!hasNativeFP) {
|
||||
ret =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
|
||||
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
|
||||
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
|
||||
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
|
||||
"shr.b32 b1, b1, 3; \n" // shift into bf16 position
|
||||
"add.u32 b0, b0, 0x38003800; \n" // b0.exp += 2**7-2**4
|
||||
// exponent compensate = 112
|
||||
"add.u32 b1, b1, 0x38003800; \n" // b1 += 112<<7 | 112<<7<<16
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
|
||||
"}";
|
||||
} else {
|
||||
ret = "{ \n"
|
||||
".reg .b32 a; \n"
|
||||
".reg .f16 a<2>; \n"
|
||||
".reg .b16 b<2>; \n"
|
||||
"cvt.rn.f16x2.e5m2x2 a, $1; \n"
|
||||
"mov.b32 {a0, a1}, a; \n"
|
||||
"cvt.bf16.f16 b0, a0; \n"
|
||||
"cvt.bf16.f16 b1, a1; \n"
|
||||
"mov.b32 $0, {b0, b1}; \n"
|
||||
"}";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
|
||||
static const std::string Bf16_to_Fp8E5M2(bool hasNativeFP) {
|
||||
std::string ret;
|
||||
if (!hasNativeFP) {
|
||||
@@ -477,15 +405,14 @@ static const std::string Bf16_to_Fp8E5M2(bool hasNativeFP) {
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
/* ----- FP8E4M3B15 ------ */
|
||||
// This data-type is a variant of the standard FP8E4M3 format.
|
||||
// It was designed for fast software conversion to FP16 on
|
||||
// nvidia GPUs that do not support it natively.
|
||||
<<<<<<< HEAD
|
||||
// Specifically, this data-type:
|
||||
// - has infinities
|
||||
// - has multiple nans (when all exponent bits are 1)
|
||||
// - has an exponent bias of 15 (vs. 7 for fp8e4m3)
|
||||
// This is the same format as FP8E4M3Nv, but:
|
||||
// - the exponent bias is 15 instead of 7
|
||||
// - 0xff and 0x7f are mapped to +-1.750 instead of +-nan
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp8E4M3B15_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
@@ -525,11 +452,6 @@ Fp8E4M3B15_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
};
|
||||
}
|
||||
#else
|
||||
=======
|
||||
// This is the same format as FP8E4M3Nv, but:
|
||||
// - the exponent bias is 15 instead of 7
|
||||
// - 0xff and 0x7f are mapped to +-1.750 instead of +-nan
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
const std::string Fp8E4M3B15_to_Fp16 =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n"
|
||||
@@ -542,7 +464,6 @@ const std::string Fp8E4M3B15_to_Fp16 =
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
|
||||
"shl.b32 $1, b1, 7; \n"
|
||||
"} \n";
|
||||
<<<<<<< HEAD
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
@@ -590,11 +511,7 @@ Fp16_to_Fp8E4M3B15(Location loc, ConversionPatternRewriter &rewriter,
|
||||
};
|
||||
}
|
||||
#else
|
||||
const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
|
||||
=======
|
||||
|
||||
static const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
std::string ret;
|
||||
ret += "{ \n"
|
||||
".reg .pred p<4>; \n"
|
||||
@@ -640,7 +557,6 @@ static const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
|
||||
// $0 = (($2 << 1) & 0x80008000u) | (($2 << 7) & 0x3f803f80u);
|
||||
// $1 = (($2 << 0) & 0x80008000u) | (($2 << 0) & 0x3f803f80u);
|
||||
// WARN: subnormal (0bs0000xxx) are not handled
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp8E4M3B15x4_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
@@ -671,10 +587,7 @@ Fp8E4M3B15x4_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
};
|
||||
}
|
||||
#else
|
||||
const std::string Fp8E4M3B15x4_to_Fp16 =
|
||||
=======
|
||||
static const std::string Fp8E4M3B15x4_to_Fp16 =
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
"{ \n"
|
||||
".reg .b32 a<2>; \n"
|
||||
"add.u32 a0, $2, $2; \n"
|
||||
@@ -692,7 +605,6 @@ static const std::string Fp8E4M3B15x4_to_Fp16 =
|
||||
// ((e4.y >> 0) & (0x80008000u >> 0)) |
|
||||
// ((e4.y >> 0) & (0x3f803f80u >> 0)) ;
|
||||
// WARN: subnormal (0bs0000xxx) are not handled
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp16_to_Fp8E4M3B15x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
@@ -726,10 +638,7 @@ Fp16_to_Fp8E4M3B15x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
};
|
||||
}
|
||||
#else
|
||||
const std::string Fp16_to_Fp8E4M3B15x4 =
|
||||
=======
|
||||
static const std::string Fp16_to_Fp8E4M3B15x4 =
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
"{ \n"
|
||||
".reg .b32 a<2>; \n"
|
||||
"shr.b32 a0, $1, 1; \n"
|
||||
@@ -1040,7 +949,6 @@ const std::string Bf16_to_Fp8E4M3 =
|
||||
// nosign = 0xf3f4f1f2
|
||||
"or.b32 $0, nosign, sign; \n" // restore sign
|
||||
"}";
|
||||
#endif
|
||||
|
||||
// Fp8E4M3 (x2) -> Fp16 (x2) (packed)
|
||||
static const std::string Fp8E4M3Nv_to_Fp16 = "{ \n"
|
||||
@@ -1052,7 +960,6 @@ static const std::string Fp16_to_Fp8E4M3Nv =
|
||||
"cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; \n"
|
||||
"}";
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Fp8E4M3 (x2) -> Fp16 (x2) (packed)
|
||||
static const std::string Fp8E4M3Nv_to_Bf16 =
|
||||
"{ \n"
|
||||
@@ -1090,13 +997,13 @@ static const std::string S8_to_Bf16 =
|
||||
"prmt.b32 $0, f0, f1, 0x7632; \n" // f32->bf16 + pack
|
||||
"prmt.b32 $1, f2, f3, 0x7632; \n" //
|
||||
"}";
|
||||
#endif
|
||||
|
||||
// Fp32 (x2) -> Fp8 (x2) (packed)
|
||||
static const std::string Fp32_to_Fp8E4M3Nv =
|
||||
"cvt.rn.satfinite.e4m3x2.f32 $0, $2, $1; \n";
|
||||
static const std::string Fp32_to_Fp8E5M2 =
|
||||
"cvt.rn.satfinite.e5m2x2.f32 $0, $2, $1; \n";
|
||||
#endif
|
||||
|
||||
static SmallVector<Value> reorderValues(const SmallVector<Value> &values,
|
||||
Type inType, Type ouType) {
|
||||
@@ -1529,49 +1436,46 @@ struct FpToFpOpConversion
|
||||
// F8 -> F16
|
||||
{{F8E4M3B15TyID, F16TyID}, Fp8E4M3B15_to_Fp16},
|
||||
{{F8E4M3FNTyID, F16TyID}, Fp8E4M3B15x4_to_Fp16},
|
||||
<<<<<<< HEAD
|
||||
{{F8E4M3FNUZTyID, F16TyID}, Fp8E4M3FNUZ_to_Fp16},
|
||||
#ifdef USE_ROCM
|
||||
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16},
|
||||
{{F8E5M2FNUZTyID, F16TyID}, Fp8E5M2FNUZ_to_Fp16},
|
||||
=======
|
||||
#else
|
||||
{{F8E4M3TyID, F16TyID}, Fp8E4M3Nv_to_Fp16},
|
||||
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16(computeCapability >= 90)},
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
#endif
|
||||
// F16 -> F8
|
||||
#ifdef USE_ROCM
|
||||
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15},
|
||||
#else
|
||||
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15(computeCapability >= 80)},
|
||||
#endif
|
||||
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
|
||||
<<<<<<< HEAD
|
||||
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
|
||||
{{F16TyID, F8E4M3FNUZTyID}, Fp16_to_Fp8E4M3FNUZ},
|
||||
#ifdef USE_ROCM
|
||||
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2},
|
||||
{{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ},
|
||||
// F8 -> BF16
|
||||
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
|
||||
#ifndef USE_ROCM
|
||||
=======
|
||||
#else
|
||||
{{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3Nv},
|
||||
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2(computeCapability >= 90)},
|
||||
// F8 -> BF16
|
||||
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)},
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
#endif
|
||||
// F8 -> BF16
|
||||
#ifdef USE_ROCM
|
||||
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
|
||||
#else
|
||||
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)},
|
||||
{{F8E4M3TyID, BF16TyID}, Fp8E4M3Nv_to_Bf16},
|
||||
#endif
|
||||
// BF16 -> F8
|
||||
<<<<<<< HEAD
|
||||
// BF16 -> F8
|
||||
#ifdef USE_ROCM
|
||||
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2},
|
||||
#ifndef USE_ROCM
|
||||
{{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3Nv},
|
||||
#endif
|
||||
=======
|
||||
#else
|
||||
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2(computeCapability >= 90)},
|
||||
{{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3Nv},
|
||||
// F32 -> F8
|
||||
{{F32TyID, F8E4M3TyID}, Fp32_to_Fp8E4M3Nv},
|
||||
{{F32TyID, F8E5M2TyID}, Fp32_to_Fp8E5M2},
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
#endif
|
||||
};
|
||||
int inVecWidthBits = 32;
|
||||
int outVecWidthBits = 32;
|
||||
|
||||
@@ -1150,7 +1150,7 @@ struct AtomicRMWOpConversion
|
||||
|
||||
#ifdef USE_ROCM
|
||||
/// Try to match the mlir::triton::RMWOp to LLVM::AtomicBinOp.
|
||||
static Optional<LLVM::AtomicBinOp> matchAtomicOp(RMWOp atomicOp) {
|
||||
static std::optional<LLVM::AtomicBinOp> matchAtomicOp(RMWOp atomicOp) {
|
||||
switch (atomicOp) {
|
||||
case RMWOp::AND:
|
||||
return LLVM::AtomicBinOp::_and;
|
||||
|
||||
@@ -489,22 +489,11 @@ struct GetNumProgramsOpConversion
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
#ifdef USE_ROCM
|
||||
|
||||
Location loc = op->getLoc();
|
||||
assert(op.getAxis() < 3);
|
||||
// Seem like GridDimOp returns the number of threads (not the number of
|
||||
// workgroups) in a kernel (a bug in llvm https://reviews.llvm.org/D156009),
|
||||
// so as a workaround here, we divide by the number of threads
|
||||
// per workgroup to get the number of workgroups in a kernel.
|
||||
// TODO: when we do upstream to include llvm fix, we can remove this workaround
|
||||
// The unit test added in this PR can guarantee that.
|
||||
Value threadsPerGrid =
|
||||
Value blockId =
|
||||
rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxis()]);
|
||||
Value threadsPerBlock =
|
||||
rewriter.create<::mlir::gpu::BlockDimOp>(loc, dims[op.getAxis()]);
|
||||
Value threadNumPerGrid = rewriter.create<arith::TruncIOp>(loc, i32_ty, threadsPerGrid);
|
||||
Value threadNumPerBlock = rewriter.create<arith::TruncIOp>(loc, i32_ty, threadsPerBlock);
|
||||
rewriter.replaceOpWithNewOp<LLVM::UDivOp>(op, threadNumPerGrid, threadNumPerBlock);
|
||||
rewriter.replaceOpWithNewOp<arith::TruncIOp>(op, i32_ty, blockId);
|
||||
return success();
|
||||
#else
|
||||
// It is not easy to get the compute capability here, so we use numCTAs to
|
||||
|
||||
@@ -26,7 +26,12 @@ warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
|
||||
auto filter = [&dotOp](Operation *op) {
|
||||
return op->getParentRegion() == dotOp->getParentRegion();
|
||||
};
|
||||
auto slices = mlir::getSlice(dotOp, filter);
|
||||
mlir::ForwardSliceOptions fwdOpt;
|
||||
fwdOpt.filter = filter;
|
||||
mlir::BackwardSliceOptions bwdOpt;
|
||||
bwdOpt.omitBlockArguments = true;
|
||||
bwdOpt.filter = filter;
|
||||
auto slices = mlir::getSlice(dotOp, bwdOpt, fwdOpt);
|
||||
for (Operation *op : slices)
|
||||
if (isa<tt::DotOp>(op) && (op != dotOp))
|
||||
return {(unsigned)numWarps, 1};
|
||||
@@ -71,7 +76,12 @@ public:
|
||||
auto filter = [&dotOp](Operation *op) {
|
||||
return op->getParentRegion() == dotOp->getParentRegion();
|
||||
};
|
||||
auto slices = mlir::getSlice(dotOp, filter);
|
||||
mlir::ForwardSliceOptions fwdOpt;
|
||||
fwdOpt.filter = filter;
|
||||
mlir::BackwardSliceOptions bwdOpt;
|
||||
bwdOpt.omitBlockArguments = true;
|
||||
bwdOpt.filter = filter;
|
||||
auto slices = mlir::getSlice(dotOp, bwdOpt, fwdOpt);
|
||||
for (Operation *op : slices) {
|
||||
if (isa<tt::DotOp>(op) && (op != dotOp))
|
||||
return true;
|
||||
|
||||
@@ -31,81 +31,7 @@ using triton::gpu::SliceEncodingAttr;
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
<<<<<<< HEAD
|
||||
// convert(blocked, dot_operand) ->
|
||||
// convert(blocked, mma) + convert(mma, dot_operand)
|
||||
// if this value is itself the result of a dot operation
|
||||
// this is a heuristic to accommodate some pattern seen in fused attention
|
||||
// kernels.
|
||||
// TODO: replace this by something more generic, i.e. layout-aware CSE
|
||||
class DecomposeDotOperand : public mlir::RewritePattern {
|
||||
|
||||
public:
|
||||
explicit DecomposeDotOperand(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
template <typename encTy>
|
||||
mlir::LogicalResult processEncoding(encTy encoding,
|
||||
triton::gpu::ConvertLayoutOp convert,
|
||||
RankedTensorType &dstType,
|
||||
mlir::PatternRewriter &rewriter) const {
|
||||
SetVector<Operation *> bwdSlices;
|
||||
mlir::getBackwardSlice(convert.getResult(), &bwdSlices);
|
||||
if (llvm::find_if(bwdSlices, [](Operation *op) {
|
||||
return isa<triton::DotOp>(op);
|
||||
}) == bwdSlices.end())
|
||||
return mlir::failure();
|
||||
|
||||
auto tmpType = RankedTensorType::get(dstType.getShape(),
|
||||
dstType.getElementType(), encoding);
|
||||
auto tmp = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
convert.getLoc(), tmpType, convert.getOperand());
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(convert, dstType,
|
||||
tmp);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
|
||||
return mlir::failure();
|
||||
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = convert.getType().cast<RankedTensorType>();
|
||||
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
|
||||
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
auto dstDotOperand =
|
||||
dstType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
auto dstParent = dstDotOperand.getParent();
|
||||
if (dstDotOperand.getOpIdx() == 1 ||
|
||||
(!dstParent.isa<triton::gpu::MmaEncodingAttr>() &&
|
||||
!dstParent.isa<triton::gpu::MfmaEncodingAttr>()))
|
||||
return mlir::failure();
|
||||
|
||||
if (dstParent.isa<triton::gpu::MmaEncodingAttr>()) {
|
||||
auto dstParentMma = dstParent.cast<triton::gpu::MmaEncodingAttr>();
|
||||
if (dstParentMma.isVolta() || dstParentMma.getWarpsPerCTA()[1] > 1)
|
||||
return mlir::failure();
|
||||
return processEncoding(dstParentMma, convert, dstType, rewriter);
|
||||
}
|
||||
|
||||
if (dstParent.isa<triton::gpu::MfmaEncodingAttr>()) {
|
||||
auto dstParentMfma = dstParent.cast<triton::gpu::MfmaEncodingAttr>();
|
||||
if (dstParentMfma.getWarpsPerCTA()[1] > 1)
|
||||
return mlir::failure();
|
||||
return processEncoding(dstParentMfma, convert, dstType, rewriter);
|
||||
}
|
||||
}
|
||||
return mlir::failure();
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
=======
|
||||
// dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0))
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
class ConvertDotConvert : public mlir::RewritePattern {
|
||||
public:
|
||||
ConvertDotConvert(mlir::MLIRContext *context)
|
||||
|
||||
@@ -188,7 +188,7 @@ void LoopPipeliner::collectValueDep(Value v, int stage,
|
||||
return;
|
||||
|
||||
// Loop-invariant value, skip
|
||||
if (v.getParentRegion() != &forOp.getLoopBody())
|
||||
if (v.getParentRegion() != &forOp.getRegion())
|
||||
return;
|
||||
|
||||
if (Operation *op = v.getDefiningOp()) {
|
||||
@@ -598,7 +598,7 @@ SmallVector<Value> LoopPipeliner::collectNewLoopArgs() {
|
||||
// We need this to update operands for yield
|
||||
// original block arg => new arg's idx
|
||||
SmallVector<Value> newLoopArgs;
|
||||
for (auto v : forOp.getIterOperands()) {
|
||||
for (auto v : forOp.getInitArgs()) {
|
||||
newLoopArgs.push_back(lookupOrDefault(v, numStages - 1));/*1*/
|
||||
}
|
||||
|
||||
|
||||
@@ -116,7 +116,7 @@ initialize_module(llvm::Module *module, const std::string &triple,
|
||||
opt.NoNaNsFPMath = true;
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(
|
||||
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
|
||||
std::nullopt, llvm::CodeGenOpt::Aggressive);
|
||||
std::nullopt, llvm::CodeGenOptLevel::Aggressive);
|
||||
|
||||
module->setDataLayout(machine->createDataLayout());
|
||||
|
||||
@@ -141,7 +141,7 @@ std::string generate_amdgcn_assembly(llvm::Module *module,
|
||||
|
||||
// emit
|
||||
machine->addPassesToEmitFile(pass, stream, nullptr,
|
||||
llvm::CodeGenFileType::CGFT_AssemblyFile);
|
||||
llvm::CodeGenFileType::AssemblyFile);
|
||||
pass.run(*module);
|
||||
|
||||
std::string amdgcn(buffer.begin(), buffer.end());
|
||||
@@ -210,7 +210,7 @@ std::string generate_hsaco(llvm::Module *module, const std::string &triple,
|
||||
// emit
|
||||
llvm::legacy::PassManager pass;
|
||||
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr,
|
||||
llvm::CGFT_ObjectFile);
|
||||
llvm::CodeGenFileType::ObjectFile);
|
||||
pass.run(*module);
|
||||
|
||||
// generate HASCO file
|
||||
@@ -792,4 +792,4 @@ translateTritonIRToHSACO(mlir::ModuleOp module, std::string gfx_arch,
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
} // namespace mlir
|
||||
|
||||
@@ -409,12 +409,9 @@ def compile(fn, **kwargs):
|
||||
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
|
||||
num_ctas = kwargs.get("num_ctas", 1)
|
||||
num_stages = kwargs.get("num_stages", get_arch_default_num_stages(device_type, capability=capability))
|
||||
<<<<<<< HEAD
|
||||
waves_per_eu = kwargs.get("waves_per_eu", 0)
|
||||
matrix_instr_nonkdim = kwargs.get("matrix_instr_nonkdim", 0)
|
||||
=======
|
||||
enable_fp_fusion = kwargs.get("enable_fp_fusion", True)
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
# TODO[shuhaoj]: Default should be to enable warp specialization once possible
|
||||
enable_warp_specialization = kwargs.get("enable_warp_specialization", False)
|
||||
# TODO[shuhaoj]: persistent can be decoupled with warp specialization
|
||||
|
||||
@@ -281,21 +281,13 @@ class JITFunction(KernelInterface[T]):
|
||||
constants = dict(zip(self.constexprs, constexpr_key))
|
||||
return constants
|
||||
|
||||
<<<<<<< HEAD
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, configs):
|
||||
=======
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, configs):
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization,enable_fp_fusion, extern_libs, configs):
|
||||
if JITFunction.cache_hook is None:
|
||||
return False
|
||||
name = self.fn.__name__
|
||||
module = self.fn.__module__
|
||||
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
|
||||
<<<<<<< HEAD
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})"
|
||||
=======
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}, enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs}), enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
|
||||
key = str(key)
|
||||
|
||||
class LegacyCompiler:
|
||||
@@ -305,11 +297,7 @@ class JITFunction(KernelInterface[T]):
|
||||
pass
|
||||
|
||||
kwargs = dict(signature=signature, device=device, constants=constants,
|
||||
<<<<<<< HEAD
|
||||
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs,
|
||||
=======
|
||||
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs,
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs,
|
||||
configs=configs)
|
||||
|
||||
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={
|
||||
@@ -363,11 +351,7 @@ class JITFunction(KernelInterface[T]):
|
||||
def regular_args_v(args_proxy):
|
||||
return [args_proxy[arg_name] for arg_name in regular_args]
|
||||
|
||||
<<<<<<< HEAD
|
||||
def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, stream, warmup, device, device_type):
|
||||
=======
|
||||
def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type):
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type):
|
||||
from ..compiler import (CompiledKernel, compile,
|
||||
get_arch_default_num_stages,
|
||||
get_arch_default_num_warps)
|
||||
@@ -418,11 +402,7 @@ class JITFunction(KernelInterface[T]):
|
||||
if num_stages is None:
|
||||
num_stages = get_arch_default_num_stages(device_type)
|
||||
|
||||
<<<<<<< HEAD
|
||||
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, self.debug)
|
||||
=======
|
||||
key = (version_key(), sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, self.debug)
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, enable_fp_fusion, self.debug)
|
||||
if extern_libs is not None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
|
||||
@@ -450,13 +430,8 @@ class JITFunction(KernelInterface[T]):
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||
<<<<<<< HEAD
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, configs):
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=matrix_instr_nonkdim, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
=======
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, configs):
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, enable_fp_fusion, extern_libs, configs):
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=matrix_instr_nonkdim, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
@@ -471,13 +446,8 @@ class JITFunction(KernelInterface[T]):
|
||||
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
|
||||
src = f"""
|
||||
import triton
|
||||
<<<<<<< HEAD
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, matrix_instr_nonkdim=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, stream, warmup, device, device_type)
|
||||
=======
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, enable_fp_fusion=True, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type)
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, matrix_instr_nonkdim=0, enable_warp_specialization=False, enable_fp_fusion=True, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type)
|
||||
"""
|
||||
scope = {"launcher_body": launcher_body}
|
||||
exec(src, scope)
|
||||
|
||||
@@ -2043,28 +2043,16 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c
|
||||
|
||||
// -----
|
||||
|
||||
<<<<<<< HEAD
|
||||
// CHECK-LABEL: copyitem
|
||||
// CHECK-LABEL: reduce_slice
|
||||
// GCN: llvm.store
|
||||
// GCN: llvm.load
|
||||
// PTX: st.shared.b8
|
||||
// PTX: ld.shared.b8
|
||||
// PTX-NOT: st.shared.b1
|
||||
// PTX-NOT: ld.shared.b1
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
|
||||
tt.func public @copyitem() attributes {noinline = false} {
|
||||
%cst = arith.constant dense<true> : tensor<4x1xi1, #blocked>
|
||||
=======
|
||||
// CHECK-LABEL: reduce_slice
|
||||
// CHECK-NOT: st.shared
|
||||
// CHECK-NOT: ld.shared
|
||||
// PTX-NOT: st.shared
|
||||
// PTX-NOT: ld.shared
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [4, 4, 2], warpsPerCTA = [2, 4, 2], order = [2, 0, 1], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [0, 1, 2]}>
|
||||
#sliced2 = #triton_gpu.slice<{dim = 2, parent = #blocked}>
|
||||
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
|
||||
tt.func public @reduce_slice() attributes {noinline = false} {
|
||||
%cst = arith.constant dense<true> : tensor<4x1xi1, #sliced2>
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
%0 = "tt.reduce"(%cst) <{axis = 1 : i32}> ({
|
||||
^bb0(%arg0: i1, %arg1: i1):
|
||||
%1 = arith.ori %arg0, %arg1 : i1
|
||||
|
||||
Reference in New Issue
Block a user