[Alloc] Enhanced SharedMem Allocation for mutually exclusive but aliased buffers (#337)

* [Alloc] Enhanced for mutually exclusive but aliased buffers

- Use disjoint alias analysis to minimize shared memory requirements

* * fix for allocation test

* * added test
* fixed mfma_enc printer

* * fixed test
This commit is contained in:
SJW
2023-09-25 20:09:33 -05:00
committed by GitHub
parent 7af5e42fbe
commit 4db99e0139
5 changed files with 249 additions and 25 deletions

View File

@@ -29,6 +29,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
template <typename T> class Interval {
public:
Interval() {}
Interval(T S) : Start(S), End(S+1) {}
Interval(T S, T E) : Start(S), End(E) { assert(Start <= End); }
T start() const { return Start; }
T end() const { return End; }
@@ -44,6 +45,16 @@ public:
bool operator<(const Interval &R) const {
return std::make_pair(Start, End) < std::make_pair(R.Start, R.End);
}
bool adjacent(T Addr) const {
return Addr+1 == Start || Addr == End;
}
bool adjacent(const Interval &R) const {
return adjacent(R.Start) || adjacent(R.End-1);
}
Interval merge(const Interval &R) const {
return Interval(std::min(Start, R.Start), std::max(End, R.End));
}
private:
T Start = std::numeric_limits<T>::min();

View File

@@ -38,6 +38,13 @@ def TritonGPU_Dialect : Dialect {
}
return threadsPerWarp.cast<IntegerAttr>().getInt();
}
static int getSharedSize(ModuleOp mod) {
Attribute sharedAttr = mod->getDiscardableAttr("triton_gpu.shared");
if(!sharedAttr) {
return 0;
}
return sharedAttr.cast<IntegerAttr>().getInt();
}
}];

View File

