[FRONTEND][BACKEND] Add flag to control accumulation for fp8 (#2300)

Change the dot to allow taking an initial accumulator and add a flag
that will allow the compiler to accumulate in a lower precision than the
output type.
On Hopper this flag is on by default which allows accumualting with
lower precision.
This only affect Hopper fp8 dot.
This commit is contained in:
Thomas Raoux
2023-09-15 18:42:54 -07:00
committed by GitHub
parent 78a0b5dc2a
commit 31b0c52142
44 changed files with 431 additions and 178 deletions

View File

@@ -148,12 +148,12 @@ def WGMMA_EltTypeAttr : I32EnumAttr<"WGMMAEltType",
def WGMMA_OperandType : AnyTypeOf<[LLVM_AnyStruct, I64], "wgmma operand A/B type">;
def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> {
let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, LLVM_AnyStruct:$opC,
let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, Optional<LLVM_AnyStruct>:$opC,
I32Attr:$m, I32Attr:$n, I32Attr:$k,
WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB,
WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB);
let results = (outs LLVM_AnyStruct:$res);
let assemblyFormat = "$opA `,` $opB `,` $opC attr-dict `:` functional-type(operands, $res)";
let assemblyFormat = "$opA `,` $opB (`,` $opC^)? attr-dict `:` functional-type(operands, $res)";
}
def NVGPU_CGABarrierSyncOp : NVGPU_Op<"cga_barrier_sync", []> {

View File

@@ -394,7 +394,12 @@ def TT_DotOp : TT_Op<"dot", [Pure,
$d = matrix_multiply($a, $b) + $c
}];
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
let arguments = (ins
TT_FpIntTensor:$a,
TT_FpIntTensor:$b,
TT_FpIntTensor:$c,
BoolAttr:$allowTF32,
I32Attr:$maxNumImpreciseAcc);
let results = (outs TT_FpIntTensor:$d);

View File

@@ -258,7 +258,11 @@ def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [Pure,
$d = matrix_multiply($a, $b) + $c
}];
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
let arguments = (ins TT_FpIntTensor:$a,
TT_FpIntTensor:$b,
TT_FpIntTensor:$c,
BoolAttr:$allowTF32,
I32Attr:$maxNumImpreciseAcc);
let results = (outs TT_FpIntTensor:$d);

View File

@@ -379,6 +379,12 @@ bool supportMMA(triton::DotOp op, int version) {
aElemTy.isF32()))) {
return false;
}
// We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
if (op.getMaxNumImpreciseAcc() < 32 &&
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ()) &&
op.getType().cast<RankedTensorType>().getElementType().isF32()) {
return false;
}
}
if (aElemTy.isF32() && bElemTy.isF32()) {
return op.getAllowTF32() && version >= 2;

View File

@@ -708,13 +708,13 @@ public:
// TODO (zahi): Return type must always be a struct for wgmma, currently
// we rely on the size of output constraints vector to determine whether
// the output is a struct or not. We should find a way to pass this info
auto opC = op.getOpC();
auto typeC = opC.getType();
auto resultType = op.getType();
auto structTypeC = typeC.dyn_cast<LLVM::LLVMStructType>();
uint32_t numCRegs = structTypeC.getBody().size();
std::string c = structTypeC.getBody().front().isF32() ? "=f" : "=r";
return std::vector<std::string>(numCRegs, c);
auto outputStructType = resultType.dyn_cast<LLVM::LLVMStructType>();
uint32_t numOutputRegs = outputStructType.getBody().size();
std::string output =
outputStructType.getBody().front().isF32() ? "=f" : "=r";
return std::vector<std::string>(numOutputRegs, output);
}
OperandsAndConstraints getOperandsAndConstraints(ttn::WGMMAOp op) const {
@@ -727,7 +727,8 @@ public:
auto structTypeA = typeA.dyn_cast<LLVM::LLVMStructType>();
// TODO (zahi): is this the best way to tie inputs/outputs ?
operandsAndConstraints.push_back({opC, "0"});
if (opC)
operandsAndConstraints.push_back({opC, "0"});
if (structTypeA) {
operandsAndConstraints.push_back({opA, "f"});
@@ -744,7 +745,6 @@ public:
using namespace ttn;
auto opA = op.getOpA();
auto opB = op.getOpB();
auto opC = op.getOpC();
auto m = op.getM();
auto n = op.getN();
auto k = op.getK();
@@ -757,12 +757,12 @@ public:
// Register checks
auto typeA = opA.getType();
auto typeB = opB.getType();
auto typeC = opC.getType();
auto typeOutput = op.getType();
auto structTypeA = typeA.dyn_cast<LLVM::LLVMStructType>();
auto structTypeB = typeB.dyn_cast<LLVM::LLVMStructType>();
auto structTypeC = typeC.dyn_cast<LLVM::LLVMStructType>();
auto structTypeOutput = typeOutput.dyn_cast<LLVM::LLVMStructType>();
assert(!structTypeB && "Operand B can not be registers");
assert(structTypeC && "Operand C must be registers");
assert(structTypeOutput && "Output and C operand must be registers");
// Element type, MNK shape and transposing support check
// Reference:
@@ -804,18 +804,20 @@ public:
// Operands
uint32_t asmOpIdx = 0;
// Operand C
uint32_t numCRegs = structTypeC.getBody().size();
std::string args = "";
// Output and operand C
uint32_t numCRegs = structTypeOutput.getBody().size();
args += "{";
for (uint32_t i = 0; i < numCRegs; ++i) {
args += "$" + std::to_string(asmOpIdx++) + (i == numCRegs - 1 ? "" : ",");
}
args += "}, ";
asmOpIdx += numCRegs;
if (op.getOpC())
asmOpIdx += numCRegs;
// Operand A
if (structTypeA) {
uint32_t numARegs = m * k / 128;
@@ -833,8 +835,8 @@ public:
// Operand B (must be `desc`)
args += "$" + std::to_string(asmOpIdx++) + ", ";
// `scale-d` is 1 by default
args += "1";
// `scale-d` is 1 if we have a C operand.
args += op.getOpC() ? "1" : "0";
// `imm-scale-a`, and `imm-scale-b` are 1 by default only for float-based
// WGMMA

View File

@@ -260,11 +260,30 @@ SmallVector<Value> unpackAccumulator(ConversionPatternRewriter &rewriter,
return results;
}
static bool isFP8(triton::nvgpu::WGMMAEltType eltType) {
return eltType == triton::nvgpu::WGMMAEltType::e5m2 ||
eltType == triton::nvgpu::WGMMAEltType::e4m3;
}
static Value faddAccumulate(ConversionPatternRewriter &rewriter, Location loc,
Value a, Value b) {
int numEl = a.getType().cast<LLVM::LLVMStructType>().getBody().size();
Value newStruct = rewriter.create<LLVM::UndefOp>(loc, a.getType());
for (int i = 0; i < numEl; ++i) {
Value lhs = rewriter.create<LLVM::ExtractValueOp>(loc, a, i);
Value rhs = rewriter.create<LLVM::ExtractValueOp>(loc, b, i);
Value add = rewriter.create<LLVM::FAddOp>(loc, lhs, rhs);
newStruct = rewriter.create<LLVM::InsertValueOp>(loc, newStruct, add, i);
}
return newStruct;
}
LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc,
Operation *op, Value a, Value b, Value c, Value d,
Value loadedA, Value loadedB, Value loadedC,
bool allowTF32, const SharedMemoryObject &smemObjA,
bool allowTF32, uint32_t maxNumImpreciseAcc,
const SharedMemoryObject &smemObjA,
const SharedMemoryObject &smemObjB, bool sync,
Value thread) {
auto aTensorTy = a.getType().cast<RankedTensorType>();
@@ -311,7 +330,10 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
if (numTMADescs == 0)
rewriter.create<triton::nvgpu::FenceAsyncSharedOp>(loc, 0);
rewriter.create<triton::nvgpu::WGMMAFenceOp>(loc);
// WGMMA fp8 -> fp32 accumulates in lower precision than fp32.
bool needsPartialAccumulator = isFP8(eltTypeA) &&
eltTypeC == triton::nvgpu::WGMMAEltType::f32 &&
maxNumImpreciseAcc <= aTensorTy.getShape()[1];
SmallVector<Value> mmaResults;
for (int m = 0; m < numRepM; ++m) {
for (int n = 0; n < numRepN; ++n) {
@@ -323,13 +345,33 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
auto accTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
Value d = typeConverter->packLLElements(loc, mmaOut, rewriter, accTy);
uint32_t numLowPrecisionAcc = 0;
Value partialAcc;
for (int k = 0; k < numRepK; ++k) {
auto a = aLoader.smemLoad(m, k);
auto b = bLoader.smemLoad(n, k);
ValueRange operands{a, b, d};
d = rewriter.create<triton::nvgpu::WGMMAOp>(loc, accTy, a, b, d, M, N,
K, eltTypeC, eltTypeA,
eltTypeB, layoutA, layoutB);
numLowPrecisionAcc += K;
// If using native accumulation would cause use to do more low precion
// accumulation than allowed do a separate allocation.
bool requireAddAccumulator =
needsPartialAccumulator &&
(numLowPrecisionAcc >= maxNumImpreciseAcc || k == numRepK - 1);
Value mmaAcc = needsPartialAccumulator ? partialAcc : d;
mmaAcc = rewriter.create<triton::nvgpu::WGMMAOp>(
loc, accTy, a, b, mmaAcc, M, N, K, eltTypeC, eltTypeA, eltTypeB,
layoutA, layoutB);
if (needsPartialAccumulator)
partialAcc = mmaAcc;
else
d = mmaAcc;
// If we need accumulate separately to have higer precision, insert
// adds.
if (requireAddAccumulator) {
d = faddAccumulate(rewriter, loc, d, partialAcc);
numLowPrecisionAcc = 0;
partialAcc = Value();
}
}
auto acc = typeConverter->unpackLLElements(loc, d, rewriter, accTy);
for (int i = 0; i < acc.size(); ++i) {
@@ -398,8 +440,9 @@ LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
auto smemObjA = getSharedMemoryObjectFromStruct(loc, llA, rewriter);
auto smemObjB = getSharedMemoryObjectFromStruct(loc, llB, rewriter);
return convertDot(typeConverter, rewriter, loc, op.getOperation(), A, B, C,
op.getD(), llA, llB, llC, op.getAllowTF32(), smemObjA,
smemObjB, true, thread);
op.getD(), llA, llB, llC, op.getAllowTF32(),
op.getMaxNumImpreciseAcc(), smemObjA, smemObjB, true,
thread);
}
LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op,
@@ -426,6 +469,7 @@ LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op,
auto smemObjA = getSharedMemoryObjectFromStruct(loc, llA, rewriter);
auto smemObjB = getSharedMemoryObjectFromStruct(loc, llB, rewriter);
return convertDot(typeConverter, rewriter, loc, op.getOperation(), A, B, C,
op.getD(), llA, llB, llC, op.getAllowTF32(), smemObjA,
smemObjB, false, thread);
op.getD(), llA, llB, llC, op.getAllowTF32(),
op.getMaxNumImpreciseAcc(), smemObjA, smemObjB, false,
thread);
}

View File

@@ -342,7 +342,8 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, c, adaptor.getAllowTF32()),
op, retType, a, b, c, adaptor.getAllowTF32(),
adaptor.getMaxNumImpreciseAcc()),
adaptor.getAttributes());
return success();
}

