mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-27 03:01:52 -04:00
[FRONTEND][BACKEND] Add flag to control accumulation for fp8 (#2300)
Change the dot to allow taking an initial accumulator and add a flag that will allow the compiler to accumulate in a lower precision than the output type. On Hopper this flag is on by default which allows accumualting with lower precision. This only affect Hopper fp8 dot.
This commit is contained in:
@@ -148,12 +148,12 @@ def WGMMA_EltTypeAttr : I32EnumAttr<"WGMMAEltType",
|
||||
def WGMMA_OperandType : AnyTypeOf<[LLVM_AnyStruct, I64], "wgmma operand A/B type">;
|
||||
|
||||
def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> {
|
||||
let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, LLVM_AnyStruct:$opC,
|
||||
let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, Optional<LLVM_AnyStruct>:$opC,
|
||||
I32Attr:$m, I32Attr:$n, I32Attr:$k,
|
||||
WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB,
|
||||
WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB);
|
||||
let results = (outs LLVM_AnyStruct:$res);
|
||||
let assemblyFormat = "$opA `,` $opB `,` $opC attr-dict `:` functional-type(operands, $res)";
|
||||
let assemblyFormat = "$opA `,` $opB (`,` $opC^)? attr-dict `:` functional-type(operands, $res)";
|
||||
}
|
||||
|
||||
def NVGPU_CGABarrierSyncOp : NVGPU_Op<"cga_barrier_sync", []> {
|
||||
|
||||
@@ -394,7 +394,12 @@ def TT_DotOp : TT_Op<"dot", [Pure,
|
||||
$d = matrix_multiply($a, $b) + $c
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
||||
let arguments = (ins
|
||||
TT_FpIntTensor:$a,
|
||||
TT_FpIntTensor:$b,
|
||||
TT_FpIntTensor:$c,
|
||||
BoolAttr:$allowTF32,
|
||||
I32Attr:$maxNumImpreciseAcc);
|
||||
|
||||
let results = (outs TT_FpIntTensor:$d);
|
||||
|
||||
|
||||
@@ -258,7 +258,11 @@ def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [Pure,
|
||||
$d = matrix_multiply($a, $b) + $c
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
||||
let arguments = (ins TT_FpIntTensor:$a,
|
||||
TT_FpIntTensor:$b,
|
||||
TT_FpIntTensor:$c,
|
||||
BoolAttr:$allowTF32,
|
||||
I32Attr:$maxNumImpreciseAcc);
|
||||
|
||||
let results = (outs TT_FpIntTensor:$d);
|
||||
|
||||
|
||||
@@ -379,6 +379,12 @@ bool supportMMA(triton::DotOp op, int version) {
|
||||
aElemTy.isF32()))) {
|
||||
return false;
|
||||
}
|
||||
// We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
|
||||
if (op.getMaxNumImpreciseAcc() < 32 &&
|
||||
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ()) &&
|
||||
op.getType().cast<RankedTensorType>().getElementType().isF32()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (aElemTy.isF32() && bElemTy.isF32()) {
|
||||
return op.getAllowTF32() && version >= 2;
|
||||
|
||||
@@ -708,13 +708,13 @@ public:
|
||||
// TODO (zahi): Return type must always be a struct for wgmma, currently
|
||||
// we rely on the size of output constraints vector to determine whether
|
||||
// the output is a struct or not. We should find a way to pass this info
|
||||
auto opC = op.getOpC();
|
||||
auto typeC = opC.getType();
|
||||
auto resultType = op.getType();
|
||||
|
||||
auto structTypeC = typeC.dyn_cast<LLVM::LLVMStructType>();
|
||||
uint32_t numCRegs = structTypeC.getBody().size();
|
||||
std::string c = structTypeC.getBody().front().isF32() ? "=f" : "=r";
|
||||
return std::vector<std::string>(numCRegs, c);
|
||||
auto outputStructType = resultType.dyn_cast<LLVM::LLVMStructType>();
|
||||
uint32_t numOutputRegs = outputStructType.getBody().size();
|
||||
std::string output =
|
||||
outputStructType.getBody().front().isF32() ? "=f" : "=r";
|
||||
return std::vector<std::string>(numOutputRegs, output);
|
||||
}
|
||||
|
||||
OperandsAndConstraints getOperandsAndConstraints(ttn::WGMMAOp op) const {
|
||||
@@ -727,7 +727,8 @@ public:
|
||||
auto structTypeA = typeA.dyn_cast<LLVM::LLVMStructType>();
|
||||
|
||||
// TODO (zahi): is this the best way to tie inputs/outputs ?
|
||||
operandsAndConstraints.push_back({opC, "0"});
|
||||
if (opC)
|
||||
operandsAndConstraints.push_back({opC, "0"});
|
||||
|
||||
if (structTypeA) {
|
||||
operandsAndConstraints.push_back({opA, "f"});
|
||||
@@ -744,7 +745,6 @@ public:
|
||||
using namespace ttn;
|
||||
auto opA = op.getOpA();
|
||||
auto opB = op.getOpB();
|
||||
auto opC = op.getOpC();
|
||||
auto m = op.getM();
|
||||
auto n = op.getN();
|
||||
auto k = op.getK();
|
||||
@@ -757,12 +757,12 @@ public:
|
||||
// Register checks
|
||||
auto typeA = opA.getType();
|
||||
auto typeB = opB.getType();
|
||||
auto typeC = opC.getType();
|
||||
auto typeOutput = op.getType();
|
||||
auto structTypeA = typeA.dyn_cast<LLVM::LLVMStructType>();
|
||||
auto structTypeB = typeB.dyn_cast<LLVM::LLVMStructType>();
|
||||
auto structTypeC = typeC.dyn_cast<LLVM::LLVMStructType>();
|
||||
auto structTypeOutput = typeOutput.dyn_cast<LLVM::LLVMStructType>();
|
||||
assert(!structTypeB && "Operand B can not be registers");
|
||||
assert(structTypeC && "Operand C must be registers");
|
||||
assert(structTypeOutput && "Output and C operand must be registers");
|
||||
|
||||
// Element type, MNK shape and transposing support check
|
||||
// Reference:
|
||||
@@ -804,18 +804,20 @@ public:
|
||||
|
||||
// Operands
|
||||
uint32_t asmOpIdx = 0;
|
||||
|
||||
// Operand C
|
||||
uint32_t numCRegs = structTypeC.getBody().size();
|
||||
|
||||
std::string args = "";
|
||||
|
||||
// Output and operand C
|
||||
uint32_t numCRegs = structTypeOutput.getBody().size();
|
||||
|
||||
args += "{";
|
||||
for (uint32_t i = 0; i < numCRegs; ++i) {
|
||||
args += "$" + std::to_string(asmOpIdx++) + (i == numCRegs - 1 ? "" : ",");
|
||||
}
|
||||
args += "}, ";
|
||||
|
||||
asmOpIdx += numCRegs;
|
||||
if (op.getOpC())
|
||||
asmOpIdx += numCRegs;
|
||||
|
||||
// Operand A
|
||||
if (structTypeA) {
|
||||
uint32_t numARegs = m * k / 128;
|
||||
@@ -833,8 +835,8 @@ public:
|
||||
// Operand B (must be `desc`)
|
||||
args += "$" + std::to_string(asmOpIdx++) + ", ";
|
||||
|
||||
// `scale-d` is 1 by default
|
||||
args += "1";
|
||||
// `scale-d` is 1 if we have a C operand.
|
||||
args += op.getOpC() ? "1" : "0";
|
||||
|
||||
// `imm-scale-a`, and `imm-scale-b` are 1 by default only for float-based
|
||||
// WGMMA
|
||||
|
||||
@@ -260,11 +260,30 @@ SmallVector<Value> unpackAccumulator(ConversionPatternRewriter &rewriter,
|
||||
return results;
|
||||
}
|
||||
|
||||
static bool isFP8(triton::nvgpu::WGMMAEltType eltType) {
|
||||
return eltType == triton::nvgpu::WGMMAEltType::e5m2 ||
|
||||
eltType == triton::nvgpu::WGMMAEltType::e4m3;
|
||||
}
|
||||
|
||||
static Value faddAccumulate(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value a, Value b) {
|
||||
int numEl = a.getType().cast<LLVM::LLVMStructType>().getBody().size();
|
||||
Value newStruct = rewriter.create<LLVM::UndefOp>(loc, a.getType());
|
||||
for (int i = 0; i < numEl; ++i) {
|
||||
Value lhs = rewriter.create<LLVM::ExtractValueOp>(loc, a, i);
|
||||
Value rhs = rewriter.create<LLVM::ExtractValueOp>(loc, b, i);
|
||||
Value add = rewriter.create<LLVM::FAddOp>(loc, lhs, rhs);
|
||||
newStruct = rewriter.create<LLVM::InsertValueOp>(loc, newStruct, add, i);
|
||||
}
|
||||
return newStruct;
|
||||
}
|
||||
|
||||
LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc,
|
||||
Operation *op, Value a, Value b, Value c, Value d,
|
||||
Value loadedA, Value loadedB, Value loadedC,
|
||||
bool allowTF32, const SharedMemoryObject &smemObjA,
|
||||
bool allowTF32, uint32_t maxNumImpreciseAcc,
|
||||
const SharedMemoryObject &smemObjA,
|
||||
const SharedMemoryObject &smemObjB, bool sync,
|
||||
Value thread) {
|
||||
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
||||
@@ -311,7 +330,10 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
if (numTMADescs == 0)
|
||||
rewriter.create<triton::nvgpu::FenceAsyncSharedOp>(loc, 0);
|
||||
rewriter.create<triton::nvgpu::WGMMAFenceOp>(loc);
|
||||
|
||||
// WGMMA fp8 -> fp32 accumulates in lower precision than fp32.
|
||||
bool needsPartialAccumulator = isFP8(eltTypeA) &&
|
||||
eltTypeC == triton::nvgpu::WGMMAEltType::f32 &&
|
||||
maxNumImpreciseAcc <= aTensorTy.getShape()[1];
|
||||
SmallVector<Value> mmaResults;
|
||||
for (int m = 0; m < numRepM; ++m) {
|
||||
for (int n = 0; n < numRepN; ++n) {
|
||||
@@ -323,13 +345,33 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
auto accTy =
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
|
||||
Value d = typeConverter->packLLElements(loc, mmaOut, rewriter, accTy);
|
||||
uint32_t numLowPrecisionAcc = 0;
|
||||
Value partialAcc;
|
||||
for (int k = 0; k < numRepK; ++k) {
|
||||
auto a = aLoader.smemLoad(m, k);
|
||||
auto b = bLoader.smemLoad(n, k);
|
||||
ValueRange operands{a, b, d};
|
||||
d = rewriter.create<triton::nvgpu::WGMMAOp>(loc, accTy, a, b, d, M, N,
|
||||
K, eltTypeC, eltTypeA,
|
||||
eltTypeB, layoutA, layoutB);
|
||||
numLowPrecisionAcc += K;
|
||||
// If using native accumulation would cause use to do more low precion
|
||||
// accumulation than allowed do a separate allocation.
|
||||
bool requireAddAccumulator =
|
||||
needsPartialAccumulator &&
|
||||
(numLowPrecisionAcc >= maxNumImpreciseAcc || k == numRepK - 1);
|
||||
Value mmaAcc = needsPartialAccumulator ? partialAcc : d;
|
||||
mmaAcc = rewriter.create<triton::nvgpu::WGMMAOp>(
|
||||
loc, accTy, a, b, mmaAcc, M, N, K, eltTypeC, eltTypeA, eltTypeB,
|
||||
layoutA, layoutB);
|
||||
if (needsPartialAccumulator)
|
||||
partialAcc = mmaAcc;
|
||||
else
|
||||
d = mmaAcc;
|
||||
// If we need accumulate separately to have higer precision, insert
|
||||
// adds.
|
||||
if (requireAddAccumulator) {
|
||||
d = faddAccumulate(rewriter, loc, d, partialAcc);
|
||||
numLowPrecisionAcc = 0;
|
||||
partialAcc = Value();
|
||||
}
|
||||
}
|
||||
auto acc = typeConverter->unpackLLElements(loc, d, rewriter, accTy);
|
||||
for (int i = 0; i < acc.size(); ++i) {
|
||||
@@ -398,8 +440,9 @@ LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
|
||||
auto smemObjA = getSharedMemoryObjectFromStruct(loc, llA, rewriter);
|
||||
auto smemObjB = getSharedMemoryObjectFromStruct(loc, llB, rewriter);
|
||||
return convertDot(typeConverter, rewriter, loc, op.getOperation(), A, B, C,
|
||||
op.getD(), llA, llB, llC, op.getAllowTF32(), smemObjA,
|
||||
smemObjB, true, thread);
|
||||
op.getD(), llA, llB, llC, op.getAllowTF32(),
|
||||
op.getMaxNumImpreciseAcc(), smemObjA, smemObjB, true,
|
||||
thread);
|
||||
}
|
||||
|
||||
LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op,
|
||||
@@ -426,6 +469,7 @@ LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op,
|
||||
auto smemObjA = getSharedMemoryObjectFromStruct(loc, llA, rewriter);
|
||||
auto smemObjB = getSharedMemoryObjectFromStruct(loc, llB, rewriter);
|
||||
return convertDot(typeConverter, rewriter, loc, op.getOperation(), A, B, C,
|
||||
op.getD(), llA, llB, llC, op.getAllowTF32(), smemObjA,
|
||||
smemObjB, false, thread);
|
||||
op.getD(), llA, llB, llC, op.getAllowTF32(),
|
||||
op.getMaxNumImpreciseAcc(), smemObjA, smemObjB, false,
|
||||
thread);
|
||||
}
|
||||
|
||||
@@ -342,7 +342,8 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);
|
||||
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||
op, retType, a, b, c, adaptor.getAllowTF32()),
|
||||
op, retType, a, b, c, adaptor.getAllowTF32(),
|
||||
adaptor.getMaxNumImpreciseAcc()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -181,8 +181,9 @@ public:
|
||||
op->getLoc(), newAccType,
|
||||
rewriter.create<arith::ConstantOp>(op->getLoc(),
|
||||
rewriter.getF32FloatAttr(0)));
|
||||
rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||
op, expandLhsOp.getOperand(), expandRhsOp.getOperand(), newAcc, true);
|
||||
rewriter.replaceOpWithNewOp<triton::DotOp>(op, expandLhsOp.getOperand(),
|
||||
expandRhsOp.getOperand(), newAcc,
|
||||
true, 0);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -12,22 +12,24 @@ include "mlir/IR/PatternBase.td"
|
||||
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
|
||||
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
|
||||
def CombineDotAddIPattern : Pat<
|
||||
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc)),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc),
|
||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||
def CombineDotAddFPattern : Pat<
|
||||
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32), $fastmath),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $fastmath),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc),
|
||||
[(Constraint<CPred<"isZero($0)">> $c),
|
||||
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc)]>;
|
||||
|
||||
def CombineDotAddIRevPattern : Pat<
|
||||
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $d),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc),
|
||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||
def CombineDotAddFRevPattern : Pat<
|
||||
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d, $fastmath),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $d, $fastmath),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc),
|
||||
[(Constraint<CPred<"isZero($0)">> $c),
|
||||
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc)]>;
|
||||
|
||||
// TODO: this fails for addptr(addptr(ptr, i32), i64)
|
||||
// Commented out until fixed
|
||||
|
||||
@@ -316,7 +316,8 @@ public:
|
||||
}
|
||||
// convert dot instruction
|
||||
auto newDot = rewriter.create<tt::DotOp>(dotOp.getLoc(), newRetType, a, b,
|
||||
newAcc, dotOp.getAllowTF32());
|
||||
newAcc, dotOp.getAllowTF32(),
|
||||
dotOp.getMaxNumImpreciseAcc());
|
||||
|
||||
rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(op, oldRetType,
|
||||
newDot.getResult());
|
||||
|
||||
@@ -1640,7 +1640,8 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
|
||||
auto dotOp = cast<tt::DotOp>(dot.getDefiningOp());
|
||||
builder.setInsertionPoint(dot.getDefiningOp());
|
||||
auto dotAsync = builder.create<tt::nvidia_gpu::DotAsyncOp>(
|
||||
loc, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32());
|
||||
loc, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32(),
|
||||
dotOp.getMaxNumImpreciseAcc());
|
||||
dot.replaceAllUsesWith(dotAsync.getResult());
|
||||
updateConsumerReleaseInfo(dot.getDefiningOp(), dotWait, /*stage=*/1);
|
||||
dot.getDefiningOp()->erase();
|
||||
|
||||
@@ -117,7 +117,8 @@ public:
|
||||
op->getLoc(), dotOp.getResult().getType(), _0f);
|
||||
auto newDot = rewriter.create<triton::DotOp>(
|
||||
op->getLoc(), dotOp.getResult().getType(), dotOp.getOperand(0),
|
||||
dotOp.getOperand(1), _0, dotOp.getAllowTF32());
|
||||
dotOp.getOperand(1), _0, dotOp.getAllowTF32(),
|
||||
dotOp.getMaxNumImpreciseAcc());
|
||||
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), dstTy, newDot.getResult());
|
||||
rewriter.replaceOpWithNewOp<arith::AddFOp>(op, newCvt, cvtOp.getOperand());
|
||||
|
||||
@@ -50,9 +50,9 @@ public:
|
||||
.getEncoding()
|
||||
.dyn_cast<ttg::MmaEncodingAttr>();
|
||||
auto isHopperEncoding = mmaEncoding && mmaEncoding.isHopper();
|
||||
if (isHopperEncoding && (isa<ttg::ConvertLayoutOp>(a.getDefiningOp()) &&
|
||||
if (isHopperEncoding && (a.getDefiningOp<ttg::ConvertLayoutOp>() &&
|
||||
ttg::isSharedEncoding(a)) ||
|
||||
(isa<ttg::ConvertLayoutOp>(b.getDefiningOp()) &&
|
||||
(b.getDefiningOp<ttg::ConvertLayoutOp>() &&
|
||||
ttg::isSharedEncoding(b))) {
|
||||
|
||||
// TODO: check whether cluster fence is needed
|
||||
|
||||
@@ -736,7 +736,7 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
|
||||
auto dotAsync =
|
||||
builder.createWithAgentIds<triton::nvidia_gpu::DotAsyncOp>(
|
||||
loc, dotOp.getA(), dotOp.getB(), dotOp.getC(),
|
||||
dotOp.getAllowTF32());
|
||||
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
|
||||
dot.replaceAllUsesWith(dotAsync.getResult());
|
||||
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(loc, 1);
|
||||
|
||||
|
||||
@@ -1482,9 +1482,10 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
.def("create_dot",
|
||||
[](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b,
|
||||
mlir::Value &c, bool allowTF32) -> mlir::Value {
|
||||
return self.create<mlir::triton::DotOp>(c.getType(), a, b, c,
|
||||
allowTF32);
|
||||
mlir::Value &c, bool allowTF32,
|
||||
int maxNumImpreciseAcc) -> mlir::Value {
|
||||
return self.create<mlir::triton::DotOp>(
|
||||
c.getType(), a, b, c, allowTF32, maxNumImpreciseAcc);
|
||||
})
|
||||
.def("create_exp",
|
||||
[](TritonOpBuilder &self, mlir::Value &val) -> mlir::Value {
|
||||
|
||||
@@ -131,8 +131,8 @@ def check_type_supported(dtype, device):
|
||||
cc = torch.cuda.get_device_capability()
|
||||
if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||
if cc[0] < 9 and (dtype is tl.float8e4nv or dtype == "float8e4"):
|
||||
pytest.skip("float8e4 is only supported on NVGPU with cc >= 90")
|
||||
if cc[0] < 9 and (dtype is tl.float8e4nv or dtype == "float8e4nv"):
|
||||
pytest.skip("float8e4nv is only supported on NVGPU with cc >= 90")
|
||||
|
||||
|
||||
class MmaLayout:
|
||||
@@ -3750,3 +3750,86 @@ def test_ptx_cast(dtype_str, device):
|
||||
buf14 = -torch.ones((s0, 6, 197, 197), device=device, dtype=torch_dtype)
|
||||
kernel[(4728,)](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2)
|
||||
assert buf14.to(torch.float32).mean() == -2.0
|
||||
|
||||
# -----------------------
|
||||
# test fp8 -> fp32 dot
|
||||
# -----------------------
|
||||
|
||||
|
||||
def f8_to_f16(x, dtype):
|
||||
@triton.jit
|
||||
def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offs < N
|
||||
x = tl.load(X + offs, mask=mask)
|
||||
tl.store(Y + offs, x, mask=mask)
|
||||
|
||||
ret = torch.empty(x.shape, dtype=torch.float16, device=x.device)
|
||||
grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),)
|
||||
dtype = getattr(tl, dtype)
|
||||
kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024)
|
||||
return ret
|
||||
|
||||
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
low_precision_acc: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
pid_m = pid % num_pid_m
|
||||
pid_n = pid // num_pid_m
|
||||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(a_ptrs)
|
||||
b = tl.load(b_ptrs)
|
||||
accumulator = tl.dot(a, b, acc=accumulator, max_num_imprecise_acc=low_precision_acc)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
tl.store(c_ptrs, accumulator)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_type_str", ['float8e5', 'float8e4nv'])
|
||||
@pytest.mark.parametrize("low_precision_acc", [0, 32, 64, 128])
|
||||
def test_fp8_dot_acc(in_type_str, low_precision_acc, device):
|
||||
check_type_supported(in_type_str, device)
|
||||
M, N, K = 128, 256, 256
|
||||
BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 128
|
||||
A = numpy_random((M, K), dtype_str=in_type_str)
|
||||
B = numpy_random((K, N), dtype_str=in_type_str)
|
||||
Bt = B.T
|
||||
C = torch.empty((M, N), dtype=torch.float32, device='cuda')
|
||||
num_warps = 8
|
||||
a = to_triton(A, device='cuda', dst_type=in_type_str)
|
||||
b = to_triton(B, device='cuda', dst_type=in_type_str)
|
||||
grid = (triton.cdiv(M, BLOCK_M), 1)
|
||||
matmul_kernel[grid](a, b, C, M, N, K,
|
||||
a.stride(0), a.stride(1), b.stride(0), b.stride(
|
||||
1), C.stride(0), C.stride(1),
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, low_precision_acc, num_warps=num_warps)
|
||||
torch_a = torch.from_numpy(A)
|
||||
th_a = f8_to_f16(torch_a.cuda(), in_type_str)
|
||||
torch_b = torch.from_numpy(B)
|
||||
th_b = f8_to_f16(torch_b.cuda(), in_type_str)
|
||||
ref_out = torch.matmul(th_a, th_b).to(torch.float32)
|
||||
if in_type_str == 'float8e4nv':
|
||||
torch.testing.assert_close(ref_out, C, rtol=0.01, atol=0.01)
|
||||
elif low_precision_acc > 32:
|
||||
torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
torch.testing.assert_close(ref_out, C)
|
||||
|
||||
@@ -26,61 +26,61 @@ def f8_to_f16(x, dtype):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32",
|
||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM",
|
||||
itertools.chain(
|
||||
*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
# 2 warp
|
||||
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
# 4 warp
|
||||
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
# 8 warp
|
||||
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
# variable input
|
||||
(128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True),
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||
],
|
||||
# n-stage
|
||||
*[
|
||||
[
|
||||
(16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True),
|
||||
(64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True),
|
||||
(256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True),
|
||||
(16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True),
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [4]
|
||||
],
|
||||
# mixed-precision
|
||||
*[
|
||||
[
|
||||
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True),
|
||||
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True),
|
||||
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True),
|
||||
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
|
||||
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
|
||||
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
|
||||
] for ADTYPE, BDTYPE in [("float8e4nv", "float8e5"),
|
||||
("float8e4nv", "float8e4nv"),
|
||||
("float8e5", "float8e4nv"),
|
||||
@@ -91,14 +91,14 @@ def f8_to_f16(x, dtype):
|
||||
("float16", "float32"),
|
||||
("float32", "float16"),
|
||||
("bfloat16", "float32"),
|
||||
("float32", "bfloat16")] for AT in [False, True] for BT in [False, True]
|
||||
("float32", "bfloat16")] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False]
|
||||
],
|
||||
# mixed-precision block layout
|
||||
*[
|
||||
[
|
||||
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False),
|
||||
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False),
|
||||
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False),
|
||||
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True),
|
||||
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True),
|
||||
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True),
|
||||
] for ADTYPE, BDTYPE in [("float8e4nv", "float16"),
|
||||
("float16", "float8e5"),
|
||||
("float16", "float32"),
|
||||
@@ -108,7 +108,7 @@ def f8_to_f16(x, dtype):
|
||||
],
|
||||
),
|
||||
)
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32):
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 7:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
@@ -176,7 +176,7 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
|
||||
a = triton.reinterpret(a, getattr(tl, ADTYPE))
|
||||
if b_fp8:
|
||||
b = triton.reinterpret(b, getattr(tl, BDTYPE))
|
||||
tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32)
|
||||
tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32, F8_FASTACCUM)
|
||||
torch.testing.assert_close(th_c, tt_c)
|
||||
except triton.OutOfResources as e:
|
||||
pytest.skip(str(e))
|
||||
|
||||
@@ -985,7 +985,7 @@ def expand_dims(input, axis, _builder=None):
|
||||
|
||||
|
||||
@builtin
|
||||
def dot(input, other, allow_tf32=True, out_dtype=float32, _builder=None):
|
||||
def dot(input, other, acc=None, allow_tf32=True, max_num_imprecise_acc=None, out_dtype=float32, _builder=None):
|
||||
"""
|
||||
Returns the matrix product of two blocks.
|
||||
|
||||
@@ -998,7 +998,7 @@ def dot(input, other, allow_tf32=True, out_dtype=float32, _builder=None):
|
||||
"""
|
||||
allow_tf32 = _constexpr_to_value(allow_tf32)
|
||||
out_dtype = _constexpr_to_value(out_dtype)
|
||||
return semantic.dot(input, other, allow_tf32, out_dtype, _builder)
|
||||
return semantic.dot(input, other, acc, allow_tf32, max_num_imprecise_acc, out_dtype, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
|
||||
@@ -1265,7 +1265,9 @@ def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
|
||||
|
||||
def dot(lhs: tl.tensor,
|
||||
rhs: tl.tensor,
|
||||
acc: tl.tensor,
|
||||
allow_tf32: bool,
|
||||
max_num_imprecise_acc: int,
|
||||
out_dtype: tl.dtype,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def assert_dtypes_valid(lhs_dtype, rhs_dtype, arch):
|
||||
@@ -1343,10 +1345,20 @@ def dot(lhs: tl.tensor,
|
||||
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
|
||||
ret_ty)
|
||||
return cast(ret, ret_scalar_ty, builder)
|
||||
|
||||
_0 = builder.create_splat(_0, [M, N])
|
||||
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
|
||||
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
|
||||
if acc is None:
|
||||
acc_handle = builder.create_splat(_0, [M, N])
|
||||
else:
|
||||
acc_handle = acc.handle
|
||||
assert acc.type == ret_ty
|
||||
|
||||
# max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
|
||||
if not (_is_cuda(builder.arch) and builder.arch == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and ret_scalar_ty.is_fp32()):
|
||||
max_num_imprecise_acc = 0
|
||||
if max_num_imprecise_acc is None:
|
||||
max_num_imprecise_acc = 2**30
|
||||
|
||||
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc),
|
||||
ret_ty)
|
||||
|
||||
|
||||
|
||||
@@ -82,6 +82,7 @@ def _kernel(A, B, C, M, N, K,
|
||||
stride_cm, stride_cn,
|
||||
dot_out_dtype: tl.constexpr,
|
||||
allow_tf32: tl.constexpr,
|
||||
fp8_fast_accum: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr
|
||||
):
|
||||
@@ -118,7 +119,10 @@ def _kernel(A, B, C, M, N, K,
|
||||
if AB_DTYPE:
|
||||
a = a.to(C.dtype.element_ty)
|
||||
b = b.to(C.dtype.element_ty)
|
||||
acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
||||
if fp8_fast_accum:
|
||||
acc = tl.dot(a, b, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
||||
else:
|
||||
acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
||||
A += BLOCK_K * SPLIT_K * stride_ak
|
||||
B += BLOCK_K * SPLIT_K * stride_bk
|
||||
acc = acc.to(C.dtype.element_ty)
|
||||
@@ -140,7 +144,7 @@ class _matmul(torch.autograd.Function):
|
||||
_locks = {}
|
||||
|
||||
@staticmethod
|
||||
def _call(a, b, dot_out_dtype, allow_tf32):
|
||||
def _call(a, b, dot_out_dtype, allow_tf32, fp8_fast_accum):
|
||||
device = a.device
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||
@@ -182,12 +186,13 @@ class _matmul(torch.autograd.Function):
|
||||
c.stride(0), c.stride(1),
|
||||
dot_out_dtype=dot_out_dtype,
|
||||
allow_tf32=allow_tf32,
|
||||
fp8_fast_accum=fp8_fast_accum,
|
||||
GROUP_M=8, AB_DTYPE=ab_dtype)
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, dot_out_dtype=None, allow_tf32=True):
|
||||
return _matmul._call(a, b, dot_out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
||||
def forward(ctx, a, b, dot_out_dtype=None, allow_tf32=True, fp8_fast_accum=True):
|
||||
return _matmul._call(a, b, dot_out_dtype=dot_out_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum)
|
||||
|
||||
|
||||
matmul = _matmul.apply
|
||||
|
||||
@@ -26,7 +26,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
|
||||
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT>
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT>
|
||||
%c = tt.dot %a, %b, %prev_c {transA = false, transB = false, allowTF32 = true} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
|
||||
%c = tt.dot %a, %b, %prev_c {transA = false, transB = false, allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
|
||||
@@ -34,7 +34,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
|
||||
// CHECK-NEXT: offset = 0, size = 4224
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT>
|
||||
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
@@ -64,11 +64,11 @@ tt.func @reusable(%A : !tt.ptr<f16>) {
|
||||
%a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
// CHECK-NEXT: offset = 0, size = 4608
|
||||
%a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT>
|
||||
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
|
||||
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
|
||||
%a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
||||
// CHECK-NEXT: offset = 0, size = 1152
|
||||
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #B_DOT>
|
||||
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
|
||||
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 4608
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
|
||||
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT>
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
tt.func @convert_dot(%A: tensor<16x16xf32, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) {
|
||||
// expected-error@+1 {{element types of operands A and B must have same bit width}}
|
||||
%D = tt.dot %A, %B, %C {allowTF32 = true, transA = false, transB = false} :
|
||||
%D = tt.dot %A, %B, %C {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} :
|
||||
tensor<16x16xf32, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
|
||||
tt.return
|
||||
}
|
||||
@@ -20,7 +20,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
tt.func @convert_dot(%A: tensor<16x16xf16>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) {
|
||||
// expected-error@+1 {{mismatching encoding between A and B operands}}
|
||||
%D = tt.dot %A, %B, %C {allowTF32 = true, transA = false, transB = false} :
|
||||
%D = tt.dot %A, %B, %C {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} :
|
||||
tensor<16x16xf16> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
|
||||
tt.return
|
||||
}
|
||||
@@ -34,7 +34,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
tt.func @convert_dot(%A: tensor<16x16xf16, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) {
|
||||
// expected-error@+1 {{mismatching kWidth between A and B operands}}
|
||||
%D = tt.dot %A, %B, %C {allowTF32 = true, transA = false, transB = false} :
|
||||
%D = tt.dot %A, %B, %C {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} :
|
||||
tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
|
||||
tt.return
|
||||
}
|
||||
|
||||
@@ -161,13 +161,13 @@ tt.func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
|
||||
%zero1x1 = arith.constant dense<0.00e+00> : tensor<1x1xf32>
|
||||
|
||||
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32>
|
||||
%r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32>
|
||||
%r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32>
|
||||
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<32x32xf32>
|
||||
%r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true, transA = false, transB = false} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32>
|
||||
%r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32>
|
||||
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32>
|
||||
%r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true, transA = false, transB = false} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32>
|
||||
%r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32>
|
||||
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<1x1xf32>
|
||||
%r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true, transA = false, transB = false} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32>
|
||||
%r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32>
|
||||
|
||||
%ptr128x128 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x128x!tt.ptr<f32>>
|
||||
%ptr32x32 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>>
|
||||
|
||||
@@ -5,7 +5,7 @@ tt.func @ops() {
|
||||
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
|
||||
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
|
||||
%c = arith.constant dense<3.00e+00> : tensor<128x128xf32>
|
||||
%0 = tt.dot %a, %b, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32>
|
||||
%0 = tt.dot %a, %b, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32>
|
||||
tt.return
|
||||
}
|
||||
|
||||
|
||||
@@ -834,7 +834,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
|
||||
%D = tt.dot %AA_DOT, %BB_DOT, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
|
||||
%D = tt.dot %AA_DOT, %BB_DOT, %cst0 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
|
||||
|
||||
tt.return
|
||||
}
|
||||
@@ -967,7 +967,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
%a_mat = triton_gpu.convert_layout %a : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #dot_operand_a>
|
||||
%b_mat = triton_gpu.convert_layout %b : (tensor<32x256xf16, #shared>) -> tensor<32x256xf16, #dot_operand_b>
|
||||
|
||||
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma>
|
||||
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma>
|
||||
%38 = triton_gpu.convert_layout %28 : (tensor<128x256xf32, #mma>) -> tensor<128x256xf32, #blocked>
|
||||
|
||||
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>, #blocked>
|
||||
@@ -993,7 +993,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
%a_mat = triton_gpu.convert_layout %a : (tensor<32x64xf16, #shared0>) -> tensor<32x64xf16, #dot_operand_a>
|
||||
%b_mat = triton_gpu.convert_layout %b : (tensor<64x64xf16, #shared1>) -> tensor<64x64xf16, #dot_operand_b>
|
||||
|
||||
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<32x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<32x64xf32, #mma>
|
||||
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<32x64xf32, #mma>
|
||||
%38 = triton_gpu.convert_layout %28 : (tensor<32x64xf32, #mma>) -> tensor<32x64xf32, #blocked>
|
||||
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
|
||||
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x64x!tt.ptr<f32>, #blocked>
|
||||
@@ -1016,7 +1016,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
%a_mat = triton_gpu.convert_layout %a : (tensor<32x16xf32, #shared>) -> tensor<32x16xf32, #dot_operand_a>
|
||||
%b_mat = triton_gpu.convert_layout %b : (tensor<16x32xf32, #shared>) -> tensor<16x32xf32, #dot_operand_b>
|
||||
|
||||
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = false, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #blocked>
|
||||
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = false, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #blocked>
|
||||
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
|
||||
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x32x!tt.ptr<f32>, #blocked>
|
||||
tt.store %36, %28 : tensor<32x32xf32, #blocked>
|
||||
@@ -1053,7 +1053,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
|
||||
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma>
|
||||
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma>
|
||||
%38 = triton_gpu.convert_layout %28 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked>
|
||||
|
||||
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
|
||||
@@ -1265,7 +1265,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
// CHECK: %[[SI:.+]] = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
|
||||
// CHECK: llvm.insertvalue %[[BC]], %[[SI]][0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
|
||||
%b_mat = arith.constant dense<1.000000e+00> : tensor<16x32xf32, #dot_operand_b>
|
||||
%28 = tt.dot %a, %b_mat, %c {allowTF32 = true, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma>
|
||||
%28 = tt.dot %a, %b_mat, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma>
|
||||
%38 = triton_gpu.convert_layout %28 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked>
|
||||
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
|
||||
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x32x!tt.ptr<f32>, #blocked>
|
||||
@@ -1295,7 +1295,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
%cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
|
||||
%cst_1 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
|
||||
%cst_2 = arith.constant dense<32> : tensor<32x1xi32, #blocked>
|
||||
%0 = tt.dot %cst_0, %cst_1, %cst {allowTF32 = true} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma>
|
||||
%0 = tt.dot %cst_0, %cst_1, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma>
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked>
|
||||
%2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked>
|
||||
|
||||
@@ -78,3 +78,74 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>
|
||||
#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} {
|
||||
// CHECK-LABEL: @dot_high_precision_acc
|
||||
tt.func @dot_high_precision_acc(%a: tensor<128x128xf8E5M2, #shared>, %b: tensor<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) {
|
||||
// CHECK: nvgpu.wgmma
|
||||
// CHECK-COUNT-128: llvm.fadd
|
||||
// CHECK: nvgpu.wgmma
|
||||
// CHECK-COUNT-128: llvm.fadd
|
||||
// CHECK: nvgpu.wgmma
|
||||
// CHECK-COUNT-128: llvm.fadd
|
||||
// CHECK: nvgpu.wgmma
|
||||
// CHECK-COUNT-128: llvm.fadd
|
||||
%m = triton_nvidia_gpu.dot_async %a, %b, %c
|
||||
{maxNumImpreciseAcc = 32 : i32, allowTF32 = true} :
|
||||
tensor<128x128xf8E5M2, #shared> * tensor<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>
|
||||
#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} {
|
||||
// CHECK-LABEL: @dot_low_precision_acc
|
||||
tt.func @dot_low_precision_acc(%a: tensor<128x128xf8E5M2, #shared>, %b: tensor<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) {
|
||||
// CHECK: nvgpu.wgmma
|
||||
// CHECK-NOT: llvm.fadd
|
||||
// CHECK: nvgpu.wgmma
|
||||
// CHECK-NOT: llvm.fadd
|
||||
// CHECK: nvgpu.wgmma
|
||||
// CHECK-NOT: llvm.fadd
|
||||
// CHECK: nvgpu.wgmma
|
||||
// CHECK-NOT: llvm.fadd
|
||||
// CHECK: llvm.return
|
||||
%m = triton_nvidia_gpu.dot_async %a, %b, %c
|
||||
{maxNumImpreciseAcc = 129 : i32, allowTF32 = true} :
|
||||
tensor<128x128xf8E5M2, #shared> * tensor<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>
|
||||
#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} {
|
||||
// CHECK-LABEL: @dot_mix_precision_acc
|
||||
tt.func @dot_mix_precision_acc(%a: tensor<128x128xf8E5M2, #shared>, %b: tensor<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) {
|
||||
// CHECK: nvgpu.wgmma
|
||||
// CHECK-NOT: llvm.fadd
|
||||
// CHECK: nvgpu.wgmma
|
||||
// CHECK-COUNT-128: llvm.fadd
|
||||
// CHECK: nvgpu.wgmma
|
||||
// CHECK-NOT: llvm.fadd
|
||||
// CHECK: nvgpu.wgmma
|
||||
// CHECK-COUNT-128: llvm.fadd
|
||||
// CHECK: llvm.return
|
||||
%m = triton_nvidia_gpu.dot_async %a, %b, %c
|
||||
{maxNumImpreciseAcc = 64 : i32, allowTF32 = true} :
|
||||
tensor<128x128xf8E5M2, #shared> * tensor<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,3 +17,16 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2
|
||||
tt.return
|
||||
}
|
||||
} // end module
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} {
|
||||
tt.func @wgmma_no_acc(%descA: i64, %descB: i64) {
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63,$64,$65,$66,$67,$68,$69,$70,$71,$72,$73,$74,$75,$76,$77,$78,$79,$80,$81,$82,$83,$84,$85,$86,$87,$88,$89,$90,$91,$92,$93,$94,$95,$96,$97,$98,$99,$100,$101,$102,$103,$104,$105,$106,$107,$108,$109,$110,$111,$112,$113,$114,$115,$116,$117,$118,$119,$120,$121,$122,$123,$124,$125,$126,$127}, $128, $129, 0, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l" %0, %1 : (i64, i64) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
%acc0 = nvgpu.wgmma %descA, %descB
|
||||
{eltTypeA = 3 : i32, eltTypeB = 3 : i32, eltTypeC = 7 : i32, k = 32 : i32, layoutA = 0 : i32, layoutB = 1 : i32, m = 64 : i32, n = 256 : i32} :
|
||||
(i64, i64) ->
|
||||
!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,12 +10,12 @@ tt.func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128x
|
||||
%zero = arith.constant dense<0.0> : tensor<128x128xf32>
|
||||
%d = arith.constant dense<3.0> : tensor<128x128xf32>
|
||||
|
||||
%dot_out = tt.dot %a, %b, %zero {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
%dot_out = tt.dot %a, %b, %zero {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
|
||||
// CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
// CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
%res0 = arith.addf %dot_out, %d : tensor<128x128xf32>
|
||||
|
||||
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
%res1 = arith.addf %d, %dot_out : tensor<128x128xf32>
|
||||
|
||||
tt.return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>
|
||||
|
||||
@@ -1543,7 +1543,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
|
||||
%26 = triton_gpu.convert_layout %19 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>>
|
||||
%27 = triton_gpu.convert_layout %25 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>>
|
||||
%28 = triton_gpu.convert_layout %cst : (tensor<32x32xf32, #blocked>) -> tensor<32x32xf32, #blocked5>
|
||||
%29 = tt.dot %26, %27, %28 {allowTF32 = true} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>> -> tensor<32x32xf32, #blocked5>
|
||||
%29 = tt.dot %26, %27, %28 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>> -> tensor<32x32xf32, #blocked5>
|
||||
%30 = triton_gpu.convert_layout %29 : (tensor<32x32xf32, #blocked5>) -> tensor<32x32xf32, #blocked>
|
||||
%31:2 = "tt.reduce"(%30, %11) <{axis = 1 : i32}> ({
|
||||
^bb0(%arg3: f32, %arg4: i32, %arg5: f32, %arg6: i32):
|
||||
@@ -1690,7 +1690,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%117 = tt.load %116 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf16, #blocked3>
|
||||
%118 = triton_gpu.convert_layout %41 : (tensor<128x64xf16, #blocked2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
|
||||
%119 = triton_gpu.convert_layout %97 : (tensor<64x64xf16, #blocked6>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
|
||||
%120 = tt.dot %118, %119, %cst {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf16, #blocked>
|
||||
%120 = tt.dot %118, %119, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf16, #blocked>
|
||||
%121 = triton_gpu.convert_layout %120 : (tensor<128x64xf16, #blocked>) -> tensor<128x64xf16, #blocked2>
|
||||
%122 = arith.extf %121 : tensor<128x64xf16, #blocked2> to tensor<128x64xf32, #blocked2>
|
||||
%123 = "tt.reduce"(%122) <{axis = 1 : i32}> ({
|
||||
@@ -1719,7 +1719,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%142 = triton_gpu.convert_layout %141 : (tensor<128x64xf16, #blocked2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
|
||||
%143 = triton_gpu.convert_layout %117 : (tensor<64x64xf16, #blocked3>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
|
||||
%144 = triton_gpu.convert_layout %140 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #blocked>
|
||||
%145 = tt.dot %142, %143, %144 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf32, #blocked>
|
||||
%145 = tt.dot %142, %143, %144 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf32, #blocked>
|
||||
%146 = triton_gpu.convert_layout %145 : (tensor<128x64xf32, #blocked>) -> tensor<128x64xf32, #blocked2>
|
||||
%147 = arith.mulf %arg24, %127 : tensor<128xf32, #blocked1>
|
||||
%148 = "tt.reduce"(%133) <{axis = 1 : i32}> ({
|
||||
|
||||
@@ -36,7 +36,7 @@ tt.func @push_elementwise(
|
||||
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR>
|
||||
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #ALR>) -> tensor<16x16xf16, #Av2k4>
|
||||
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLC>) -> tensor<16x16xf16, #Bv2k4>
|
||||
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2>
|
||||
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2>
|
||||
tt.return %newc : tensor<16x16xf32, #Cv2>
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ tt.func @succeeds_if_arg_is_not_convert_layout(
|
||||
%dotaf8 = tt.bitcast %dotai8 : tensor<16x16xi8, #Av2k4> -> tensor<16x16xf8E5M2, #Av2k4>
|
||||
%dota = tt.fp_to_fp %dotaf8 : tensor<16x16xf8E5M2, #Av2k4> -> tensor<16x16xf16, #Av2k4>
|
||||
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLC>) -> tensor<16x16xf16, #Bv2k4>
|
||||
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2>
|
||||
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2>
|
||||
tt.return %newc : tensor<16x16xf32, #Cv2>
|
||||
}
|
||||
|
||||
@@ -82,7 +82,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.compute-capabil
|
||||
// CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>>
|
||||
// CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %[[BLOAD]] : (tensor<16x16xf16, #[[BB]]>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
|
||||
// CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
|
||||
// CHECK: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma>
|
||||
// CHECK: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma>
|
||||
tt.func @push_convert_both_operands(
|
||||
%pa: tensor<16x16x!tt.ptr<f16>, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%pb: tensor<16x16x!tt.ptr<f16>, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
@@ -93,7 +93,7 @@ tt.func @push_convert_both_operands(
|
||||
%be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB>
|
||||
%al = triton_gpu.convert_layout %ae : (tensor<16x16xf32, #blockedA>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
|
||||
%bl = triton_gpu.convert_layout %be : (tensor<16x16xf32, #blockedB>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
|
||||
%r = tt.dot %al, %bl, %c {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma>
|
||||
%r = tt.dot %al, %bl, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma>
|
||||
tt.return %r : tensor<16x16xf32, #mma>
|
||||
}
|
||||
|
||||
@@ -119,7 +119,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.compute-capabil
|
||||
// CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %[[BLOAD]] : (tensor<16x16xf16, #[[BB]]>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
|
||||
// CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
|
||||
// CHECK: %[[ADD:.+]] = arith.addf %[[BEXT]], %[[CST]] : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
|
||||
// CHECK: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma>
|
||||
// CHECK: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma>
|
||||
tt.func @update_kwidth_slice(
|
||||
%pa: tensor<16x16x!tt.ptr<f16>, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%pb: tensor<16x16x!tt.ptr<f16>, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
@@ -132,7 +132,7 @@ tt.func @update_kwidth_slice(
|
||||
%add = arith.addf %be, %cst : tensor<16x16xf32, #blockedB>
|
||||
%al = triton_gpu.convert_layout %ae : (tensor<16x16xf32, #blockedA>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
|
||||
%bl = triton_gpu.convert_layout %add : (tensor<16x16xf32, #blockedB>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
|
||||
%r = tt.dot %al, %bl, %c {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma>
|
||||
%r = tt.dot %al, %bl, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma>
|
||||
tt.return %r : tensor<16x16xf32, #mma>
|
||||
}
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
@@ -151,7 +151,7 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
@@ -220,7 +220,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
|
||||
scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
@@ -293,7 +293,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
|
||||
//
|
||||
// %sa = triton_gpu.convert_layout %a : (tensor<128x32xf16, #BA>) -> tensor<128x32xf16, #SA>
|
||||
// %sb = triton_gpu.convert_layout %b : (tensor<32x128xf16, #BB>) -> tensor<32x128xf16, #SB>
|
||||
// %c = tt.dot %sa, %sb, %prev_c {allowTF32 = true} : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C>
|
||||
// %c = tt.dot %sa, %sb, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C>
|
||||
//
|
||||
// %a_tileptr_next = tt.advance %a_tileptr, [%c0, %c32_i32] : !tt.ptr<tensor<128x32xf16>, 1>
|
||||
// %b_tileptr_next = tt.advance %b_tileptr, [%c32_i32, %c0] : !tt.ptr<tensor<32x128xf16>, 1>
|
||||
|
||||
@@ -84,7 +84,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
%b_ = triton_gpu.convert_layout %b__ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
%b = arith.mulf %b_, %b_scale: tensor<32x128xf16, #B>
|
||||
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
@@ -157,7 +157,7 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
@@ -224,7 +224,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
|
||||
%loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
@@ -266,7 +266,7 @@ tt.func @lut_bmm_scalar(%77: i64 {tt.divisibility=16: i32},
|
||||
%87 = tt.load %86 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL>
|
||||
%88 = triton_gpu.convert_layout %82 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A>
|
||||
%89 = triton_gpu.convert_layout %87 : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #B>
|
||||
%90 = tt.dot %88, %89, %arg19 {allowTF32 = true} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C>
|
||||
%90 = tt.dot %88, %89, %arg19 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C>
|
||||
%91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr<f16>, #AL>, tensor<16x16xi32, #AL>
|
||||
%92 = tt.addptr %arg21, %c1_i32 : !tt.ptr<i64>, i32
|
||||
scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr<f16>, #AL>, !tt.ptr<i64>
|
||||
@@ -312,7 +312,7 @@ tt.func @lut_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt
|
||||
%87 = tt.load %86 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL>
|
||||
%88 = triton_gpu.convert_layout %82 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A>
|
||||
%89 = triton_gpu.convert_layout %87 : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #B>
|
||||
%90 = tt.dot %88, %89, %arg19 {allowTF32 = true} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C>
|
||||
%90 = tt.dot %88, %89, %arg19 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C>
|
||||
%91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr<f16>, #AL>, tensor<16x16xi32, #AL>
|
||||
%92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr<i64>, #BLs1>, tensor<16xi32, #BLs1>
|
||||
scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr<f16>, #AL>, tensor<16x!tt.ptr<i64>, #BLs1>
|
||||
@@ -362,7 +362,7 @@ tt.func @post_load_inv(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%116 = tt.load %arg12, %115, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL>
|
||||
%117 = triton_gpu.convert_layout %112 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>>
|
||||
%118 = triton_gpu.convert_layout %116 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>>
|
||||
%119 = tt.dot %117, %118, %arg10 {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C>
|
||||
%119 = tt.dot %117, %118, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C>
|
||||
%131 = arith.index_cast %arg9 : index to i32
|
||||
%120 = arith.addi %131, %c1_i32 : i32
|
||||
%121 = arith.muli %120, %c32_i32 : i32
|
||||
@@ -425,7 +425,7 @@ tt.func @cross_iter_dep(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%150 = tt.load %arg12, %149, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL>
|
||||
%151 = triton_gpu.convert_layout %146 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>>
|
||||
%152 = triton_gpu.convert_layout %150 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>>
|
||||
%153 = tt.dot %151, %152, %arg10 {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C>
|
||||
%153 = tt.dot %151, %152, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C>
|
||||
%162 = arith.index_cast %arg9 : index to i32
|
||||
%154 = arith.addi %162, %c2_i32 : i32
|
||||
%155 = arith.muli %154, %c32_i32 : i32
|
||||
@@ -497,7 +497,7 @@ tt.func @dep_arg_two_uses(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
||||
%199 = tt.load %arg24, %198, %88 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%200 = triton_gpu.convert_layout %193 : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>>
|
||||
%201 = triton_gpu.convert_layout %199 : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>>
|
||||
%202 = tt.dot %200, %201, %arg23 {allowTF32 = true} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> -> tensor<128x128xf32, #C>
|
||||
%202 = tt.dot %200, %201, %arg23 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> -> tensor<128x128xf32, #C>
|
||||
%203 = tt.addptr %arg24, %90 : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi64, #BL>
|
||||
scf.yield %190, %196, %197, %202, %203 : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>>, !tt.ptr<i32>, tensor<128x128xf32, #C>, tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%8 = tt.load %6 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<16x16xf16, #blockedB0>, 1> -> tensor<16x16xf16, #blockedB1>
|
||||
%9 = triton_gpu.convert_layout %7 : (tensor<64x16xf16, #blockedA1>) -> tensor<64x16xf16, #sharedA>
|
||||
%10 = triton_gpu.convert_layout %8 : (tensor<16x16xf16, #blockedB1>) -> tensor<16x16xf16, #sharedB>
|
||||
%11 = tt.dot %9, %10, %cst {allowTF32 = true} : tensor<64x16xf16, #sharedA> * tensor<16x16xf16, #sharedB> -> tensor<64x16xf32, #mma>
|
||||
%11 = tt.dot %9, %10, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #sharedA> * tensor<16x16xf16, #sharedB> -> tensor<64x16xf32, #mma>
|
||||
%12 = triton_gpu.convert_layout %11 : (tensor<64x16xf32, #mma>) -> tensor<64x16xf32, #blockedA1>
|
||||
%13 = arith.truncf %12 : tensor<64x16xf32, #blockedA1> to tensor<64x16xf16, #blockedA1>
|
||||
%14 = arith.extsi %arg8 : i32 to i64
|
||||
|
||||
@@ -62,7 +62,7 @@ tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__1
|
||||
%47:3 = scf.for %arg12 = %c0 to %46 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr<f32>>, tensor<64x64x!tt.ptr<f32>>) {
|
||||
%76 = tt.load %arg14, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, transA=false, transB=false} : tensor<64x64xf32>
|
||||
%77 = tt.load %arg15, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, transA=false, transB=false} : tensor<64x64xf32>
|
||||
%78 = tt.dot %76, %77, %cst_0 {allowTF32 = true, transA = false, transB = false} : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32>
|
||||
%78 = tt.dot %76, %77, %cst_0 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32>
|
||||
%79 = arith.addf %arg13, %78 : tensor<64x64xf32>
|
||||
%80 = arith.muli %arg7, %c64_i32 : i32
|
||||
%81 = tt.splat %80 : (i32) -> tensor<64x64xi32>
|
||||
|
||||
@@ -53,7 +53,7 @@ tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr
|
||||
%a_op_ = triton_gpu.convert_layout %a : (tensor<128x32xf8E5M2, #A>) -> tensor<128x32xf8E5M2, #A_OP>
|
||||
%a_op = tt.fp_to_fp %a_op_ : tensor<128x32xf8E5M2, #A_OP> -> tensor<128x32xf16, #A_OP>
|
||||
%b_op = triton_gpu.convert_layout %b : (tensor<32x128xf16, #B>) -> tensor<32x128xf16, #B_OP>
|
||||
%c = tt.dot %a_op, %b_op, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C>
|
||||
%c = tt.dot %a_op, %b_op, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f8E5M2>, #AL>, tensor<128x32xi32, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
|
||||
@@ -15,7 +15,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
|
||||
%9 = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #blocked>
|
||||
%10 = triton_gpu.convert_layout %9 : (tensor<32x32xf32, #blocked>) -> tensor<32x32xf32, #shared>
|
||||
%11 = triton_gpu.convert_layout %10 : (tensor<32x32xf32, #shared>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
|
||||
%12 = tt.dot %11, %cst_0, %cst {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma>
|
||||
%12 = tt.dot %11, %cst_0, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma>
|
||||
%13 = triton_gpu.convert_layout %12 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked>
|
||||
tt.store %arg0, %13 {cache = 1 : i32, evict = 1 : i32} : tensor<32x32xf32, #blocked>
|
||||
tt.return
|
||||
@@ -41,7 +41,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
|
||||
%A = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #blocked>
|
||||
%AS = triton_gpu.convert_layout %A : (tensor<32x32xf32, #blocked>) -> tensor<32x32xf32, #shared>
|
||||
%AD = triton_gpu.convert_layout %AS : (tensor<32x32xf32, #shared>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
|
||||
%12 = tt.dot %AD, %BD, %cst {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma>
|
||||
%12 = tt.dot %AD, %BD, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma>
|
||||
%13 = triton_gpu.convert_layout %12 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked>
|
||||
tt.store %arg0, %13 {cache = 1 : i32, evict = 1 : i32} : tensor<32x32xf32, #blocked>
|
||||
tt.return
|
||||
|
||||
@@ -46,7 +46,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%30 = triton_gpu.convert_layout %28 : (tensor<128x64xf16, #blocked>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>>
|
||||
%31 = triton_gpu.convert_layout %29 : (tensor<64x128xf16, #blocked1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>>
|
||||
%32 = triton_gpu.convert_layout %arg12 : (tensor<128x128xf32, #blocked>) -> tensor<128x128xf32, #blocked2>
|
||||
%33 = tt.dot %30, %31, %32 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x128xf32, #blocked2>
|
||||
%33 = tt.dot %30, %31, %32 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x128xf32, #blocked2>
|
||||
%34 = triton_gpu.convert_layout %33 : (tensor<128x128xf32, #blocked2>) -> tensor<128x128xf32, #blocked>
|
||||
// CHECK-NOT: tt.advance
|
||||
%35 = tt.advance %arg13, [%c0_i32, %c64_i32] : <tensor<128x64xf16, #blocked>, 1>
|
||||
|
||||
@@ -97,7 +97,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
%90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked>
|
||||
%91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared>
|
||||
%92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1>
|
||||
%93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
%93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
%94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<128x32xi32, #blocked1>
|
||||
%95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr<f16, 1>, #blocked>, tensor<32x128xi32, #blocked>
|
||||
scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<32x128x!tt.ptr<f16, 1>, #blocked>
|
||||
@@ -208,7 +208,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
// CHECK-NEXT: %90 = tt.load %arg12 {async_agent = dense<0> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked>
|
||||
// CHECK-NEXT: %91 = triton_gpu.convert_layout %89 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared>
|
||||
// CHECK-NEXT: %92 = triton_gpu.convert_layout %90 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1>
|
||||
// CHECK-NEXT: %93 = tt.dot %91, %92, %arg10 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
// CHECK-NEXT: %93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
// CHECK-NEXT: %94 = tt.addptr %arg11, %cst_1 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<128x32xi32, #blocked1>
|
||||
// CHECK-NEXT: %95 = tt.addptr %arg12, %cst_0 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr<f16, 1>, #blocked>, tensor<32x128xi32, #blocked>
|
||||
// CHECK-NEXT: scf.yield {async_agent = dense<[0, 1]> : vector<2xi32>} %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<32x128x!tt.ptr<f16, 1>, #blocked>
|
||||
@@ -336,7 +336,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
%92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1>
|
||||
scf.yield %91, %92 : tensor<128x32xf16, #shared>, tensor<32x128xf16, #shared1>
|
||||
}
|
||||
%93 = tt.dot %96#0, %96#1, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
%93 = tt.dot %96#0, %96#1, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
%94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<128x32xi32, #blocked1>
|
||||
%95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr<f16, 1>, #blocked>, tensor<32x128xi32, #blocked>
|
||||
scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<32x128x!tt.ptr<f16, 1>, #blocked>
|
||||
@@ -452,7 +452,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
// CHECK-NEXT: %96 = triton_gpu.convert_layout %94 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1>
|
||||
// CHECK-NEXT: scf.yield {async_agent = dense<[0, 1]> : vector<2xi32>} %95, %96 : tensor<128x32xf16, #shared>, tensor<32x128xf16, #shared1>
|
||||
// CHECK-NEXT: } {async_agent = dense<[0, 1]> : vector<2xi32>}
|
||||
// CHECK-NEXT: %90 = tt.dot %89#0, %89#1, %arg10 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
// CHECK-NEXT: %90 = tt.dot %89#0, %89#1, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
// CHECK-NEXT: %91 = tt.addptr %arg11, %cst_1 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<128x32xi32, #blocked1>
|
||||
// CHECK-NEXT: %92 = tt.addptr %arg12, %cst_0 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr<f16, 1>, #blocked>, tensor<32x128xi32, #blocked>
|
||||
// CHECK-NEXT: scf.yield {async_agent = dense<[0, 1]> : vector<2xi32>} %90, %91, %92 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<32x128x!tt.ptr<f16, 1>, #blocked>
|
||||
@@ -587,7 +587,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
%90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked>
|
||||
%91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared>
|
||||
%92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1>
|
||||
%93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
%93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
%base_94 = scf.if %96 -> (tensor<128x32x!tt.ptr<f16, 1>, #blocked1>) {
|
||||
%r1_0 = arith.select %96, %c31_i32, %c127_i32 : i32
|
||||
%r1_1 = tt.splat %r1_0 : (i32) -> tensor<128x32xi32, #blocked1>
|
||||
@@ -717,7 +717,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
// CHECK-NEXT: %92 = tt.load %arg12 {async_agent = dense<0> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked>
|
||||
// CHECK-NEXT: %93 = triton_gpu.convert_layout %91 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared>
|
||||
// CHECK-NEXT: %94 = triton_gpu.convert_layout %92 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1>
|
||||
// CHECK-NEXT: %95 = tt.dot %93, %94, %arg10 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
// CHECK-NEXT: %95 = tt.dot %93, %94, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
// CHECK-NEXT: %96 = scf.if %90 -> (tensor<128x32x!tt.ptr<f16, 1>, #blocked1>) {
|
||||
// CHECK-NEXT: %99 = arith.select %90, %c31_i32, %c127_i32 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
// CHECK-NEXT: %100 = tt.splat %99 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128x32xi32, #blocked1>
|
||||
|
||||
@@ -177,7 +177,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
%63 = triton_gpu.extract_slice %0[%arg11, 0, 0] [1, 32, 128] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x32x128xf16, #shared> to tensor<32x128xf16, #shared>
|
||||
%64 = triton_gpu.convert_layout %62 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x32xf16, #shared1>) -> tensor<128x32xf16, #shared1>
|
||||
%65 = triton_gpu.convert_layout %63 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #shared>) -> tensor<32x128xf16, #shared>
|
||||
%66 = tt.dot %64, %65, %arg10 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared1> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma>
|
||||
%66 = tt.dot %64, %65, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared1> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma>
|
||||
%c1_i32_2 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32
|
||||
%c3_i32 = arith.constant {async_agent = dense<1> : vector<1xi32>} 3 : i32
|
||||
%67 = arith.addi %arg11, %c1_i32_2 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
@@ -384,7 +384,7 @@ module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability"
|
||||
%50 = triton_gpu.convert_layout %49 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x16xf16, #shared>) -> tensor<64x16xf16, #shared>
|
||||
%51 = triton_gpu.extract_slice %1[%48, 0, 0] [1, 16, 64] [1, 1, 1] {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
|
||||
%52 = triton_gpu.convert_layout %51 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<16x64xf16, #shared1>) -> tensor<16x64xf16, #shared1>
|
||||
%53 = tt.dot %50, %52, %arg12 {agent.mutex_role = 0 : i32, allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
%53 = tt.dot %50, %52, %arg12 {agent.mutex_role = 0 : i32, allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
triton_nvidia_gpu.consumer_release %2, %48 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
|
||||
%c1_i32_6 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32
|
||||
%54 = arith.addi %arg13, %c1_i32_6 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32
|
||||
|
||||
@@ -141,7 +141,7 @@ module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability"
|
||||
%40 = triton_gpu.convert_layout %39 {async_agent = dense<1> : vector<1xi32>} : (tensor<64x16xf16, #shared>) -> tensor<64x16xf16, #shared>
|
||||
%41 = triton_gpu.extract_slice %1[%38, 0, 0] [1, 16, 64] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
|
||||
%42 = triton_gpu.convert_layout %41 {async_agent = dense<1> : vector<1xi32>} : (tensor<16x64xf16, #shared1>) -> tensor<16x64xf16, #shared1>
|
||||
%43 = tt.dot %40, %42, %arg12 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
%43 = tt.dot %40, %42, %arg12 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
triton_nvidia_gpu.consumer_release %2, %38 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
|
||||
%c1_i32_5 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32
|
||||
%44 = arith.addi %arg13, %c1_i32_5 {async_agent = dense<1> : vector<1xi32>} : i32
|
||||
|
||||
@@ -120,7 +120,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
%90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked>
|
||||
%91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared>
|
||||
%92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1>
|
||||
%93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
%93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
%94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<128x32xi32, #blocked1>
|
||||
%95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr<f16, 1>, #blocked>, tensor<32x128xi32, #blocked>
|
||||
scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<32x128x!tt.ptr<f16, 1>, #blocked>
|
||||
|
||||
@@ -96,7 +96,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
%90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked>
|
||||
%91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared>
|
||||
%92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1>
|
||||
%93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
%93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
%94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<128x32xi32, #blocked1>
|
||||
%95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr<f16, 1>, #blocked>, tensor<32x128xi32, #blocked>
|
||||
scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<32x128x!tt.ptr<f16, 1>, #blocked>
|
||||
@@ -226,7 +226,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
%92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1>
|
||||
scf.yield %91, %92 : tensor<128x32xf16, #shared>, tensor<32x128xf16, #shared1>
|
||||
}
|
||||
%93 = tt.dot %96#0, %96#1, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
%93 = tt.dot %96#0, %96#1, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
%94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<128x32xi32, #blocked1>
|
||||
%95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr<f16, 1>, #blocked>, tensor<32x128xi32, #blocked>
|
||||
scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<32x128x!tt.ptr<f16, 1>, #blocked>
|
||||
@@ -362,7 +362,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
%90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked>
|
||||
%91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared>
|
||||
%92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1>
|
||||
%93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
%93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
|
||||
%base_94 = scf.if %96 -> (tensor<128x32x!tt.ptr<f16, 1>, #blocked1>) {
|
||||
%r1_0 = arith.select %96, %c31_i32, %c127_i32 : i32
|
||||
%r1_1 = tt.splat %r1_0 : (i32) -> tensor<128x32xi32, #blocked1>
|
||||
@@ -438,7 +438,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%47 = tt.load %arg12 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<16x64xf16, #blocked1>, 1> -> tensor<16x64xf16, #blocked3>
|
||||
%48 = triton_gpu.convert_layout %46 : (tensor<64x16xf16, #blocked2>) -> tensor<64x16xf16, #shared>
|
||||
%49 = triton_gpu.convert_layout %47 : (tensor<16x64xf16, #blocked3>) -> tensor<16x64xf16, #shared1>
|
||||
%50 = tt.dot %48, %49, %arg10 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
%50 = tt.dot %48, %49, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
%51 = tt.advance %arg11, [%c0_i32, %c16_i32] : <tensor<64x16xf16, #blocked>, 1>
|
||||
%52 = tt.advance %arg12, [%c16_i32, %c0_i32] : <tensor<16x64xf16, #blocked1>, 1>
|
||||
scf.yield %50, %51, %52 : tensor<64x64xf32, #mma>, !tt.ptr<tensor<64x16xf16, #blocked>, 1>, !tt.ptr<tensor<16x64xf16, #blocked1>, 1>
|
||||
@@ -518,7 +518,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%47 = tt.load %arg12 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<16x64xf16, #blocked1>, 1> -> tensor<16x64xf16, #blocked3>
|
||||
%48 = triton_gpu.convert_layout %46 : (tensor<64x16xf16, #blocked2>) -> tensor<64x16xf16, #shared>
|
||||
%49 = triton_gpu.convert_layout %47 : (tensor<16x64xf16, #blocked3>) -> tensor<16x64xf16, #shared1>
|
||||
%50 = tt.dot %48, %49, %arg10 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
%50 = tt.dot %48, %49, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
%51 = tt.advance %arg11, [%c0_i32, %c16_i32] : <tensor<64x16xf16, #blocked>, 1>
|
||||
%52 = tt.advance %arg12, [%c16_i32, %c0_i32] : <tensor<16x64xf16, #blocked1>, 1>
|
||||
scf.yield %50, %51, %52 : tensor<64x64xf32, #mma>, !tt.ptr<tensor<64x16xf16, #blocked>, 1>, !tt.ptr<tensor<16x64xf16, #blocked1>, 1>
|
||||
@@ -600,7 +600,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%47 = tt.load %arg12 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<16x64xf16, #blocked1>, 1> -> tensor<16x64xf16, #blocked3>
|
||||
%48 = triton_gpu.convert_layout %46 : (tensor<64x16xf16, #blocked2>) -> tensor<64x16xf16, #shared>
|
||||
%49 = triton_gpu.convert_layout %47 : (tensor<16x64xf16, #blocked3>) -> tensor<16x64xf16, #shared1>
|
||||
%50 = tt.dot %48, %49, %arg10 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
%50 = tt.dot %48, %49, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
%51 = tt.advance %arg11, [%c0_i32, %c16_i32] : <tensor<64x16xf16, #blocked>, 1>
|
||||
%52 = tt.advance %arg12, [%c16_i32, %c0_i32] : <tensor<16x64xf16, #blocked1>, 1>
|
||||
scf.yield %50, %51, %52 : tensor<64x64xf32, #mma>, !tt.ptr<tensor<64x16xf16, #blocked>, 1>, !tt.ptr<tensor<16x64xf16, #blocked1>, 1>
|
||||
@@ -686,7 +686,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%47 = tt.load %arg12 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<16x64xf16, #blocked1>, 1> -> tensor<16x64xf16, #blocked3>
|
||||
%48 = triton_gpu.convert_layout %46 : (tensor<64x16xf16, #blocked2>) -> tensor<64x16xf16, #shared>
|
||||
%49 = triton_gpu.convert_layout %47 : (tensor<16x64xf16, #blocked3>) -> tensor<16x64xf16, #shared1>
|
||||
%50 = tt.dot %48, %49, %arg10 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
%50 = tt.dot %48, %49, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
%51 = tt.advance %arg11, [%c0_i32, %c16_i32] : <tensor<64x16xf16, #blocked>, 1>
|
||||
%52 = tt.advance %arg12, [%c16_i32, %c0_i32] : <tensor<16x64xf16, #blocked1>, 1>
|
||||
scf.yield %50, %51, %52 : tensor<64x64xf32, #mma>, !tt.ptr<tensor<64x16xf16, #blocked>, 1>, !tt.ptr<tensor<16x64xf16, #blocked1>, 1>
|
||||
@@ -799,7 +799,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%44 = tt.load %arg17 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<64x128xf16, #blocked1>, 1> -> tensor<64x128xf16, #blocked4>
|
||||
%45 = triton_gpu.convert_layout %43 : (tensor<256x64xf16, #blocked3>) -> tensor<256x64xf16, #shared>
|
||||
%46 = triton_gpu.convert_layout %44 : (tensor<64x128xf16, #blocked4>) -> tensor<64x128xf16, #shared1>
|
||||
%47 = tt.dot %45, %46, %arg15 {allowTF32 = true} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma>
|
||||
%47 = tt.dot %45, %46, %arg15 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma>
|
||||
%48 = tt.advance %arg16, [%c0_i32, %c64_i32] : <tensor<256x64xf16, #blocked>, 1>
|
||||
%49 = tt.advance %arg17, [%c64_i32, %c0_i32] : <tensor<64x128xf16, #blocked1>, 1>
|
||||
scf.yield %47, %48, %49 : tensor<256x128xf32, #mma>, !tt.ptr<tensor<256x64xf16, #blocked>, 1>, !tt.ptr<tensor<64x128xf16, #blocked1>, 1>
|
||||
@@ -852,7 +852,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%b = tt.load %arg1 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<64x128xf16, #blocked1>, 1> -> tensor<64x128xf16, #blocked4>
|
||||
%shm_a = triton_gpu.convert_layout %a : (tensor<256x64xf16, #blocked3>) -> tensor<256x64xf16, #shared>
|
||||
%shm_b = triton_gpu.convert_layout %b : (tensor<64x128xf16, #blocked4>) -> tensor<64x128xf16, #shared1>
|
||||
%d = tt.dot %shm_a, %shm_b, %cst {allowTF32 = true} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma>
|
||||
%d = tt.dot %shm_a, %shm_b, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma>
|
||||
%out = triton_gpu.convert_layout %d : (tensor<256x128xf32, #mma>) -> tensor<256x128xf32, #blocked2>
|
||||
tt.store %arg2, %out {cache = 1 : i32, evict = 1 : i32} : tensor<256x128xf32, #blocked2>
|
||||
}
|
||||
@@ -887,7 +887,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%b = tt.load %arg1 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x128x!tt.ptr<f16, 1>, #blocked4> -> tensor<64x128xf16, #blocked4>
|
||||
%shm_a = triton_gpu.convert_layout %a : (tensor<256x64xf16, #blocked3>) -> tensor<256x64xf16, #shared>
|
||||
%shm_b = triton_gpu.convert_layout %b : (tensor<64x128xf16, #blocked4>) -> tensor<64x128xf16, #shared1>
|
||||
%d = tt.dot %shm_a, %shm_b, %cst {allowTF32 = true} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma>
|
||||
%d = tt.dot %shm_a, %shm_b, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma>
|
||||
%out = triton_gpu.convert_layout %d : (tensor<256x128xf32, #mma>) -> tensor<256x128xf32, #blocked2>
|
||||
tt.store %arg2, %out {cache = 1 : i32, evict = 1 : i32} : tensor<256x128xf32, #blocked2>
|
||||
}
|
||||
@@ -1000,7 +1000,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
%92 = tt.load %arg19 {boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<16x64xf16, #blocked1>, 1> -> tensor<16x64xf16, #blocked4>
|
||||
%93 = triton_gpu.convert_layout %91 : (tensor<64x16xf16, #blocked3>) -> tensor<64x16xf16, #shared>
|
||||
%94 = triton_gpu.convert_layout %92 : (tensor<16x64xf16, #blocked4>) -> tensor<16x64xf16, #shared1>
|
||||
%95 = tt.dot %93, %94, %arg17 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
%95 = tt.dot %93, %94, %arg17 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
%96 = tt.advance %arg18, [%c0_i32, %c16_i32] : <tensor<64x16xf16, #blocked>, 1>
|
||||
%97 = tt.advance %arg19, [%c16_i32, %c0_i32] : <tensor<16x64xf16, #blocked1>, 1>
|
||||
scf.yield %95, %96, %97 : tensor<64x64xf32, #mma>, !tt.ptr<tensor<64x16xf16, #blocked>, 1>, !tt.ptr<tensor<16x64xf16, #blocked1>, 1>
|
||||
|
||||
Reference in New Issue
Block a user