@@ -137,11 +137,68 @@ private:
using BufferT = Allocation::BufferT;
/// Value -> Liveness Range
using IntervalT = Interval<size_t>;
/// Use MapVector to ensure determinism.
using BufferRangeMapT = llvm::MapVector<BufferT *, Interval<size_t>>;
using BufferRangeMapT = llvm::MapVector<BufferT *, IntervalT>;
/// Nodes -> Nodes
using GraphT = DenseMap<BufferT *, DenseSet<BufferT *>>;
/// Set of Liveness Intervals
class LivenessR : public SmallVector<IntervalT, 4> {
public:
LivenessR() = default;
LivenessR(const LivenessR &) = default;
/// Disjointness
bool isDisjoint() const {
if (size() < 2)
return false;
// sorted so the first OOB proves disjoint
auto maxId = (*this)[0].end();
for (auto rng : *this) {
if (rng.start() <= maxId) {
// adjoining
maxId = std::max(maxId, rng.end());
} else
return true;
}
return false;
}
void sort() {
llvm::sort(*this, [](const auto &lhs, const auto &rhs) {
return lhs.start() <= rhs.start();
});
}
bool addAdjacent(size_t id) {
bool isAdjacent = false;
for (auto &interval : *this) {
if (interval.adjacent(id)) {
isAdjacent = true;
interval = interval.merge(IntervalT(id));
}
}
return isAdjacent;
}
void add(size_t id) {
if (!addAdjacent(id))
push_back(IntervalT(id));
}
IntervalT unionize() const {
IntervalT res;
if (size()) {
res = front();
for (auto &I : *this)
res = res.merge(I);
}
return res;
}
};
typedef function_ref<LivenessR(Value value)> LivenessF;
void run() {
getValuesAndSizes();
resolveLiveness();
@@ -289,33 +346,55 @@ private:
/// Computes the liveness range of the allocated value.
/// Each buffer is allocated only once.
void resolveExplicitBufferLiveness(
function_ref<Interval<size_t>(Value value)> getLiveness) {
void resolveExplicitBufferLiveness(LivenessF getLiveness) {
for (auto valueBufferIter : allocation->valueBuffer) {
auto value = valueBufferIter.first;
auto *buffer = valueBufferIter.second;
bufferRange[buffer] = getLiveness(value);
auto ranges = getLiveness(value);
bufferRange[buffer] = ranges.unionize();
}
}
/// Extends the liveness range by unionizing the liveness range of the aliased
/// values because each allocated buffer could be an alias of others, if block
/// arguments are involved.
void resolveAliasBufferLiveness(
function_ref<Interval<size_t>(Value value)> getLiveness) {
/// Only unionize adjacent live ranges to account for loop-carried buffers that
/// are mutually exclusive.
/// Example from stream pipeliner:
/// 3 %b0 = convert_layout %g0 -+
/// 4 %fr = for (.., %arg0 = %b0) { |
/// 5 %gn = load %pc |
/// 6 %bc = convert_layout %arg0 -+
/// 7 %v = add %bc, ...
/// 8 %bn = convert_layout %gn -+
/// 9 %pn = addptr %pc, %cst |
/// 10 } |
/// 11 %be = convert_layout %fr#1 -+
/// 12 %ve = add %be
void resolveAliasBufferLiveness(LivenessF getLiveness) {
for (auto aliasBufferIter : allocation->aliasBuffer) {
auto value = aliasBufferIter.first;
auto buffers = aliasBufferIter.second;
auto range = getLiveness(value);
auto aranges = getLiveness(value);
bool disjoint = aranges.isDisjoint();
for (auto *buffer : buffers) {
auto minId = range.start();
auto maxId = range.end();
auto range = aranges[0];
if (bufferRange.count(buffer)) {
// Extend the allocated buffer's range
minId = std::min(minId, bufferRange[buffer].start());
maxId = std::max(maxId, bufferRange[buffer].end());
auto brange = bufferRange[buffer];
if (disjoint) {
// find adjacent/intersecting
for (auto arange : aranges) {
if (arange.adjacent(brange) ||
arange.intersects(brange))
brange = arange.merge(brange);
}
range = brange;
} else {
// Extend the allocated buffer's range
range = range.merge(brange);
}
}
bufferRange[buffer] = Interval(minId, maxId);
bufferRange[buffer] = range;
}
}
}
@@ -366,18 +445,13 @@ private:
Liveness liveness(operation);
auto getValueLivenessRange = [&](Value value) {
auto liveOperations = liveness.resolveLiveness(value);
auto minId = std::numeric_limits<size_t>::max();
auto maxId = std::numeric_limits<size_t>::min();
LivenessR ranges;
std::for_each(liveOperations.begin(), liveOperations.end(),
[&](Operation *liveOp) {
if (operationId[liveOp] < minId) {
minId = operationId[liveOp];
}
if ((operationId[liveOp] + 1) > maxId) {
maxId = operationId[liveOp] + 1;
}
ranges.add(operationId[liveOp]);
});
return Interval(minId, maxId);
ranges.sort();
return ranges;
};
resolveExplicitBufferLiveness(getValueLivenessRange);
@@ -432,9 +506,9 @@ private:
// If the available triple's range is less than a given buffer range,
// we won't know if there has been an overlap without using graph coloring.
// Start -> Liveness Range
using TripleMapT = std::multimap<size_t, Interval<size_t>>;
using TripleMapT = std::multimap<size_t, IntervalT>;
TripleMapT tripleMap;
tripleMap.insert(std::make_pair(0, Interval<size_t>()));
tripleMap.insert(std::make_pair(0, IntervalT()));
SmallVector<BufferT *> xBuffers = buffers;
while (!xBuffers.empty()) {
auto tripleIt = tripleMap.begin();
@@ -542,6 +616,19 @@ private:
}
}
void dump() const {
llvm::outs() << "DUMP: " << "\n";
for (auto bufferIter : bufferRange) {
llvm::outs() << "ID= " << bufferIter.first->id << "\n";
// llvm::outs() << " Kind= " << kind << "\n";
llvm::outs() << " Size= " << bufferIter.first->size << "\n";
llvm::outs() << " Offs= " << bufferIter.first->offset << "\n";
llvm::outs() << " -> " << bufferIter.second.start() << "\n";
llvm::outs() << " -> " << bufferIter.second.end() << "\n";
}
}
private:
Operation *operation;
Allocation::FuncAllocMapT *funcAllocMap;

View File

@@ -1095,7 +1095,7 @@ void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
auto mmaParent = getParent().dyn_cast<MmaEncodingAttr>();
printer << "<{"
<< "opIdx = " << getOpIdx() << ", parent = " << getParent();
if (mmaParent && mmaParent.isAmpere())
if ((mmaParent && mmaParent.isAmpere()) || getParent().isa<MfmaEncodingAttr>())
printer << ", kWidth = " << getKWidth();
printer << "}>";
}
@@ -1221,6 +1221,9 @@ public:
if (auto mmaAttr = attr.dyn_cast<MmaEncodingAttr>()) {
os << "mma";
return AliasResult::FinalAlias;
} else if (attr.isa<MfmaEncodingAttr>()) {
os << "mfma";
return AliasResult::FinalAlias;
} else if (auto sharedAttr = attr.dyn_cast<SharedEncodingAttr>()) {
os << "shared";
return AliasResult::FinalAlias;

View File

@@ -0,0 +1,116 @@
// RUN: triton-opt --convert-triton-gpu-to-llvm %s | FileCheck %s
// CHECK: module attributes {{{.*}}, triton_gpu.shared = 9216 : i32
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 8, order = [1, 0]}>
#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @matmul_kernel_0d1d2d3d4d5d6d7c8d9c10d11c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>>
%cst_0 = arith.constant dense<32> : tensor<64x32xi32, #blocked>
%c31_i32 = arith.constant 31 : i32
%c63_i32 = arith.constant 63 : i32
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c32_i32 = arith.constant 32 : i32
%c64_i32 = arith.constant 64 : i32
%c4_i32 = arith.constant 4 : i32
%0 = tt.get_program_id x : i32
%1 = arith.addi %arg3, %c63_i32 : i32
%2 = arith.divsi %1, %c64_i32 : i32
%3 = arith.addi %arg4, %c63_i32 : i32
%4 = arith.divsi %3, %c64_i32 : i32
%5 = arith.muli %4, %c4_i32 : i32
%6 = arith.divsi %0, %5 : i32
%7 = arith.muli %6, %c4_i32 : i32
%8 = arith.subi %2, %7 : i32
%9 = "triton_gpu.cmpi"(%8, %c4_i32) <{predicate = 2 : i64}> : (i32, i32) -> i1
%10 = arith.select %9, %8, %c4_i32 : i32
%11 = arith.remsi %0, %10 : i32
%12 = arith.addi %7, %11 : i32
%13 = arith.remsi %0, %5 : i32
%14 = arith.divsi %13, %10 : i32
%15 = arith.muli %12, %c64_i32 : i32
%16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%17 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%19 = tt.splat %15 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%20 = tt.splat %15 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%21 = arith.addi %19, %16 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%22 = arith.addi %20, %18 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%23 = arith.muli %14, %c64_i32 : i32
%24 = tt.splat %23 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%25 = arith.addi %24, %17 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%26 = tt.expand_dims %21 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi32, #blocked>
%27 = tt.expand_dims %22 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1>
%28 = tt.splat %arg6 : (i32) -> tensor<64x1xi32, #blocked>
%29 = arith.muli %26, %28 : tensor<64x1xi32, #blocked>
%30 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked>
%31 = tt.addptr %30, %29 : tensor<64x1x!tt.ptr<f16>, #blocked>, tensor<64x1xi32, #blocked>
%32 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%33 = tt.expand_dims %32 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x32xi32, #blocked>
%34 = tt.broadcast %31 : (tensor<64x1x!tt.ptr<f16>, #blocked>) -> tensor<64x32x!tt.ptr<f16>, #blocked>
%35 = tt.broadcast %33 : (tensor<1x32xi32, #blocked>) -> tensor<64x32xi32, #blocked>
%36 = tt.addptr %34, %35 : tensor<64x32x!tt.ptr<f16>, #blocked>, tensor<64x32xi32, #blocked>
%37 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%38 = tt.expand_dims %37 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<32x1xi32, #blocked1>
%39 = tt.splat %arg7 : (i32) -> tensor<32x1xi32, #blocked1>
%40 = arith.muli %38, %39 : tensor<32x1xi32, #blocked1>
%41 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<32x1x!tt.ptr<f16>, #blocked1>
%42 = tt.addptr %41, %40 : tensor<32x1x!tt.ptr<f16>, #blocked1>, tensor<32x1xi32, #blocked1>
%43 = tt.expand_dims %25 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
%44 = tt.broadcast %42 : (tensor<32x1x!tt.ptr<f16>, #blocked1>) -> tensor<32x64x!tt.ptr<f16>, #blocked1>
%45 = tt.broadcast %43 : (tensor<1x64xi32, #blocked1>) -> tensor<32x64xi32, #blocked1>
%46 = tt.addptr %44, %45 : tensor<32x64x!tt.ptr<f16>, #blocked1>, tensor<32x64xi32, #blocked1>
%47 = arith.addi %arg5, %c31_i32 : i32
%48 = arith.divsi %47, %c32_i32 : i32
%49 = arith.muli %arg7, %c32_i32 : i32
%50 = tt.splat %49 : (i32) -> tensor<32x64xi32, #blocked1>
%51 = tt.load %36 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16, #blocked>
%52 = triton_gpu.convert_layout %51 : (tensor<64x32xf16, #blocked>) -> tensor<64x32xf16, #shared>
%53 = tt.load %46 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16, #blocked1>
%54 = triton_gpu.convert_layout %53 : (tensor<32x64xf16, #blocked1>) -> tensor<32x64xf16, #shared1>
%55 = tt.addptr %36, %cst_0 : tensor<64x32x!tt.ptr<f16>, #blocked>, tensor<64x32xi32, #blocked>
%56 = tt.addptr %46, %50 : tensor<32x64x!tt.ptr<f16>, #blocked1>, tensor<32x64xi32, #blocked1>
%57 = arith.subi %48, %c1_i32 : i32
cf.br ^bb1(%c0_i32, %cst, %52, %54, %55, %56 : i32, tensor<64x64xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>>, tensor<64x32xf16, #shared>, tensor<32x64xf16, #shared1>, tensor<64x32x!tt.ptr<f16>, #blocked>, tensor<32x64x!tt.ptr<f16>, #blocked1>)
^bb1(%58: i32, %59: tensor<64x64xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>>, %60: tensor<64x32xf16, #shared>, %61: tensor<32x64xf16, #shared1>, %62: tensor<64x32x!tt.ptr<f16>, #blocked>, %63: tensor<32x64x!tt.ptr<f16>, #blocked1>): // 2 preds: ^bb0, ^bb2
%64 = arith.cmpi slt, %58, %57 : i32
cf.cond_br %64, ^bb2, ^bb3
^bb2: // pred: ^bb1
%65 = tt.load %62 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16, #blocked>
%66 = tt.load %63 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16, #blocked1>
%67 = triton_gpu.convert_layout %60 : (tensor<64x32xf16, #shared>) -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>, kWidth = 8}>>
%68 = triton_gpu.convert_layout %61 : (tensor<32x64xf16, #shared1>) -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>, kWidth = 8}>>
%69 = tt.dot %67, %68, %59 {allowTF32 = true} : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>, kWidth = 8}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>, kWidth = 8}>> -> tensor<64x64xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>>
%70 = tt.addptr %62, %cst_0 : tensor<64x32x!tt.ptr<f16>, #blocked>, tensor<64x32xi32, #blocked>
%71 = tt.addptr %63, %50 : tensor<32x64x!tt.ptr<f16>, #blocked1>, tensor<32x64xi32, #blocked1>
%72 = triton_gpu.convert_layout %65 : (tensor<64x32xf16, #blocked>) -> tensor<64x32xf16, #shared>
%73 = triton_gpu.convert_layout %66 : (tensor<32x64xf16, #blocked1>) -> tensor<32x64xf16, #shared1>
%74 = arith.addi %58, %c1_i32 : i32
cf.br ^bb1(%74, %69, %72, %73, %70, %71 : i32, tensor<64x64xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>>, tensor<64x32xf16, #shared>, tensor<32x64xf16, #shared1>, tensor<64x32x!tt.ptr<f16>, #blocked>, tensor<32x64x!tt.ptr<f16>, #blocked1>)
^bb3: // pred: ^bb1
%75 = triton_gpu.convert_layout %60 : (tensor<64x32xf16, #shared>) -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>, kWidth = 8}>>
%76 = triton_gpu.convert_layout %61 : (tensor<32x64xf16, #shared1>) -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>, kWidth = 8}>>
%77 = tt.dot %75, %76, %59 {allowTF32 = true} : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>, kWidth = 8}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>, kWidth = 8}>> -> tensor<64x64xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>>
%78 = arith.truncf %77 : tensor<64x64xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>> to tensor<64x64xf16, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>>
%79 = tt.splat %arg8 : (i32) -> tensor<64x1xi32, #blocked1>
%80 = arith.muli %79, %27 : tensor<64x1xi32, #blocked1>
%81 = tt.splat %arg2 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked1>
%82 = tt.addptr %81, %80 : tensor<64x1x!tt.ptr<f16>, #blocked1>, tensor<64x1xi32, #blocked1>
%83 = tt.broadcast %82 : (tensor<64x1x!tt.ptr<f16>, #blocked1>) -> tensor<64x64x!tt.ptr<f16>, #blocked1>
%84 = tt.broadcast %43 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1>
%85 = tt.addptr %83, %84 : tensor<64x64x!tt.ptr<f16>, #blocked1>, tensor<64x64xi32, #blocked1>
%86 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked1>
%87 = "triton_gpu.cmpi"(%27, %86) <{predicate = 2 : i64}> : (tensor<64x1xi32, #blocked1>, tensor<64x1xi32, #blocked1>) -> tensor<64x1xi1, #blocked1>
%88 = tt.splat %arg4 : (i32) -> tensor<1x64xi32, #blocked1>
%89 = "triton_gpu.cmpi"(%43, %88) <{predicate = 2 : i64}> : (tensor<1x64xi32, #blocked1>, tensor<1x64xi32, #blocked1>) -> tensor<1x64xi1, #blocked1>
%90 = tt.broadcast %87 : (tensor<64x1xi1, #blocked1>) -> tensor<64x64xi1, #blocked1>
%91 = tt.broadcast %89 : (tensor<1x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1>
%92 = arith.andi %90, %91 : tensor<64x64xi1, #blocked1>
%93 = triton_gpu.convert_layout %78 : (tensor<64x64xf16, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>>) -> tensor<64x64xf16, #blocked1>
tt.store %85, %93, %92 {cache = 1 : i32, evict = 1 : i32} : tensor<64x64xf16, #blocked1>
tt.return
}
}