View File

@@ -181,8 +181,9 @@ public:
op->getLoc(), newAccType,
rewriter.create<arith::ConstantOp>(op->getLoc(),
rewriter.getF32FloatAttr(0)));
rewriter.replaceOpWithNewOp<triton::DotOp>(
op, expandLhsOp.getOperand(), expandRhsOp.getOperand(), newAcc, true);
rewriter.replaceOpWithNewOp<triton::DotOp>(op, expandLhsOp.getOperand(),
expandRhsOp.getOperand(), newAcc,
true, 0);
return mlir::success();
}
};

View File

@@ -12,22 +12,24 @@ include "mlir/IR/PatternBase.td"
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
def CombineDotAddIPattern : Pat<
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
(TT_DotOp $a, $b, $d, $allowTF32),
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc)),
(TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddFPattern : Pat<
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32), $fastmath),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $fastmath),
(TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc)]>;
def CombineDotAddIRevPattern : Pat<
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
(TT_DotOp $a, $b, $d, $allowTF32),
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $d),
(TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddFRevPattern : Pat<
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d, $fastmath),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $d, $fastmath),
(TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc)]>;
// TODO: this fails for addptr(addptr(ptr, i32), i64)
// Commented out until fixed

View File

@@ -316,7 +316,8 @@ public:
}
// convert dot instruction
auto newDot = rewriter.create<tt::DotOp>(dotOp.getLoc(), newRetType, a, b,
newAcc, dotOp.getAllowTF32());
newAcc, dotOp.getAllowTF32(),
dotOp.getMaxNumImpreciseAcc());
rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(op, oldRetType,
newDot.getResult());

View File

@@ -1640,7 +1640,8 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
auto dotOp = cast<tt::DotOp>(dot.getDefiningOp());
builder.setInsertionPoint(dot.getDefiningOp());
auto dotAsync = builder.create<tt::nvidia_gpu::DotAsyncOp>(
loc, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32());
loc, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32(),
dotOp.getMaxNumImpreciseAcc());
dot.replaceAllUsesWith(dotAsync.getResult());
updateConsumerReleaseInfo(dot.getDefiningOp(), dotWait, /*stage=*/1);
dot.getDefiningOp()->erase();

View File

@@ -117,7 +117,8 @@ public:
op->getLoc(), dotOp.getResult().getType(), _0f);
auto newDot = rewriter.create<triton::DotOp>(
op->getLoc(), dotOp.getResult().getType(), dotOp.getOperand(0),
dotOp.getOperand(1), _0, dotOp.getAllowTF32());
dotOp.getOperand(1), _0, dotOp.getAllowTF32(),
dotOp.getMaxNumImpreciseAcc());
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), dstTy, newDot.getResult());
rewriter.replaceOpWithNewOp<arith::AddFOp>(op, newCvt, cvtOp.getOperand());

View File

@@ -50,9 +50,9 @@ public:
.getEncoding()
.dyn_cast<ttg::MmaEncodingAttr>();
auto isHopperEncoding = mmaEncoding && mmaEncoding.isHopper();
if (isHopperEncoding && (isa<ttg::ConvertLayoutOp>(a.getDefiningOp()) &&
if (isHopperEncoding && (a.getDefiningOp<ttg::ConvertLayoutOp>() &&
ttg::isSharedEncoding(a)) ||
(isa<ttg::ConvertLayoutOp>(b.getDefiningOp()) &&
(b.getDefiningOp<ttg::ConvertLayoutOp>() &&
ttg::isSharedEncoding(b))) {
// TODO: check whether cluster fence is needed

View File

@@ -736,7 +736,7 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
auto dotAsync =
builder.createWithAgentIds<triton::nvidia_gpu::DotAsyncOp>(
loc, dotOp.getA(), dotOp.getB(), dotOp.getC(),
dotOp.getAllowTF32());
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
dot.replaceAllUsesWith(dotAsync.getResult());
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(loc, 1);

View File

@@ -1482,9 +1482,10 @@ void init_triton_ir(py::module &&m) {
})
.def("create_dot",
[](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b,
mlir::Value &c, bool allowTF32) -> mlir::Value {
return self.create<mlir::triton::DotOp>(c.getType(), a, b, c,
allowTF32);
mlir::Value &c, bool allowTF32,
int maxNumImpreciseAcc) -> mlir::Value {
return self.create<mlir::triton::DotOp>(
c.getType(), a, b, c, allowTF32, maxNumImpreciseAcc);
})
.def("create_exp",
[](TritonOpBuilder &self, mlir::Value &val) -> mlir::Value {

View File

@@ -131,8 +131,8 @@ def check_type_supported(dtype, device):
cc = torch.cuda.get_device_capability()
if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
if cc[0] < 9 and (dtype is tl.float8e4nv or dtype == "float8e4"):
pytest.skip("float8e4 is only supported on NVGPU with cc >= 90")
if cc[0] < 9 and (dtype is tl.float8e4nv or dtype == "float8e4nv"):
pytest.skip("float8e4nv is only supported on NVGPU with cc >= 90")
class MmaLayout:
@@ -3750,3 +3750,86 @@ def test_ptx_cast(dtype_str, device):
buf14 = -torch.ones((s0, 6, 197, 197), device=device, dtype=torch_dtype)
kernel[(4728,)](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2)
assert buf14.to(torch.float32).mean() == -2.0
# -----------------------
# test fp8 -> fp32 dot
# -----------------------
def f8_to_f16(x, dtype):
@triton.jit
def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offs < N
x = tl.load(X + offs, mask=mask)
tl.store(Y + offs, x, mask=mask)
ret = torch.empty(x.shape, dtype=torch.float16, device=x.device)
grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),)
dtype = getattr(tl, dtype)
kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024)
return ret
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
low_precision_acc: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
accumulator = tl.dot(a, b, acc=accumulator, max_num_imprecise_acc=low_precision_acc)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(c_ptrs, accumulator)
@pytest.mark.parametrize("in_type_str", ['float8e5', 'float8e4nv'])
@pytest.mark.parametrize("low_precision_acc", [0, 32, 64, 128])
def test_fp8_dot_acc(in_type_str, low_precision_acc, device):
check_type_supported(in_type_str, device)
M, N, K = 128, 256, 256
BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 128
A = numpy_random((M, K), dtype_str=in_type_str)
B = numpy_random((K, N), dtype_str=in_type_str)
Bt = B.T
C = torch.empty((M, N), dtype=torch.float32, device='cuda')
num_warps = 8
a = to_triton(A, device='cuda', dst_type=in_type_str)
b = to_triton(B, device='cuda', dst_type=in_type_str)
grid = (triton.cdiv(M, BLOCK_M), 1)
matmul_kernel[grid](a, b, C, M, N, K,
a.stride(0), a.stride(1), b.stride(0), b.stride(
1), C.stride(0), C.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K, low_precision_acc, num_warps=num_warps)
torch_a = torch.from_numpy(A)
th_a = f8_to_f16(torch_a.cuda(), in_type_str)
torch_b = torch.from_numpy(B)
th_b = f8_to_f16(torch_b.cuda(), in_type_str)
ref_out = torch.matmul(th_a, th_b).to(torch.float32)
if in_type_str == 'float8e4nv':
torch.testing.assert_close(ref_out, C, rtol=0.01, atol=0.01)
elif low_precision_acc > 32:
torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3)
else:
torch.testing.assert_close(ref_out, C)

View File

@@ -26,61 +26,61 @@ def f8_to_f16(x, dtype):
@pytest.mark.parametrize(
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32",
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM",
itertools.chain(
*[
[
# 1 warp
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
# 2 warp
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
# 4 warp
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
# 8 warp
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
# variable input
(128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True),
(128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True),
(128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True),
(128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True),
(128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True),
(128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True),
(128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True),
(128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
],
# n-stage
*[
[
(16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True),
(64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True),
(128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True),
(256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True),
(128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True),
(16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True),
(64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True),
(128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True),
(256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True),
(128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [4]
],
# mixed-precision
*[
[
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True),
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True),
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True),
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
] for ADTYPE, BDTYPE in [("float8e4nv", "float8e5"),
("float8e4nv", "float8e4nv"),
("float8e5", "float8e4nv"),
@@ -91,14 +91,14 @@ def f8_to_f16(x, dtype):
("float16", "float32"),
("float32", "float16"),
("bfloat16", "float32"),
("float32", "bfloat16")] for AT in [False, True] for BT in [False, True]
("float32", "bfloat16")] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False]
],
# mixed-precision block layout
*[
[
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False),
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False),
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False),
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True),
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True),
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True),
] for ADTYPE, BDTYPE in [("float8e4nv", "float16"),
("float16", "float8e5"),
("float16", "float32"),
@@ -108,7 +108,7 @@ def f8_to_f16(x, dtype):
],
),
)
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32):
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM):
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
@@ -176,7 +176,7 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
a = triton.reinterpret(a, getattr(tl, ADTYPE))
if b_fp8:
b = triton.reinterpret(b, getattr(tl, BDTYPE))
tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32)
tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32, F8_FASTACCUM)
torch.testing.assert_close(th_c, tt_c)
except triton.OutOfResources as e:
pytest.skip(str(e))

View File

@@ -985,7 +985,7 @@ def expand_dims(input, axis, _builder=None):
@builtin
def dot(input, other, allow_tf32=True, out_dtype=float32, _builder=None):
def dot(input, other, acc=None, allow_tf32=True, max_num_imprecise_acc=None, out_dtype=float32, _builder=None):
"""
Returns the matrix product of two blocks.
@@ -998,7 +998,7 @@ def dot(input, other, allow_tf32=True, out_dtype=float32, _builder=None):
"""
allow_tf32 = _constexpr_to_value(allow_tf32)
out_dtype = _constexpr_to_value(out_dtype)
return semantic.dot(input, other, allow_tf32, out_dtype, _builder)
return semantic.dot(input, other, acc, allow_tf32, max_num_imprecise_acc, out_dtype, _builder)
# -----------------------

View File

@@ -1265,7 +1265,9 @@ def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
def dot(lhs: tl.tensor,
rhs: tl.tensor,
acc: tl.tensor,
allow_tf32: bool,
max_num_imprecise_acc: int,
out_dtype: tl.dtype,
builder: ir.builder) -> tl.tensor:
def assert_dtypes_valid(lhs_dtype, rhs_dtype, arch):
@@ -1343,10 +1345,20 @@ def dot(lhs: tl.tensor,
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
ret_ty)
return cast(ret, ret_scalar_ty, builder)
_0 = builder.create_splat(_0, [M, N])
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
if acc is None:
acc_handle = builder.create_splat(_0, [M, N])
else:
acc_handle = acc.handle
assert acc.type == ret_ty
# max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
if not (_is_cuda(builder.arch) and builder.arch == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and ret_scalar_ty.is_fp32()):
max_num_imprecise_acc = 0
if max_num_imprecise_acc is None:
max_num_imprecise_acc = 2**30
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc),
ret_ty)

View File

@@ -82,6 +82,7 @@ def _kernel(A, B, C, M, N, K,
stride_cm, stride_cn,
dot_out_dtype: tl.constexpr,
allow_tf32: tl.constexpr,
fp8_fast_accum: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr
):
@@ -118,7 +119,10 @@ def _kernel(A, B, C, M, N, K,
if AB_DTYPE:
a = a.to(C.dtype.element_ty)
b = b.to(C.dtype.element_ty)
acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
if fp8_fast_accum:
acc = tl.dot(a, b, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
else:
acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = acc.to(C.dtype.element_ty)
@@ -140,7 +144,7 @@ class _matmul(torch.autograd.Function):
_locks = {}
@staticmethod
def _call(a, b, dot_out_dtype, allow_tf32):
def _call(a, b, dot_out_dtype, allow_tf32, fp8_fast_accum):
device = a.device
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
@@ -182,12 +186,13 @@ class _matmul(torch.autograd.Function):
c.stride(0), c.stride(1),
dot_out_dtype=dot_out_dtype,
allow_tf32=allow_tf32,
fp8_fast_accum=fp8_fast_accum,
GROUP_M=8, AB_DTYPE=ab_dtype)
return c
@staticmethod
def forward(ctx, a, b, dot_out_dtype=None, allow_tf32=True):
return _matmul._call(a, b, dot_out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
def forward(ctx, a, b, dot_out_dtype=None, allow_tf32=True, fp8_fast_accum=True):
return _matmul._call(a, b, dot_out_dtype=dot_out_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum)
matmul = _matmul.apply

View File

@@ -26,7 +26,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT>
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT>
%c = tt.dot %a, %b, %prev_c {transA = false, transB = false, allowTF32 = true} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
%c = tt.dot %a, %b, %prev_c {transA = false, transB = false, allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>

View File

@@ -34,7 +34,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
// CHECK-NEXT: offset = 0, size = 4224
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
@@ -64,11 +64,11 @@ tt.func @reusable(%A : !tt.ptr<f16>) {
%a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
// CHECK-NEXT: offset = 0, size = 4608
%a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT>
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
%a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
// CHECK-NEXT: offset = 0, size = 1152
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #B_DOT>
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
tt.return
// CHECK-NEXT: size = 4608
}

View File

@@ -32,7 +32,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT>
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>

View File

@@ -6,7 +6,7 @@
module attributes {"triton_gpu.num-warps" = 1 : i32} {
tt.func @convert_dot(%A: tensor<16x16xf32, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) {
// expected-error@+1 {{element types of operands A and B must have same bit width}}
%D = tt.dot %A, %B, %C {allowTF32 = true, transA = false, transB = false} :
%D = tt.dot %A, %B, %C {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} :
tensor<16x16xf32, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
tt.return
}
@@ -20,7 +20,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
module attributes {"triton_gpu.num-warps" = 1 : i32} {
tt.func @convert_dot(%A: tensor<16x16xf16>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) {
// expected-error@+1 {{mismatching encoding between A and B operands}}
%D = tt.dot %A, %B, %C {allowTF32 = true, transA = false, transB = false} :
%D = tt.dot %A, %B, %C {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} :
tensor<16x16xf16> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
tt.return
}
@@ -34,7 +34,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
module attributes {"triton_gpu.num-warps" = 1 : i32} {
tt.func @convert_dot(%A: tensor<16x16xf16, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) {
// expected-error@+1 {{mismatching kWidth between A and B operands}}
%D = tt.dot %A, %B, %C {allowTF32 = true, transA = false, transB = false} :
%D = tt.dot %A, %B, %C {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} :
tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
tt.return
}

View File

@@ -161,13 +161,13 @@ tt.func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
%zero1x1 = arith.constant dense<0.00e+00> : tensor<1x1xf32>
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32>
%r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32>
%r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32>
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<32x32xf32>
%r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true, transA = false, transB = false} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32>
%r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32>
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32>
%r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true, transA = false, transB = false} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32>
%r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32>
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<1x1xf32>
%r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true, transA = false, transB = false} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32>
%r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32>
%ptr128x128 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x128x!tt.ptr<f32>>
%ptr32x32 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>>

View File

@@ -5,7 +5,7 @@ tt.func @ops() {
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
%c = arith.constant dense<3.00e+00> : tensor<128x128xf32>
%0 = tt.dot %a, %b, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32>
%0 = tt.dot %a, %b, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32>
tt.return
}

View File

@@ -834,7 +834,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
// CHECK: llvm.inline_asm
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
%D = tt.dot %AA_DOT, %BB_DOT, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
%D = tt.dot %AA_DOT, %BB_DOT, %cst0 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
tt.return
}
@@ -967,7 +967,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%a_mat = triton_gpu.convert_layout %a : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #dot_operand_a>
%b_mat = triton_gpu.convert_layout %b : (tensor<32x256xf16, #shared>) -> tensor<32x256xf16, #dot_operand_b>
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma>
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma>
%38 = triton_gpu.convert_layout %28 : (tensor<128x256xf32, #mma>) -> tensor<128x256xf32, #blocked>
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>, #blocked>
@@ -993,7 +993,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%a_mat = triton_gpu.convert_layout %a : (tensor<32x64xf16, #shared0>) -> tensor<32x64xf16, #dot_operand_a>
%b_mat = triton_gpu.convert_layout %b : (tensor<64x64xf16, #shared1>) -> tensor<64x64xf16, #dot_operand_b>
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<32x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<32x64xf32, #mma>
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<32x64xf32, #mma>
%38 = triton_gpu.convert_layout %28 : (tensor<32x64xf32, #mma>) -> tensor<32x64xf32, #blocked>
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x64x!tt.ptr<f32>, #blocked>
@@ -1016,7 +1016,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%a_mat = triton_gpu.convert_layout %a : (tensor<32x16xf32, #shared>) -> tensor<32x16xf32, #dot_operand_a>
%b_mat = triton_gpu.convert_layout %b : (tensor<16x32xf32, #shared>) -> tensor<16x32xf32, #dot_operand_b>
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = false, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #blocked>
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = false, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #blocked>
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x32x!tt.ptr<f32>, #blocked>
tt.store %36, %28 : tensor<32x32xf32, #blocked>
@@ -1053,7 +1053,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
// CHECK: llvm.inline_asm
// CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma>
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma>
%38 = triton_gpu.convert_layout %28 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked>
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
@@ -1265,7 +1265,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: %[[SI:.+]] = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK: llvm.insertvalue %[[BC]], %[[SI]][0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
%b_mat = arith.constant dense<1.000000e+00> : tensor<16x32xf32, #dot_operand_b>
%28 = tt.dot %a, %b_mat, %c {allowTF32 = true, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma>
%28 = tt.dot %a, %b_mat, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma>
%38 = triton_gpu.convert_layout %28 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked>
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x32x!tt.ptr<f32>, #blocked>
@@ -1295,7 +1295,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%cst_1 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%cst_2 = arith.constant dense<32> : tensor<32x1xi32, #blocked>
%0 = tt.dot %cst_0, %cst_1, %cst {allowTF32 = true} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma>
%0 = tt.dot %cst_0, %cst_1, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma>
%1 = triton_gpu.convert_layout %0 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked>
%2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked>

View File

@@ -78,3 +78,74 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
tt.return
}
}
// -----
#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>
#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} {
// CHECK-LABEL: @dot_high_precision_acc
tt.func @dot_high_precision_acc(%a: tensor<128x128xf8E5M2, #shared>, %b: tensor<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) {
// CHECK: nvgpu.wgmma
// CHECK-COUNT-128: llvm.fadd
// CHECK: nvgpu.wgmma
// CHECK-COUNT-128: llvm.fadd
// CHECK: nvgpu.wgmma
// CHECK-COUNT-128: llvm.fadd
// CHECK: nvgpu.wgmma
// CHECK-COUNT-128: llvm.fadd
%m = triton_nvidia_gpu.dot_async %a, %b, %c
{maxNumImpreciseAcc = 32 : i32, allowTF32 = true} :
tensor<128x128xf8E5M2, #shared> * tensor<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma>
tt.return
}
}
// -----
#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>
#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} {
// CHECK-LABEL: @dot_low_precision_acc
tt.func @dot_low_precision_acc(%a: tensor<128x128xf8E5M2, #shared>, %b: tensor<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) {
// CHECK: nvgpu.wgmma
// CHECK-NOT: llvm.fadd
// CHECK: nvgpu.wgmma
// CHECK-NOT: llvm.fadd
// CHECK: nvgpu.wgmma
// CHECK-NOT: llvm.fadd
// CHECK: nvgpu.wgmma
// CHECK-NOT: llvm.fadd
// CHECK: llvm.return
%m = triton_nvidia_gpu.dot_async %a, %b, %c
{maxNumImpreciseAcc = 129 : i32, allowTF32 = true} :
tensor<128x128xf8E5M2, #shared> * tensor<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma>
tt.return
}
}
// -----
#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>
#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} {
// CHECK-LABEL: @dot_mix_precision_acc
tt.func @dot_mix_precision_acc(%a: tensor<128x128xf8E5M2, #shared>, %b: tensor<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) {
// CHECK: nvgpu.wgmma
// CHECK-NOT: llvm.fadd
// CHECK: nvgpu.wgmma
// CHECK-COUNT-128: llvm.fadd
// CHECK: nvgpu.wgmma
// CHECK-NOT: llvm.fadd
// CHECK: nvgpu.wgmma
// CHECK-COUNT-128: llvm.fadd
// CHECK: llvm.return
%m = triton_nvidia_gpu.dot_async %a, %b, %c
{maxNumImpreciseAcc = 64 : i32, allowTF32 = true} :
tensor<128x128xf8E5M2, #shared> * tensor<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma>
tt.return
}
}

View File

@@ -17,3 +17,16 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2
tt.return
}
} // end module
// -----
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} {
tt.func @wgmma_no_acc(%descA: i64, %descB: i64) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63,$64,$65,$66,$67,$68,$69,$70,$71,$72,$73,$74,$75,$76,$77,$78,$79,$80,$81,$82,$83,$84,$85,$86,$87,$88,$89,$90,$91,$92,$93,$94,$95,$96,$97,$98,$99,$100,$101,$102,$103,$104,$105,$106,$107,$108,$109,$110,$111,$112,$113,$114,$115,$116,$117,$118,$119,$120,$121,$122,$123,$124,$125,$126,$127}, $128, $129, 0, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l" %0, %1 : (i64, i64) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
%acc0 = nvgpu.wgmma %descA, %descB
{eltTypeA = 3 : i32, eltTypeB = 3 : i32, eltTypeC = 7 : i32, k = 32 : i32, layoutA = 0 : i32, layoutB = 1 : i32, m = 64 : i32, n = 256 : i32} :
(i64, i64) ->
!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
tt.return
}
}

View File

@@ -10,12 +10,12 @@ tt.func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128x
%zero = arith.constant dense<0.0> : tensor<128x128xf32>
%d = arith.constant dense<3.0> : tensor<128x128xf32>
%dot_out = tt.dot %a, %b, %zero {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
%dot_out = tt.dot %a, %b, %zero {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
// CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
// CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
%res0 = arith.addf %dot_out, %d : tensor<128x128xf32>
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
%res1 = arith.addf %d, %dot_out : tensor<128x128xf32>
tt.return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>

View File

@@ -1543,7 +1543,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
%26 = triton_gpu.convert_layout %19 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>>
%27 = triton_gpu.convert_layout %25 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>>
%28 = triton_gpu.convert_layout %cst : (tensor<32x32xf32, #blocked>) -> tensor<32x32xf32, #blocked5>
%29 = tt.dot %26, %27, %28 {allowTF32 = true} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>> -> tensor<32x32xf32, #blocked5>
%29 = tt.dot %26, %27, %28 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>> -> tensor<32x32xf32, #blocked5>
%30 = triton_gpu.convert_layout %29 : (tensor<32x32xf32, #blocked5>) -> tensor<32x32xf32, #blocked>
%31:2 = "tt.reduce"(%30, %11) <{axis = 1 : i32}> ({
^bb0(%arg3: f32, %arg4: i32, %arg5: f32, %arg6: i32):
@@ -1690,7 +1690,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%117 = tt.load %116 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf16, #blocked3>
%118 = triton_gpu.convert_layout %41 : (tensor<128x64xf16, #blocked2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
%119 = triton_gpu.convert_layout %97 : (tensor<64x64xf16, #blocked6>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
%120 = tt.dot %118, %119, %cst {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf16, #blocked>
%120 = tt.dot %118, %119, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf16, #blocked>
%121 = triton_gpu.convert_layout %120 : (tensor<128x64xf16, #blocked>) -> tensor<128x64xf16, #blocked2>
%122 = arith.extf %121 : tensor<128x64xf16, #blocked2> to tensor<128x64xf32, #blocked2>
%123 = "tt.reduce"(%122) <{axis = 1 : i32}> ({
@@ -1719,7 +1719,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%142 = triton_gpu.convert_layout %141 : (tensor<128x64xf16, #blocked2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
%143 = triton_gpu.convert_layout %117 : (tensor<64x64xf16, #blocked3>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
%144 = triton_gpu.convert_layout %140 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #blocked>
%145 = tt.dot %142, %143, %144 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf32, #blocked>
%145 = tt.dot %142, %143, %144 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf32, #blocked>
%146 = triton_gpu.convert_layout %145 : (tensor<128x64xf32, #blocked>) -> tensor<128x64xf32, #blocked2>
%147 = arith.mulf %arg24, %127 : tensor<128xf32, #blocked1>
%148 = "tt.reduce"(%133) <{axis = 1 : i32}> ({

View File

@@ -36,7 +36,7 @@ tt.func @push_elementwise(
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR>
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #ALR>) -> tensor<16x16xf16, #Av2k4>
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLC>) -> tensor<16x16xf16, #Bv2k4>
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2>
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2>
tt.return %newc : tensor<16x16xf32, #Cv2>
}
@@ -58,7 +58,7 @@ tt.func @succeeds_if_arg_is_not_convert_layout(
%dotaf8 = tt.bitcast %dotai8 : tensor<16x16xi8, #Av2k4> -> tensor<16x16xf8E5M2, #Av2k4>
%dota = tt.fp_to_fp %dotaf8 : tensor<16x16xf8E5M2, #Av2k4> -> tensor<16x16xf16, #Av2k4>
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLC>) -> tensor<16x16xf16, #Bv2k4>
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2>
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2>
tt.return %newc : tensor<16x16xf32, #Cv2>
}
@@ -82,7 +82,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.compute-capabil
// CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>>
// CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %[[BLOAD]] : (tensor<16x16xf16, #[[BB]]>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
// CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
// CHECK: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma>
// CHECK: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma>
tt.func @push_convert_both_operands(
%pa: tensor<16x16x!tt.ptr<f16>, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%pb: tensor<16x16x!tt.ptr<f16>, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
@@ -93,7 +93,7 @@ tt.func @push_convert_both_operands(
%be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB>
%al = triton_gpu.convert_layout %ae : (tensor<16x16xf32, #blockedA>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%bl = triton_gpu.convert_layout %be : (tensor<16x16xf32, #blockedB>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%r = tt.dot %al, %bl, %c {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma>
%r = tt.dot %al, %bl, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma>
tt.return %r : tensor<16x16xf32, #mma>
}
@@ -119,7 +119,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.compute-capabil
// CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %[[BLOAD]] : (tensor<16x16xf16, #[[BB]]>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
// CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
// CHECK: %[[ADD:.+]] = arith.addf %[[BEXT]], %[[CST]] : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
// CHECK: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma>
// CHECK: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma>
tt.func @update_kwidth_slice(
%pa: tensor<16x16x!tt.ptr<f16>, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%pb: tensor<16x16x!tt.ptr<f16>, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
@@ -132,7 +132,7 @@ tt.func @update_kwidth_slice(
%add = arith.addf %be, %cst : tensor<16x16xf32, #blockedB>
%al = triton_gpu.convert_layout %ae : (tensor<16x16xf32, #blockedA>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%bl = triton_gpu.convert_layout %add : (tensor<16x16xf32, #blockedB>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%r = tt.dot %al, %bl, %c {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma>
%r = tt.dot %al, %bl, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma>
tt.return %r : tensor<16x16xf32, #mma>
}

View File

@@ -74,7 +74,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
@@ -151,7 +151,7 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
@@ -220,7 +220,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
}
@@ -293,7 +293,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
//
// %sa = triton_gpu.convert_layout %a : (tensor<128x32xf16, #BA>) -> tensor<128x32xf16, #SA>
// %sb = triton_gpu.convert_layout %b : (tensor<32x128xf16, #BB>) -> tensor<32x128xf16, #SB>
// %c = tt.dot %sa, %sb, %prev_c {allowTF32 = true} : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C>
// %c = tt.dot %sa, %sb, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C>
//
// %a_tileptr_next = tt.advance %a_tileptr, [%c0, %c32_i32] : !tt.ptr<tensor<128x32xf16>, 1>
// %b_tileptr_next = tt.advance %b_tileptr, [%c32_i32, %c0] : !tt.ptr<tensor<32x128xf16>, 1>

View File

@@ -84,7 +84,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
%b_ = triton_gpu.convert_layout %b__ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
%b = arith.mulf %b_, %b_scale: tensor<32x128xf16, #B>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
@@ -157,7 +157,7 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
@@ -224,7 +224,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
%loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
}
@@ -266,7 +266,7 @@ tt.func @lut_bmm_scalar(%77: i64 {tt.divisibility=16: i32},
%87 = tt.load %86 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL>
%88 = triton_gpu.convert_layout %82 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A>
%89 = triton_gpu.convert_layout %87 : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #B>
%90 = tt.dot %88, %89, %arg19 {allowTF32 = true} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C>
%90 = tt.dot %88, %89, %arg19 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C>
%91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr<f16>, #AL>, tensor<16x16xi32, #AL>
%92 = tt.addptr %arg21, %c1_i32 : !tt.ptr<i64>, i32
scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr<f16>, #AL>, !tt.ptr<i64>
@@ -312,7 +312,7 @@ tt.func @lut_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt
%87 = tt.load %86 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL>
%88 = triton_gpu.convert_layout %82 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A>
%89 = triton_gpu.convert_layout %87 : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #B>
%90 = tt.dot %88, %89, %arg19 {allowTF32 = true} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C>
%90 = tt.dot %88, %89, %arg19 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C>
%91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr<f16>, #AL>, tensor<16x16xi32, #AL>
%92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr<i64>, #BLs1>, tensor<16xi32, #BLs1>
scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr<f16>, #AL>, tensor<16x!tt.ptr<i64>, #BLs1>
@@ -362,7 +362,7 @@ tt.func @post_load_inv(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%116 = tt.load %arg12, %115, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL>
%117 = triton_gpu.convert_layout %112 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>>
%118 = triton_gpu.convert_layout %116 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>>
%119 = tt.dot %117, %118, %arg10 {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C>
%119 = tt.dot %117, %118, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C>
%131 = arith.index_cast %arg9 : index to i32
%120 = arith.addi %131, %c1_i32 : i32
%121 = arith.muli %120, %c32_i32 : i32
@@ -425,7 +425,7 @@ tt.func @cross_iter_dep(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%150 = tt.load %arg12, %149, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL>
%151 = triton_gpu.convert_layout %146 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>>
%152 = triton_gpu.convert_layout %150 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>>
%153 = tt.dot %151, %152, %arg10 {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C>
%153 = tt.dot %151, %152, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C>
%162 = arith.index_cast %arg9 : index to i32
%154 = arith.addi %162, %c2_i32 : i32
%155 = arith.muli %154, %c32_i32 : i32
@@ -497,7 +497,7 @@ tt.func @dep_arg_two_uses(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32},
%199 = tt.load %arg24, %198, %88 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
%200 = triton_gpu.convert_layout %193 : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>>
%201 = triton_gpu.convert_layout %199 : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>>
%202 = tt.dot %200, %201, %arg23 {allowTF32 = true} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> -> tensor<128x128xf32, #C>
%202 = tt.dot %200, %201, %arg23 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> -> tensor<128x128xf32, #C>
%203 = tt.addptr %arg24, %90 : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi64, #BL>
scf.yield %190, %196, %197, %202, %203 : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>>, !tt.ptr<i32>, tensor<128x128xf32, #C>, tensor<32x128x!tt.ptr<f16>, #BL>
}

View File

@@ -52,7 +52,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%8 = tt.load %6 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<16x16xf16, #blockedB0>, 1> -> tensor<16x16xf16, #blockedB1>
%9 = triton_gpu.convert_layout %7 : (tensor<64x16xf16, #blockedA1>) -> tensor<64x16xf16, #sharedA>
%10 = triton_gpu.convert_layout %8 : (tensor<16x16xf16, #blockedB1>) -> tensor<16x16xf16, #sharedB>
%11 = tt.dot %9, %10, %cst {allowTF32 = true} : tensor<64x16xf16, #sharedA> * tensor<16x16xf16, #sharedB> -> tensor<64x16xf32, #mma>
%11 = tt.dot %9, %10, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #sharedA> * tensor<16x16xf16, #sharedB> -> tensor<64x16xf32, #mma>
%12 = triton_gpu.convert_layout %11 : (tensor<64x16xf32, #mma>) -> tensor<64x16xf32, #blockedA1>
%13 = arith.truncf %12 : tensor<64x16xf32, #blockedA1> to tensor<64x16xf16, #blockedA1>
%14 = arith.extsi %arg8 : i32 to i64

View File

@@ -62,7 +62,7 @@ tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__1
%47:3 = scf.for %arg12 = %c0 to %46 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr<f32>>, tensor<64x64x!tt.ptr<f32>>) {
%76 = tt.load %arg14, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, transA=false, transB=false} : tensor<64x64xf32>
%77 = tt.load %arg15, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, transA=false, transB=false} : tensor<64x64xf32>
%78 = tt.dot %76, %77, %cst_0 {allowTF32 = true, transA = false, transB = false} : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32>
%78 = tt.dot %76, %77, %cst_0 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32>
%79 = arith.addf %arg13, %78 : tensor<64x64xf32>
%80 = arith.muli %arg7, %c64_i32 : i32
%81 = tt.splat %80 : (i32) -> tensor<64x64xi32>

View File

@@ -53,7 +53,7 @@ tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr
%a_op_ = triton_gpu.convert_layout %a : (tensor<128x32xf8E5M2, #A>) -> tensor<128x32xf8E5M2, #A_OP>
%a_op = tt.fp_to_fp %a_op_ : tensor<128x32xf8E5M2, #A_OP> -> tensor<128x32xf16, #A_OP>
%b_op = triton_gpu.convert_layout %b : (tensor<32x128xf16, #B>) -> tensor<32x128xf16, #B_OP>
%c = tt.dot %a_op, %b_op, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C>
%c = tt.dot %a_op, %b_op, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f8E5M2>, #AL>, tensor<128x32xi32, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>

View File

@@ -15,7 +15,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
%9 = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #blocked>
%10 = triton_gpu.convert_layout %9 : (tensor<32x32xf32, #blocked>) -> tensor<32x32xf32, #shared>
%11 = triton_gpu.convert_layout %10 : (tensor<32x32xf32, #shared>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
%12 = tt.dot %11, %cst_0, %cst {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma>
%12 = tt.dot %11, %cst_0, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma>
%13 = triton_gpu.convert_layout %12 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked>
tt.store %arg0, %13 {cache = 1 : i32, evict = 1 : i32} : tensor<32x32xf32, #blocked>
tt.return
@@ -41,7 +41,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
%A = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #blocked>
%AS = triton_gpu.convert_layout %A : (tensor<32x32xf32, #blocked>) -> tensor<32x32xf32, #shared>
%AD = triton_gpu.convert_layout %AS : (tensor<32x32xf32, #shared>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
%12 = tt.dot %AD, %BD, %cst {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma>
%12 = tt.dot %AD, %BD, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma>
%13 = triton_gpu.convert_layout %12 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked>
tt.store %arg0, %13 {cache = 1 : i32, evict = 1 : i32} : tensor<32x32xf32, #blocked>
tt.return

View File

@@ -46,7 +46,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%30 = triton_gpu.convert_layout %28 : (tensor<128x64xf16, #blocked>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>>
%31 = triton_gpu.convert_layout %29 : (tensor<64x128xf16, #blocked1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>>
%32 = triton_gpu.convert_layout %arg12 : (tensor<128x128xf32, #blocked>) -> tensor<128x128xf32, #blocked2>
%33 = tt.dot %30, %31, %32 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x128xf32, #blocked2>
%33 = tt.dot %30, %31, %32 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x128xf32, #blocked2>
%34 = triton_gpu.convert_layout %33 : (tensor<128x128xf32, #blocked2>) -> tensor<128x128xf32, #blocked>
// CHECK-NOT: tt.advance
%35 = tt.advance %arg13, [%c0_i32, %c64_i32] : <tensor<128x64xf16, #blocked>, 1>

View File

@@ -97,7 +97,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked>
%91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared>
%92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1>
%93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
%93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
%94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<128x32xi32, #blocked1>
%95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr<f16, 1>, #blocked>, tensor<32x128xi32, #blocked>
scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<32x128x!tt.ptr<f16, 1>, #blocked>
@@ -208,7 +208,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK-NEXT: %90 = tt.load %arg12 {async_agent = dense<0> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked>
// CHECK-NEXT: %91 = triton_gpu.convert_layout %89 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared>
// CHECK-NEXT: %92 = triton_gpu.convert_layout %90 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1>
// CHECK-NEXT: %93 = tt.dot %91, %92, %arg10 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
// CHECK-NEXT: %93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
// CHECK-NEXT: %94 = tt.addptr %arg11, %cst_1 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<128x32xi32, #blocked1>
// CHECK-NEXT: %95 = tt.addptr %arg12, %cst_0 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr<f16, 1>, #blocked>, tensor<32x128xi32, #blocked>
// CHECK-NEXT: scf.yield {async_agent = dense<[0, 1]> : vector<2xi32>} %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<32x128x!tt.ptr<f16, 1>, #blocked>
@@ -336,7 +336,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1>
scf.yield %91, %92 : tensor<128x32xf16, #shared>, tensor<32x128xf16, #shared1>
}
%93 = tt.dot %96#0, %96#1, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
%93 = tt.dot %96#0, %96#1, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
%94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<128x32xi32, #blocked1>
%95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr<f16, 1>, #blocked>, tensor<32x128xi32, #blocked>
scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<32x128x!tt.ptr<f16, 1>, #blocked>
@@ -452,7 +452,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK-NEXT: %96 = triton_gpu.convert_layout %94 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1>
// CHECK-NEXT: scf.yield {async_agent = dense<[0, 1]> : vector<2xi32>} %95, %96 : tensor<128x32xf16, #shared>, tensor<32x128xf16, #shared1>
// CHECK-NEXT: } {async_agent = dense<[0, 1]> : vector<2xi32>}
// CHECK-NEXT: %90 = tt.dot %89#0, %89#1, %arg10 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
// CHECK-NEXT: %90 = tt.dot %89#0, %89#1, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
// CHECK-NEXT: %91 = tt.addptr %arg11, %cst_1 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<128x32xi32, #blocked1>
// CHECK-NEXT: %92 = tt.addptr %arg12, %cst_0 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr<f16, 1>, #blocked>, tensor<32x128xi32, #blocked>
// CHECK-NEXT: scf.yield {async_agent = dense<[0, 1]> : vector<2xi32>} %90, %91, %92 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<32x128x!tt.ptr<f16, 1>, #blocked>
@@ -587,7 +587,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked>
%91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared>
%92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1>
%93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
%93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
%base_94 = scf.if %96 -> (tensor<128x32x!tt.ptr<f16, 1>, #blocked1>) {
%r1_0 = arith.select %96, %c31_i32, %c127_i32 : i32
%r1_1 = tt.splat %r1_0 : (i32) -> tensor<128x32xi32, #blocked1>
@@ -717,7 +717,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK-NEXT: %92 = tt.load %arg12 {async_agent = dense<0> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked>
// CHECK-NEXT: %93 = triton_gpu.convert_layout %91 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared>
// CHECK-NEXT: %94 = triton_gpu.convert_layout %92 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1>
// CHECK-NEXT: %95 = tt.dot %93, %94, %arg10 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
// CHECK-NEXT: %95 = tt.dot %93, %94, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
// CHECK-NEXT: %96 = scf.if %90 -> (tensor<128x32x!tt.ptr<f16, 1>, #blocked1>) {
// CHECK-NEXT: %99 = arith.select %90, %c31_i32, %c127_i32 {async_agent = dense<1> : vector<1xi32>} : i32
// CHECK-NEXT: %100 = tt.splat %99 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128x32xi32, #blocked1>

View File

@@ -177,7 +177,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%63 = triton_gpu.extract_slice %0[%arg11, 0, 0] [1, 32, 128] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x32x128xf16, #shared> to tensor<32x128xf16, #shared>
%64 = triton_gpu.convert_layout %62 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x32xf16, #shared1>) -> tensor<128x32xf16, #shared1>
%65 = triton_gpu.convert_layout %63 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #shared>) -> tensor<32x128xf16, #shared>
%66 = tt.dot %64, %65, %arg10 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared1> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma>
%66 = tt.dot %64, %65, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared1> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma>
%c1_i32_2 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32
%c3_i32 = arith.constant {async_agent = dense<1> : vector<1xi32>} 3 : i32
%67 = arith.addi %arg11, %c1_i32_2 {async_agent = dense<1> : vector<1xi32>} : i32
@@ -384,7 +384,7 @@ module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability"
%50 = triton_gpu.convert_layout %49 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x16xf16, #shared>) -> tensor<64x16xf16, #shared>
%51 = triton_gpu.extract_slice %1[%48, 0, 0] [1, 16, 64] [1, 1, 1] {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
%52 = triton_gpu.convert_layout %51 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<16x64xf16, #shared1>) -> tensor<16x64xf16, #shared1>
%53 = tt.dot %50, %52, %arg12 {agent.mutex_role = 0 : i32, allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%53 = tt.dot %50, %52, %arg12 {agent.mutex_role = 0 : i32, allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
triton_nvidia_gpu.consumer_release %2, %48 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
%c1_i32_6 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32
%54 = arith.addi %arg13, %c1_i32_6 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32

View File

@@ -141,7 +141,7 @@ module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability"
%40 = triton_gpu.convert_layout %39 {async_agent = dense<1> : vector<1xi32>} : (tensor<64x16xf16, #shared>) -> tensor<64x16xf16, #shared>
%41 = triton_gpu.extract_slice %1[%38, 0, 0] [1, 16, 64] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
%42 = triton_gpu.convert_layout %41 {async_agent = dense<1> : vector<1xi32>} : (tensor<16x64xf16, #shared1>) -> tensor<16x64xf16, #shared1>
%43 = tt.dot %40, %42, %arg12 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%43 = tt.dot %40, %42, %arg12 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
triton_nvidia_gpu.consumer_release %2, %38 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32
%c1_i32_5 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32
%44 = arith.addi %arg13, %c1_i32_5 {async_agent = dense<1> : vector<1xi32>} : i32

View File

@@ -120,7 +120,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked>
%91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared>
%92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1>
%93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
%93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
%94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<128x32xi32, #blocked1>
%95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr<f16, 1>, #blocked>, tensor<32x128xi32, #blocked>
scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<32x128x!tt.ptr<f16, 1>, #blocked>

View File

@@ -96,7 +96,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked>
%91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared>
%92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1>
%93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
%93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
%94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<128x32xi32, #blocked1>
%95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr<f16, 1>, #blocked>, tensor<32x128xi32, #blocked>
scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<32x128x!tt.ptr<f16, 1>, #blocked>
@@ -226,7 +226,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1>
scf.yield %91, %92 : tensor<128x32xf16, #shared>, tensor<32x128xf16, #shared1>
}
%93 = tt.dot %96#0, %96#1, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
%93 = tt.dot %96#0, %96#1, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
%94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<128x32xi32, #blocked1>
%95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr<f16, 1>, #blocked>, tensor<32x128xi32, #blocked>
scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr<f16, 1>, #blocked1>, tensor<32x128x!tt.ptr<f16, 1>, #blocked>
@@ -362,7 +362,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked>
%91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared>
%92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1>
%93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
%93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma>
%base_94 = scf.if %96 -> (tensor<128x32x!tt.ptr<f16, 1>, #blocked1>) {
%r1_0 = arith.select %96, %c31_i32, %c127_i32 : i32
%r1_1 = tt.splat %r1_0 : (i32) -> tensor<128x32xi32, #blocked1>
@@ -438,7 +438,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%47 = tt.load %arg12 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<16x64xf16, #blocked1>, 1> -> tensor<16x64xf16, #blocked3>
%48 = triton_gpu.convert_layout %46 : (tensor<64x16xf16, #blocked2>) -> tensor<64x16xf16, #shared>
%49 = triton_gpu.convert_layout %47 : (tensor<16x64xf16, #blocked3>) -> tensor<16x64xf16, #shared1>
%50 = tt.dot %48, %49, %arg10 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%50 = tt.dot %48, %49, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%51 = tt.advance %arg11, [%c0_i32, %c16_i32] : <tensor<64x16xf16, #blocked>, 1>
%52 = tt.advance %arg12, [%c16_i32, %c0_i32] : <tensor<16x64xf16, #blocked1>, 1>
scf.yield %50, %51, %52 : tensor<64x64xf32, #mma>, !tt.ptr<tensor<64x16xf16, #blocked>, 1>, !tt.ptr<tensor<16x64xf16, #blocked1>, 1>
@@ -518,7 +518,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%47 = tt.load %arg12 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<16x64xf16, #blocked1>, 1> -> tensor<16x64xf16, #blocked3>
%48 = triton_gpu.convert_layout %46 : (tensor<64x16xf16, #blocked2>) -> tensor<64x16xf16, #shared>
%49 = triton_gpu.convert_layout %47 : (tensor<16x64xf16, #blocked3>) -> tensor<16x64xf16, #shared1>
%50 = tt.dot %48, %49, %arg10 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%50 = tt.dot %48, %49, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%51 = tt.advance %arg11, [%c0_i32, %c16_i32] : <tensor<64x16xf16, #blocked>, 1>
%52 = tt.advance %arg12, [%c16_i32, %c0_i32] : <tensor<16x64xf16, #blocked1>, 1>
scf.yield %50, %51, %52 : tensor<64x64xf32, #mma>, !tt.ptr<tensor<64x16xf16, #blocked>, 1>, !tt.ptr<tensor<16x64xf16, #blocked1>, 1>
@@ -600,7 +600,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%47 = tt.load %arg12 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<16x64xf16, #blocked1>, 1> -> tensor<16x64xf16, #blocked3>
%48 = triton_gpu.convert_layout %46 : (tensor<64x16xf16, #blocked2>) -> tensor<64x16xf16, #shared>
%49 = triton_gpu.convert_layout %47 : (tensor<16x64xf16, #blocked3>) -> tensor<16x64xf16, #shared1>
%50 = tt.dot %48, %49, %arg10 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%50 = tt.dot %48, %49, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%51 = tt.advance %arg11, [%c0_i32, %c16_i32] : <tensor<64x16xf16, #blocked>, 1>
%52 = tt.advance %arg12, [%c16_i32, %c0_i32] : <tensor<16x64xf16, #blocked1>, 1>
scf.yield %50, %51, %52 : tensor<64x64xf32, #mma>, !tt.ptr<tensor<64x16xf16, #blocked>, 1>, !tt.ptr<tensor<16x64xf16, #blocked1>, 1>
@@ -686,7 +686,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%47 = tt.load %arg12 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<16x64xf16, #blocked1>, 1> -> tensor<16x64xf16, #blocked3>
%48 = triton_gpu.convert_layout %46 : (tensor<64x16xf16, #blocked2>) -> tensor<64x16xf16, #shared>
%49 = triton_gpu.convert_layout %47 : (tensor<16x64xf16, #blocked3>) -> tensor<16x64xf16, #shared1>
%50 = tt.dot %48, %49, %arg10 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%50 = tt.dot %48, %49, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%51 = tt.advance %arg11, [%c0_i32, %c16_i32] : <tensor<64x16xf16, #blocked>, 1>
%52 = tt.advance %arg12, [%c16_i32, %c0_i32] : <tensor<16x64xf16, #blocked1>, 1>
scf.yield %50, %51, %52 : tensor<64x64xf32, #mma>, !tt.ptr<tensor<64x16xf16, #blocked>, 1>, !tt.ptr<tensor<16x64xf16, #blocked1>, 1>
@@ -799,7 +799,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%44 = tt.load %arg17 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<64x128xf16, #blocked1>, 1> -> tensor<64x128xf16, #blocked4>
%45 = triton_gpu.convert_layout %43 : (tensor<256x64xf16, #blocked3>) -> tensor<256x64xf16, #shared>
%46 = triton_gpu.convert_layout %44 : (tensor<64x128xf16, #blocked4>) -> tensor<64x128xf16, #shared1>
%47 = tt.dot %45, %46, %arg15 {allowTF32 = true} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma>
%47 = tt.dot %45, %46, %arg15 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma>
%48 = tt.advance %arg16, [%c0_i32, %c64_i32] : <tensor<256x64xf16, #blocked>, 1>
%49 = tt.advance %arg17, [%c64_i32, %c0_i32] : <tensor<64x128xf16, #blocked1>, 1>
scf.yield %47, %48, %49 : tensor<256x128xf32, #mma>, !tt.ptr<tensor<256x64xf16, #blocked>, 1>, !tt.ptr<tensor<64x128xf16, #blocked1>, 1>
@@ -852,7 +852,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%b = tt.load %arg1 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<64x128xf16, #blocked1>, 1> -> tensor<64x128xf16, #blocked4>
%shm_a = triton_gpu.convert_layout %a : (tensor<256x64xf16, #blocked3>) -> tensor<256x64xf16, #shared>
%shm_b = triton_gpu.convert_layout %b : (tensor<64x128xf16, #blocked4>) -> tensor<64x128xf16, #shared1>
%d = tt.dot %shm_a, %shm_b, %cst {allowTF32 = true} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma>
%d = tt.dot %shm_a, %shm_b, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma>
%out = triton_gpu.convert_layout %d : (tensor<256x128xf32, #mma>) -> tensor<256x128xf32, #blocked2>
tt.store %arg2, %out {cache = 1 : i32, evict = 1 : i32} : tensor<256x128xf32, #blocked2>
}
@@ -887,7 +887,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%b = tt.load %arg1 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x128x!tt.ptr<f16, 1>, #blocked4> -> tensor<64x128xf16, #blocked4>
%shm_a = triton_gpu.convert_layout %a : (tensor<256x64xf16, #blocked3>) -> tensor<256x64xf16, #shared>
%shm_b = triton_gpu.convert_layout %b : (tensor<64x128xf16, #blocked4>) -> tensor<64x128xf16, #shared1>
%d = tt.dot %shm_a, %shm_b, %cst {allowTF32 = true} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma>
%d = tt.dot %shm_a, %shm_b, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma>
%out = triton_gpu.convert_layout %d : (tensor<256x128xf32, #mma>) -> tensor<256x128xf32, #blocked2>
tt.store %arg2, %out {cache = 1 : i32, evict = 1 : i32} : tensor<256x128xf32, #blocked2>
}
@@ -1000,7 +1000,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%92 = tt.load %arg19 {boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<16x64xf16, #blocked1>, 1> -> tensor<16x64xf16, #blocked4>
%93 = triton_gpu.convert_layout %91 : (tensor<64x16xf16, #blocked3>) -> tensor<64x16xf16, #shared>
%94 = triton_gpu.convert_layout %92 : (tensor<16x64xf16, #blocked4>) -> tensor<16x64xf16, #shared1>
%95 = tt.dot %93, %94, %arg17 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%95 = tt.dot %93, %94, %arg17 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%96 = tt.advance %arg18, [%c0_i32, %c16_i32] : <tensor<64x16xf16, #blocked>, 1>
%97 = tt.advance %arg19, [%c16_i32, %c0_i32] : <tensor<16x64xf16, #blocked1>, 1>
scf.yield %95, %96, %97 : tensor<64x64xf32, #mma>, !tt.ptr<tensor<64x16xf16, #blocked>, 1>, !tt.ptr<tensor<16x64xf16, #blocked1>, 1>