[BACKEND] Fix topological sort and add new test cases (#1132)

Previous https://github.com/openai/triton/pull/1113 forgot to consider
that a node may have multiple parents, visiting the instruction before
any parent violates the semantic of topological sort.

The fixed implementation exhaustively add all operations into a
candidate subgraph and move an operation to the "ready" queue once all
of its operands have been visited.
This commit is contained in:
Keren Zhou
2023-01-31 23:41:20 -08:00
committed by GitHub
parent fc846e5e1e
commit 5dd8ce3745
6 changed files with 632 additions and 12 deletions

View File

@@ -80,6 +80,14 @@ SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
bool isMmaToDotShortcut(triton::gpu::MmaEncodingAttr &mmaLayout,
triton::gpu::DotOperandEncodingAttr &dotOperandLayout);
/// Multi-root DAG topological sort.
/// Performs a topological sort of the Operation in the `toSort` SetVector.
/// Returns a topologically sorted SetVector.
/// It is faster than mlir::topologicalSort because it prunes nodes that have
/// been visited before.
SetVector<Operation *>
multiRootTopologicalSort(const SetVector<Operation *> &toSort);
} // namespace mlir
#endif // TRITON_ANALYSIS_UTILITY_H

View File

@@ -228,7 +228,8 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>,
// Shape Manipulation Ops
//
def TT_SplatOp : TT_Op<"splat", [NoSideEffect,
SameOperandsAndResultElementType]> {
SameOperandsAndResultElementType,
SameOperandsAndResultEncoding]> {
let summary = "splat";
let arguments = (ins TT_Type:$src);
@@ -253,7 +254,8 @@ def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect,
}
def TT_ViewOp : TT_Op<"view", [NoSideEffect,
SameOperandsAndResultElementType]> {
SameOperandsAndResultElementType,
SameOperandsAndResultEncoding]> {
let summary = "view";
let arguments = (ins TT_Tensor:$src);
@@ -265,7 +267,8 @@ def TT_ViewOp : TT_Op<"view", [NoSideEffect,
}
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
SameOperandsAndResultElementType]> {
SameOperandsAndResultElementType,
SameOperandsAndResultEncoding]> {
let summary = "broadcast. No left-padding as of now.";
let arguments = (ins TT_Type:$src);
@@ -278,7 +281,8 @@ def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
}
def TT_CatOp : TT_Op<"cat", [NoSideEffect,
SameOperandsAndResultElementType]> {
SameOperandsAndResultElementType,
SameOperandsAndResultEncoding]> {
let summary = "concatenate 2 tensors";
let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);
@@ -289,7 +293,9 @@ def TT_CatOp : TT_Op<"cat", [NoSideEffect,
}
def TT_TransOp : TT_Op<"trans", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
DeclareOpInterfaceMethods<InferTypeOpInterface>,
SameOperandsAndResultElementType,
SameOperandsAndResultEncoding]> {
let summary = "transpose a tensor";

View File

@@ -2,6 +2,7 @@
#include "mlir/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <deque>
namespace mlir {
@@ -164,4 +165,126 @@ bool isMmaToDotShortcut(triton::gpu::MmaEncodingAttr &mmaLayout,
dotOperandLayout.getParent() == mmaLayout;
}
namespace {
/// A data structure similar to SetVector but maintains
/// a deque instead of a vector to allow for efficient
/// push_back and pop_front operations.
/// Using SetVector doesn't suffice our needs because
/// it only pushes and pops from the back.
/// For example, if we have a queue like this:
/// 0->4 1->2->3
/// ^--------
/// where 3 depends on 4, once we pop 3, we found
/// 4 is not ready, so we check 2 and push 3 back
/// to the queue.
struct DFSSubgraphState {
DFSSubgraphState() : set(), deque() {}
DenseSet<Operation *> set;
std::deque<Operation *> deque;
bool push_back(Operation *op) {
if (set.insert(op).second) {
deque.push_back(op);
return true;
}
return false;
}
Operation *pop_front() {
Operation *op = deque.front();
deque.pop_front();
set.erase(op);
return op;
}
bool empty() { return deque.empty(); }
};
/// DFS post-order implementation that maintains a global count to work across
/// multiple invocations, to help implement topological sort on multi-root DAGs.
/// We traverse all operations but only record the ones that appear in
/// `toSort` for the final result.
struct DFSState {
DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
const SetVector<Operation *> &toSort;
SmallVector<Operation *, 16> topologicalCounts;
DenseSet<Operation *> seen;
/// We mark each op as ready if all its operands are seen. If an op is ready,
/// we add it to the queue. Otherwise, we keep adding its operands to the
/// ancestors set.
void addToReadyQueue(Operation *op, DFSSubgraphState &subGraph,
SmallVector<Operation *, 4> &readyQueue) {
bool ready = true;
for (Value operand : op->getOperands()) {
auto def = operand.getDefiningOp();
if (def && !seen.count(def)) {
subGraph.push_back(def);
ready = false;
}
}
if (ready)
readyQueue.push_back(op);
}
};
void dfsPostorder(Operation *root, DFSState *state) {
DFSSubgraphState subGraph;
subGraph.push_back(root);
SmallVector<Operation *> ops;
while (!subGraph.empty()) {
// Nodes in the ready queue are ready to be processed.
// Meaning that either their operands are all seen or they have null
// operands.
SmallVector<Operation *, 4> readyQueue;
auto *current = subGraph.pop_front();
state->addToReadyQueue(current, subGraph, readyQueue);
while (!readyQueue.empty()) {
Operation *current = readyQueue.pop_back_val();
if (!state->seen.insert(current).second)
continue;
ops.push_back(current);
for (Value result : current->getResults()) {
for (Operation *op : result.getUsers())
state->addToReadyQueue(op, subGraph, readyQueue);
}
for (Region &region : current->getRegions()) {
for (Operation &op : region.getOps())
state->addToReadyQueue(&op, subGraph, readyQueue);
}
}
}
for (Operation *op : llvm::reverse(ops)) {
if (state->toSort.count(op) > 0)
state->topologicalCounts.push_back(op);
}
}
} // namespace
SetVector<Operation *>
multiRootTopologicalSort(const SetVector<Operation *> &toSort) {
if (toSort.empty()) {
return toSort;
}
// Run from each root with global count and `seen` set.
DFSState state(toSort);
for (auto *s : toSort) {
assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
dfsPostorder(s, &state);
}
// Reorder and return.
SetVector<Operation *> res;
for (auto it = state.topologicalCounts.rbegin(),
eit = state.topologicalCounts.rend();
it != eit; ++it) {
res.insert(*it);
}
return res;
}
} // namespace mlir

View File

@@ -376,12 +376,12 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
newOp->getResult(0).setType(newType);
auto typeInfer = dyn_cast<InferTypeOpInterface>(newOp);
if (typeInfer) {
SmallVector<Type, 1> newType;
SmallVector<Type, 1> newTypes;
auto success = typeInfer.inferReturnTypes(
newOp->getContext(), newOp->getLoc(), newOp->getOperands(),
newOp->getAttrDictionary(), newOp->getRegions(), newType);
newOp->getAttrDictionary(), newOp->getRegions(), newTypes);
if (succeeded(success))
newOp->getResult(0).setType(newType.front());
newOp->getResult(0).setType(newTypes.front());
}
return newOp;
}
@@ -652,7 +652,7 @@ public:
else
sortedValues.push_back(v);
}
tmp = mlir::topologicalSort(tmp);
tmp = mlir::multiRootTopologicalSort(tmp);
for (Operation *op : tmp)
sortedValues.push_back(op->getResult(0));

View File

@@ -905,6 +905,13 @@ def kernel_suffix(signature, specialization):
# ------------------------------------------------------------------------------
def parse_mlir_module(path, context):
module = _triton.ir.parse_mlir_module(path, context)
# module takes ownership of the context
module.context = context
return module
def build_triton_ir(fn, signature, specialization, constants):
# canonicalize signature
if isinstance(signature, str):
@@ -1541,9 +1548,9 @@ def compile(fn, **kwargs):
# build compilation stages
stages = {
"ast": (lambda path: fn, None),
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
"ttir": (lambda path: parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
"ttgir": (lambda path: parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
"llir": (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, capability)),

View File

@@ -234,4 +234,480 @@ func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f6
scf.yield %30 : tensor<1x512xf64, #blocked2>
}
return
}
}
// Make sure the following IR doesn't hang the compiler.
// CHECK-LABEL: long_func
func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg13: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg14: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<1.000000e+00> : tensor<1024xf32, #blocked0>
%cst_0 = arith.constant dense<5.000000e-04> : tensor<1024xf32, #blocked0>
%cst_1 = arith.constant dense<0.999499976> : tensor<1024xf32, #blocked0>
%cst_2 = arith.constant dense<1.000000e+04> : tensor<1024xf32, #blocked0>
%cst_3 = arith.constant dense<5000> : tensor<1024xi32, #blocked0>
%cst_4 = arith.constant dense<150> : tensor<1024xi32, #blocked0>
%cst_5 = arith.constant dense<false> : tensor<1024xi1, #blocked0>
%cst_6 = arith.constant dense<2> : tensor<1024xi32, #blocked0>
%cst_7 = arith.constant dense<4999> : tensor<1024xi32, #blocked0>
%cst_8 = arith.constant dense<2499> : tensor<1024xi32, #blocked0>
%cst_9 = arith.constant dense<2500> : tensor<1024xi32, #blocked0>
%cst_10 = arith.constant dense<0.91629076> : tensor<1024xf32, #blocked0>
%c2499_i32 = arith.constant 2499 : i32
%cst_11 = arith.constant dense<1024> : tensor<1024xi32, #blocked0>
%c1024_i32 = arith.constant 1024 : i32
%cst_12 = arith.constant dense<1> : tensor<1024xi32, #blocked0>
%cst_13 = arith.constant dense<0.000000e+00> : tensor<1024xf32, #blocked0>
%cst_14 = arith.constant dense<0> : tensor<1024xi32, #blocked0>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked0>
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked0>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked0>
%5 = "triton_gpu.cmpi"(%4, %cst_11) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%6 = tt.splat %arg5 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked0>
%7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr<f32>, #blocked0>, tensor<1024xi32, #blocked0>
%8 = triton_gpu.convert_layout %7 : (tensor<1024x!tt.ptr<f32>, #blocked0>) -> tensor<1024x!tt.ptr<f32>, #blocked1>
%9 = triton_gpu.convert_layout %5 : (tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked1>
%10 = tt.load %8, %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked1>
%11 = triton_gpu.convert_layout %10 : (tensor<1024xf32, #blocked1>) -> tensor<1024xf32, #blocked0>
%12 = tt.splat %arg7 : (!tt.ptr<i64>) -> tensor<1024x!tt.ptr<i64>, #blocked0>
%13 = tt.addptr %12, %4 : tensor<1024x!tt.ptr<i64>, #blocked0>, tensor<1024xi32, #blocked0>
%14 = triton_gpu.convert_layout %13 : (tensor<1024x!tt.ptr<i64>, #blocked0>) -> tensor<1024x!tt.ptr<i64>, #blocked2>
%15 = triton_gpu.convert_layout %5 : (tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked2>
%16 = tt.load %14, %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xi64, #blocked2>
%17 = triton_gpu.convert_layout %16 : (tensor<1024xi64, #blocked2>) -> tensor<1024xi64, #blocked0>
%18 = tt.splat %arg8 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked0>
%19 = tt.addptr %18, %4 : tensor<1024x!tt.ptr<f32>, #blocked0>, tensor<1024xi32, #blocked0>
%20 = triton_gpu.convert_layout %19 : (tensor<1024x!tt.ptr<f32>, #blocked0>) -> tensor<1024x!tt.ptr<f32>, #blocked1>
%21 = triton_gpu.convert_layout %5 : (tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked1>
%22 = tt.load %20, %21 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked1>
%23 = triton_gpu.convert_layout %22 : (tensor<1024xf32, #blocked1>) -> tensor<1024xf32, #blocked0>
%24 = arith.subf %cst_13, %11 : tensor<1024xf32, #blocked0>
%25 = math.exp %24 : tensor<1024xf32, #blocked0>
%26 = arith.sitofp %cst_12 : tensor<1024xi32, #blocked0> to tensor<1024xf32, #blocked0>
%27 = arith.addf %25, %26 : tensor<1024xf32, #blocked0>
%28 = arith.divf %26, %27 : tensor<1024xf32, #blocked0>
%29 = tt.addptr %arg6, %c2499_i32 : !tt.ptr<f32>, i32
%30 = tt.load %29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32
%31 = arith.subf %11, %cst_10 : tensor<1024xf32, #blocked0>
%32 = arith.subf %cst_13, %31 : tensor<1024xf32, #blocked0>
%33 = math.exp %32 : tensor<1024xf32, #blocked0>
%34 = arith.addf %33, %26 : tensor<1024xf32, #blocked0>
%35 = arith.divf %26, %34 : tensor<1024xf32, #blocked0>
%36 = tt.splat %30 : (f32) -> tensor<1024xf32, #blocked0>
%37 = "triton_gpu.cmpf"(%36, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0>
%38 = "triton_gpu.select"(%37, %cst_14, %cst_9) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%39 = "triton_gpu.select"(%37, %cst_8, %cst_7) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%40 = arith.subi %39, %38 : tensor<1024xi32, #blocked0>
%41 = "triton_gpu.cmpi"(%40, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%42 = "triton_gpu.cmpi"(%41, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%43 = arith.remsi %40, %cst_6 : tensor<1024xi32, #blocked0>
%44 = "triton_gpu.cmpi"(%43, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%45 = arith.divsi %40, %cst_6 : tensor<1024xi32, #blocked0>
%46 = arith.subi %45, %cst_12 : tensor<1024xi32, #blocked0>
%47 = "triton_gpu.select"(%44, %46, %45) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%48 = "triton_gpu.select"(%42, %47, %45) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%49 = arith.addi %38, %48 : tensor<1024xi32, #blocked0>
%50 = "triton_gpu.cmpi"(%38, %39) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%51 = "triton_gpu.select"(%50, %49, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%52 = tt.splat %arg6 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked0>
%53 = tt.addptr %52, %51 : tensor<1024x!tt.ptr<f32>, #blocked0>, tensor<1024xi32, #blocked0>
%54 = triton_gpu.convert_layout %53 : (tensor<1024x!tt.ptr<f32>, #blocked0>) -> tensor<1024x!tt.ptr<f32>, #blocked0>
%55 = tt.load %54 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0>
%56 = "triton_gpu.cmpf"(%55, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0>
%57 = "triton_gpu.cmpi"(%56, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%58 = arith.andi %57, %50 : tensor<1024xi1, #blocked0>
%59 = arith.addi %51, %cst_12 : tensor<1024xi32, #blocked0>
%60 = "triton_gpu.select"(%58, %59, %38) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%61 = arith.andi %56, %50 : tensor<1024xi1, #blocked0>
%62 = "triton_gpu.select"(%61, %51, %39) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%63 = "triton_gpu.cmpi"(%60, %62) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%64 = arith.subi %62, %60 : tensor<1024xi32, #blocked0>
%65 = "triton_gpu.cmpi"(%64, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%66 = "triton_gpu.cmpi"(%65, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%67 = arith.remsi %64, %cst_6 : tensor<1024xi32, #blocked0>
%68 = "triton_gpu.cmpi"(%67, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%69 = arith.divsi %64, %cst_6 : tensor<1024xi32, #blocked0>
%70 = arith.subi %69, %cst_12 : tensor<1024xi32, #blocked0>
%71 = "triton_gpu.select"(%68, %70, %69) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%72 = "triton_gpu.select"(%66, %71, %69) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%73 = arith.addi %60, %72 : tensor<1024xi32, #blocked0>
%74 = "triton_gpu.select"(%63, %73, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%75 = tt.addptr %52, %74 : tensor<1024x!tt.ptr<f32>, #blocked0>, tensor<1024xi32, #blocked0>
%76 = triton_gpu.convert_layout %75 : (tensor<1024x!tt.ptr<f32>, #blocked0>) -> tensor<1024x!tt.ptr<f32>, #blocked0>
%77 = tt.load %76 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0>
%78 = "triton_gpu.cmpf"(%77, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0>
%79 = "triton_gpu.cmpi"(%78, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%80 = arith.andi %79, %63 : tensor<1024xi1, #blocked0>
%81 = arith.addi %74, %cst_12 : tensor<1024xi32, #blocked0>
%82 = "triton_gpu.select"(%80, %81, %60) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%83 = arith.andi %78, %63 : tensor<1024xi1, #blocked0>
%84 = "triton_gpu.select"(%83, %74, %62) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%85 = "triton_gpu.cmpi"(%82, %84) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%86 = arith.subi %84, %82 : tensor<1024xi32, #blocked0>
%87 = "triton_gpu.cmpi"(%86, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%88 = "triton_gpu.cmpi"(%87, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%89 = arith.remsi %86, %cst_6 : tensor<1024xi32, #blocked0>
%90 = "triton_gpu.cmpi"(%89, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%91 = arith.divsi %86, %cst_6 : tensor<1024xi32, #blocked0>
%92 = arith.subi %91, %cst_12 : tensor<1024xi32, #blocked0>
%93 = "triton_gpu.select"(%90, %92, %91) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%94 = "triton_gpu.select"(%88, %93, %91) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%95 = arith.addi %82, %94 : tensor<1024xi32, #blocked0>
%96 = "triton_gpu.select"(%85, %95, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%97 = tt.addptr %52, %96 : tensor<1024x!tt.ptr<f32>, #blocked0>, tensor<1024xi32, #blocked0>
%98 = triton_gpu.convert_layout %97 : (tensor<1024x!tt.ptr<f32>, #blocked0>) -> tensor<1024x!tt.ptr<f32>, #blocked0>
%99 = tt.load %98 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0>
%100 = "triton_gpu.cmpf"(%99, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0>
%101 = "triton_gpu.cmpi"(%100, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%102 = arith.andi %101, %85 : tensor<1024xi1, #blocked0>
%103 = arith.addi %96, %cst_12 : tensor<1024xi32, #blocked0>
%104 = "triton_gpu.select"(%102, %103, %82) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%105 = arith.andi %100, %85 : tensor<1024xi1, #blocked0>
%106 = "triton_gpu.select"(%105, %96, %84) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%107 = "triton_gpu.cmpi"(%104, %106) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%108 = arith.subi %106, %104 : tensor<1024xi32, #blocked0>
%109 = "triton_gpu.cmpi"(%108, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%110 = "triton_gpu.cmpi"(%109, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%111 = arith.remsi %108, %cst_6 : tensor<1024xi32, #blocked0>
%112 = "triton_gpu.cmpi"(%111, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%113 = arith.divsi %108, %cst_6 : tensor<1024xi32, #blocked0>
%114 = arith.subi %113, %cst_12 : tensor<1024xi32, #blocked0>
%115 = "triton_gpu.select"(%112, %114, %113) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%116 = "triton_gpu.select"(%110, %115, %113) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%117 = arith.addi %104, %116 : tensor<1024xi32, #blocked0>
%118 = "triton_gpu.select"(%107, %117, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%119 = tt.addptr %52, %118 : tensor<1024x!tt.ptr<f32>, #blocked0>, tensor<1024xi32, #blocked0>
%120 = triton_gpu.convert_layout %119 : (tensor<1024x!tt.ptr<f32>, #blocked0>) -> tensor<1024x!tt.ptr<f32>, #blocked0>
%121 = tt.load %120 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0>
%122 = "triton_gpu.cmpf"(%121, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0>
%123 = "triton_gpu.cmpi"(%122, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%124 = arith.andi %123, %107 : tensor<1024xi1, #blocked0>
%125 = arith.addi %118, %cst_12 : tensor<1024xi32, #blocked0>
%126 = "triton_gpu.select"(%124, %125, %104) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%127 = arith.andi %122, %107 : tensor<1024xi1, #blocked0>
%128 = "triton_gpu.select"(%127, %118, %106) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%129 = "triton_gpu.cmpi"(%126, %128) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%130 = arith.subi %128, %126 : tensor<1024xi32, #blocked0>
%131 = "triton_gpu.cmpi"(%130, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%132 = "triton_gpu.cmpi"(%131, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%133 = arith.remsi %130, %cst_6 : tensor<1024xi32, #blocked0>
%134 = "triton_gpu.cmpi"(%133, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%135 = arith.divsi %130, %cst_6 : tensor<1024xi32, #blocked0>
%136 = arith.subi %135, %cst_12 : tensor<1024xi32, #blocked0>
%137 = "triton_gpu.select"(%134, %136, %135) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%138 = "triton_gpu.select"(%132, %137, %135) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%139 = arith.addi %126, %138 : tensor<1024xi32, #blocked0>
%140 = "triton_gpu.select"(%129, %139, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%141 = tt.addptr %52, %140 : tensor<1024x!tt.ptr<f32>, #blocked0>, tensor<1024xi32, #blocked0>
%142 = triton_gpu.convert_layout %141 : (tensor<1024x!tt.ptr<f32>, #blocked0>) -> tensor<1024x!tt.ptr<f32>, #blocked0>
%143 = tt.load %142 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0>
%144 = "triton_gpu.cmpf"(%143, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0>
%145 = "triton_gpu.cmpi"(%144, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%146 = arith.andi %145, %129 : tensor<1024xi1, #blocked0>
%147 = arith.addi %140, %cst_12 : tensor<1024xi32, #blocked0>
%148 = "triton_gpu.select"(%146, %147, %126) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%149 = arith.andi %144, %129 : tensor<1024xi1, #blocked0>
%150 = "triton_gpu.select"(%149, %140, %128) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%151 = "triton_gpu.cmpi"(%148, %150) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%152 = arith.subi %150, %148 : tensor<1024xi32, #blocked0>
%153 = "triton_gpu.cmpi"(%152, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%154 = "triton_gpu.cmpi"(%153, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%155 = arith.remsi %152, %cst_6 : tensor<1024xi32, #blocked0>
%156 = "triton_gpu.cmpi"(%155, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%157 = arith.divsi %152, %cst_6 : tensor<1024xi32, #blocked0>
%158 = arith.subi %157, %cst_12 : tensor<1024xi32, #blocked0>
%159 = "triton_gpu.select"(%156, %158, %157) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%160 = "triton_gpu.select"(%154, %159, %157) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%161 = arith.addi %148, %160 : tensor<1024xi32, #blocked0>
%162 = "triton_gpu.select"(%151, %161, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%163 = tt.addptr %52, %162 : tensor<1024x!tt.ptr<f32>, #blocked0>, tensor<1024xi32, #blocked0>
%164 = triton_gpu.convert_layout %163 : (tensor<1024x!tt.ptr<f32>, #blocked0>) -> tensor<1024x!tt.ptr<f32>, #blocked0>
%165 = tt.load %164 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0>
%166 = "triton_gpu.cmpf"(%165, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0>
%167 = "triton_gpu.cmpi"(%166, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%168 = arith.andi %167, %151 : tensor<1024xi1, #blocked0>
%169 = arith.addi %162, %cst_12 : tensor<1024xi32, #blocked0>
%170 = "triton_gpu.select"(%168, %169, %148) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%171 = arith.andi %166, %151 : tensor<1024xi1, #blocked0>
%172 = "triton_gpu.select"(%171, %162, %150) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%173 = "triton_gpu.cmpi"(%170, %172) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%174 = arith.subi %172, %170 : tensor<1024xi32, #blocked0>
%175 = "triton_gpu.cmpi"(%174, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%176 = "triton_gpu.cmpi"(%175, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%177 = arith.remsi %174, %cst_6 : tensor<1024xi32, #blocked0>
%178 = "triton_gpu.cmpi"(%177, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%179 = arith.divsi %174, %cst_6 : tensor<1024xi32, #blocked0>
%180 = arith.subi %179, %cst_12 : tensor<1024xi32, #blocked0>
%181 = "triton_gpu.select"(%178, %180, %179) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%182 = "triton_gpu.select"(%176, %181, %179) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%183 = arith.addi %170, %182 : tensor<1024xi32, #blocked0>
%184 = "triton_gpu.select"(%173, %183, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%185 = tt.addptr %52, %184 : tensor<1024x!tt.ptr<f32>, #blocked0>, tensor<1024xi32, #blocked0>
%186 = triton_gpu.convert_layout %185 : (tensor<1024x!tt.ptr<f32>, #blocked0>) -> tensor<1024x!tt.ptr<f32>, #blocked0>
%187 = tt.load %186 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0>
%188 = "triton_gpu.cmpf"(%187, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0>
%189 = "triton_gpu.cmpi"(%188, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%190 = arith.andi %189, %173 : tensor<1024xi1, #blocked0>
%191 = arith.addi %184, %cst_12 : tensor<1024xi32, #blocked0>
%192 = "triton_gpu.select"(%190, %191, %170) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%193 = arith.andi %188, %173 : tensor<1024xi1, #blocked0>
%194 = "triton_gpu.select"(%193, %184, %172) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%195 = "triton_gpu.cmpi"(%192, %194) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%196 = arith.subi %194, %192 : tensor<1024xi32, #blocked0>
%197 = "triton_gpu.cmpi"(%196, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%198 = "triton_gpu.cmpi"(%197, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%199 = arith.remsi %196, %cst_6 : tensor<1024xi32, #blocked0>
%200 = "triton_gpu.cmpi"(%199, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%201 = arith.divsi %196, %cst_6 : tensor<1024xi32, #blocked0>
%202 = arith.subi %201, %cst_12 : tensor<1024xi32, #blocked0>
%203 = "triton_gpu.select"(%200, %202, %201) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%204 = "triton_gpu.select"(%198, %203, %201) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%205 = arith.addi %192, %204 : tensor<1024xi32, #blocked0>
%206 = "triton_gpu.select"(%195, %205, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%207 = tt.addptr %52, %206 : tensor<1024x!tt.ptr<f32>, #blocked0>, tensor<1024xi32, #blocked0>
%208 = triton_gpu.convert_layout %207 : (tensor<1024x!tt.ptr<f32>, #blocked0>) -> tensor<1024x!tt.ptr<f32>, #blocked0>
%209 = tt.load %208 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0>
%210 = "triton_gpu.cmpf"(%209, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0>
%211 = "triton_gpu.cmpi"(%210, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%212 = arith.andi %211, %195 : tensor<1024xi1, #blocked0>
%213 = arith.addi %206, %cst_12 : tensor<1024xi32, #blocked0>
%214 = "triton_gpu.select"(%212, %213, %192) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%215 = arith.andi %210, %195 : tensor<1024xi1, #blocked0>
%216 = "triton_gpu.select"(%215, %206, %194) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%217 = "triton_gpu.cmpi"(%214, %216) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%218 = arith.subi %216, %214 : tensor<1024xi32, #blocked0>
%219 = "triton_gpu.cmpi"(%218, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%220 = "triton_gpu.cmpi"(%219, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%221 = arith.remsi %218, %cst_6 : tensor<1024xi32, #blocked0>
%222 = "triton_gpu.cmpi"(%221, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%223 = arith.divsi %218, %cst_6 : tensor<1024xi32, #blocked0>
%224 = arith.subi %223, %cst_12 : tensor<1024xi32, #blocked0>
%225 = "triton_gpu.select"(%222, %224, %223) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%226 = "triton_gpu.select"(%220, %225, %223) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%227 = arith.addi %214, %226 : tensor<1024xi32, #blocked0>
%228 = "triton_gpu.select"(%217, %227, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%229 = tt.addptr %52, %228 : tensor<1024x!tt.ptr<f32>, #blocked0>, tensor<1024xi32, #blocked0>
%230 = triton_gpu.convert_layout %229 : (tensor<1024x!tt.ptr<f32>, #blocked0>) -> tensor<1024x!tt.ptr<f32>, #blocked0>
%231 = tt.load %230 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0>
%232 = "triton_gpu.cmpf"(%231, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0>
%233 = "triton_gpu.cmpi"(%232, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%234 = arith.andi %233, %217 : tensor<1024xi1, #blocked0>
%235 = arith.addi %228, %cst_12 : tensor<1024xi32, #blocked0>
%236 = "triton_gpu.select"(%234, %235, %214) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%237 = arith.andi %232, %217 : tensor<1024xi1, #blocked0>
%238 = "triton_gpu.select"(%237, %228, %216) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%239 = "triton_gpu.cmpi"(%236, %238) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%240 = arith.subi %238, %236 : tensor<1024xi32, #blocked0>
%241 = "triton_gpu.cmpi"(%240, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%242 = "triton_gpu.cmpi"(%241, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%243 = arith.remsi %240, %cst_6 : tensor<1024xi32, #blocked0>
%244 = "triton_gpu.cmpi"(%243, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%245 = arith.divsi %240, %cst_6 : tensor<1024xi32, #blocked0>
%246 = arith.subi %245, %cst_12 : tensor<1024xi32, #blocked0>
%247 = "triton_gpu.select"(%244, %246, %245) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%248 = "triton_gpu.select"(%242, %247, %245) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%249 = arith.addi %236, %248 : tensor<1024xi32, #blocked0>
%250 = "triton_gpu.select"(%239, %249, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%251 = tt.addptr %52, %250 : tensor<1024x!tt.ptr<f32>, #blocked0>, tensor<1024xi32, #blocked0>
%252 = triton_gpu.convert_layout %251 : (tensor<1024x!tt.ptr<f32>, #blocked0>) -> tensor<1024x!tt.ptr<f32>, #blocked0>
%253 = tt.load %252 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0>
%254 = "triton_gpu.cmpf"(%253, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0>
%255 = "triton_gpu.cmpi"(%254, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%256 = arith.andi %255, %239 : tensor<1024xi1, #blocked0>
%257 = arith.addi %250, %cst_12 : tensor<1024xi32, #blocked0>
%258 = "triton_gpu.select"(%256, %257, %236) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%259 = arith.andi %254, %239 : tensor<1024xi1, #blocked0>
%260 = "triton_gpu.select"(%259, %250, %238) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%261 = "triton_gpu.cmpi"(%258, %260) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%262 = arith.subi %260, %258 : tensor<1024xi32, #blocked0>
%263 = "triton_gpu.cmpi"(%262, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%264 = "triton_gpu.cmpi"(%263, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%265 = arith.remsi %262, %cst_6 : tensor<1024xi32, #blocked0>
%266 = "triton_gpu.cmpi"(%265, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%267 = arith.divsi %262, %cst_6 : tensor<1024xi32, #blocked0>
%268 = arith.subi %267, %cst_12 : tensor<1024xi32, #blocked0>
%269 = "triton_gpu.select"(%266, %268, %267) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%270 = "triton_gpu.select"(%264, %269, %267) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%271 = arith.addi %258, %270 : tensor<1024xi32, #blocked0>
%272 = "triton_gpu.select"(%261, %271, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%273 = tt.addptr %52, %272 : tensor<1024x!tt.ptr<f32>, #blocked0>, tensor<1024xi32, #blocked0>
%274 = triton_gpu.convert_layout %273 : (tensor<1024x!tt.ptr<f32>, #blocked0>) -> tensor<1024x!tt.ptr<f32>, #blocked0>
%275 = tt.load %274 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0>
%276 = "triton_gpu.cmpf"(%275, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0>
%277 = "triton_gpu.cmpi"(%276, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%278 = arith.andi %277, %261 : tensor<1024xi1, #blocked0>
%279 = arith.addi %272, %cst_12 : tensor<1024xi32, #blocked0>
%280 = "triton_gpu.select"(%278, %279, %258) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%281 = arith.andi %276, %261 : tensor<1024xi1, #blocked0>
%282 = "triton_gpu.select"(%281, %272, %260) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%283 = "triton_gpu.cmpi"(%280, %282) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%284 = arith.subi %282, %280 : tensor<1024xi32, #blocked0>
%285 = "triton_gpu.cmpi"(%284, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%286 = "triton_gpu.cmpi"(%285, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%287 = arith.remsi %284, %cst_6 : tensor<1024xi32, #blocked0>
%288 = "triton_gpu.cmpi"(%287, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0>
%289 = arith.divsi %284, %cst_6 : tensor<1024xi32, #blocked0>
%290 = arith.subi %289, %cst_12 : tensor<1024xi32, #blocked0>
%291 = "triton_gpu.select"(%288, %290, %289) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%292 = "triton_gpu.select"(%286, %291, %289) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%293 = arith.addi %280, %292 : tensor<1024xi32, #blocked0>
%294 = "triton_gpu.select"(%283, %293, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%295 = tt.addptr %52, %294 : tensor<1024x!tt.ptr<f32>, #blocked0>, tensor<1024xi32, #blocked0>
%296 = triton_gpu.convert_layout %295 : (tensor<1024x!tt.ptr<f32>, #blocked0>) -> tensor<1024x!tt.ptr<f32>, #blocked0>
%297 = tt.load %296 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0>
%298 = "triton_gpu.cmpf"(%297, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0>
%299 = "triton_gpu.cmpi"(%298, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0>
%300 = arith.andi %299, %283 : tensor<1024xi1, #blocked0>
%301 = arith.addi %294, %cst_12 : tensor<1024xi32, #blocked0>
%302 = "triton_gpu.select"(%300, %301, %280) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0>
%303 = arith.extsi %cst_12 : tensor<1024xi32, #blocked0> to tensor<1024xi64, #blocked0>
%304 = "triton_gpu.cmpi"(%17, %303) {predicate = 0 : i64} : (tensor<1024xi64, #blocked0>, tensor<1024xi64, #blocked0>) -> tensor<1024xi1, #blocked0>
%305 = arith.fptosi %23 : tensor<1024xf32, #blocked0> to tensor<1024xi64, #blocked0>
%306 = arith.extsi %cst_14 : tensor<1024xi32, #blocked0> to tensor<1024xi64, #blocked0>
%307 = "triton_gpu.cmpi"(%306, %305) {predicate = 4 : i64} : (tensor<1024xi64, #blocked0>, tensor<1024xi64, #blocked0>) -> tensor<1024xi1, #blocked0>
%308 = arith.extsi %cst_4 : tensor<1024xi32, #blocked0> to tensor<1024xi64, #blocked0>
%309 = "triton_gpu.cmpi"(%305, %308) {predicate = 4 : i64} : (tensor<1024xi64, #blocked0>, tensor<1024xi64, #blocked0>) -> tensor<1024xi1, #blocked0>
%310 = "triton_gpu.select"(%309, %306, %305) : (tensor<1024xi1, #blocked0>, tensor<1024xi64, #blocked0>, tensor<1024xi64, #blocked0>) -> tensor<1024xi64, #blocked0>
%311 = "triton_gpu.select"(%307, %306, %310) : (tensor<1024xi1, #blocked0>, tensor<1024xi64, #blocked0>, tensor<1024xi64, #blocked0>) -> tensor<1024xi64, #blocked0>
%312 = "triton_gpu.select"(%304, %311, %306) : (tensor<1024xi1, #blocked0>, tensor<1024xi64, #blocked0>, tensor<1024xi64, #blocked0>) -> tensor<1024xi64, #blocked0>
%313 = arith.extsi %cst_3 : tensor<1024xi32, #blocked0> to tensor<1024xi64, #blocked0>
%314 = arith.muli %312, %313 : tensor<1024xi64, #blocked0>
%315 = arith.extsi %302 : tensor<1024xi32, #blocked0> to tensor<1024xi64, #blocked0>
%316 = arith.addi %315, %314 : tensor<1024xi64, #blocked0>
%317 = arith.trunci %316 : tensor<1024xi64, #blocked0> to tensor<1024xi32, #blocked0>
%318 = arith.extsi %317 : tensor<1024xi32, #blocked0> to tensor<1024xi64, #blocked0>
%319 = tt.splat %arg9 : (!tt.ptr<f64>) -> tensor<1024x!tt.ptr<f64>, #blocked0>
%320 = tt.addptr %319, %318 : tensor<1024x!tt.ptr<f64>, #blocked0>, tensor<1024xi64, #blocked0>
%321 = triton_gpu.convert_layout %320 : (tensor<1024x!tt.ptr<f64>, #blocked0>) -> tensor<1024x!tt.ptr<f64>, #blocked0>
%322 = tt.load %321 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf64, #blocked0>
%323 = arith.extf %cst_2 : tensor<1024xf32, #blocked0> to tensor<1024xf64, #blocked0>
%324 = "triton_gpu.cmpf"(%322, %323) {predicate = 2 : i64} : (tensor<1024xf64, #blocked0>, tensor<1024xf64, #blocked0>) -> tensor<1024xi1, #blocked0>
%325 = tt.splat %arg10 : (!tt.ptr<f64>) -> tensor<1024x!tt.ptr<f64>, #blocked0>
%326 = tt.addptr %325, %318 : tensor<1024x!tt.ptr<f64>, #blocked0>, tensor<1024xi64, #blocked0>
%327 = triton_gpu.convert_layout %326 : (tensor<1024x!tt.ptr<f64>, #blocked0>) -> tensor<1024x!tt.ptr<f64>, #blocked0>
%328 = tt.load %327 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf64, #blocked0>
%329 = arith.divf %328, %322 : tensor<1024xf64, #blocked0>
%330 = arith.truncf %329 : tensor<1024xf64, #blocked0> to tensor<1024xf32, #blocked0>
%331 = arith.mulf %330, %cst_1 : tensor<1024xf32, #blocked0>
%332 = arith.mulf %35, %cst_0 : tensor<1024xf32, #blocked0>
%333 = arith.addf %331, %332 : tensor<1024xf32, #blocked0>
%334 = "triton_gpu.select"(%324, %333, %35) : (tensor<1024xi1, #blocked0>, tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xf32, #blocked0>
%335 = tt.addptr %319, %317 : tensor<1024x!tt.ptr<f64>, #blocked0>, tensor<1024xi32, #blocked0>
%336 = triton_gpu.convert_layout %335 : (tensor<1024x!tt.ptr<f64>, #blocked0>) -> tensor<1024x!tt.ptr<f64>, #blocked0>
%337 = tt.load %336 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf64, #blocked0>
%338 = arith.extf %cst : tensor<1024xf32, #blocked0> to tensor<1024xf64, #blocked0>
%339 = arith.mulf %337, %338 : tensor<1024xf64, #blocked0>
%340 = tt.addptr %325, %317 : tensor<1024x!tt.ptr<f64>, #blocked0>, tensor<1024xi32, #blocked0>
%341 = triton_gpu.convert_layout %340 : (tensor<1024x!tt.ptr<f64>, #blocked0>) -> tensor<1024x!tt.ptr<f64>, #blocked0>
%342 = tt.load %341 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf64, #blocked0>
%343 = arith.mulf %342, %338 : tensor<1024xf64, #blocked0>
%344 = tt.splat %arg11 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked0>
%345 = tt.addptr %344, %4 : tensor<1024x!tt.ptr<f32>, #blocked0>, tensor<1024xi32, #blocked0>
%346 = triton_gpu.convert_layout %345 : (tensor<1024x!tt.ptr<f32>, #blocked0>) -> tensor<1024x!tt.ptr<f32>, #blocked1>
%347 = triton_gpu.convert_layout %28 : (tensor<1024xf32, #blocked0>) -> tensor<1024xf32, #blocked1>
%348 = triton_gpu.convert_layout %5 : (tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked1>
tt.store %346, %347, %348 : tensor<1024xf32, #blocked1>
%349 = tt.splat %arg12 : (!tt.ptr<i32>) -> tensor<1024x!tt.ptr<i32>, #blocked0>
%350 = tt.addptr %349, %4 : tensor<1024x!tt.ptr<i32>, #blocked0>, tensor<1024xi32, #blocked0>
%351 = triton_gpu.convert_layout %350 : (tensor<1024x!tt.ptr<i32>, #blocked0>) -> tensor<1024x!tt.ptr<i32>, #blocked1>
%352 = triton_gpu.convert_layout %317 : (tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked1>
%353 = triton_gpu.convert_layout %5 : (tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked1>
tt.store %351, %352, %353 : tensor<1024xi32, #blocked1>
%354 = tt.splat %arg13 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked0>
%355 = tt.addptr %354, %4 : tensor<1024x!tt.ptr<f32>, #blocked0>, tensor<1024xi32, #blocked0>
%356 = triton_gpu.convert_layout %355 : (tensor<1024x!tt.ptr<f32>, #blocked0>) -> tensor<1024x!tt.ptr<f32>, #blocked1>
%357 = triton_gpu.convert_layout %334 : (tensor<1024xf32, #blocked0>) -> tensor<1024xf32, #blocked1>
%358 = triton_gpu.convert_layout %5 : (tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked1>
tt.store %356, %357, %358 : tensor<1024xf32, #blocked1>
%359 = tt.splat %arg14 : (!tt.ptr<f64>) -> tensor<1024x!tt.ptr<f64>, #blocked0>
%360 = tt.addptr %359, %318 : tensor<1024x!tt.ptr<f64>, #blocked0>, tensor<1024xi64, #blocked0>
%361 = triton_gpu.convert_layout %360 : (tensor<1024x!tt.ptr<f64>, #blocked0>) -> tensor<1024x!tt.ptr<f64>, #blocked0>
%362 = triton_gpu.convert_layout %339 : (tensor<1024xf64, #blocked0>) -> tensor<1024xf64, #blocked0>
tt.store %361, %362 : tensor<1024xf64, #blocked0>
%363 = tt.splat %arg15 : (!tt.ptr<f64>) -> tensor<1024x!tt.ptr<f64>, #blocked0>
%364 = tt.addptr %363, %318 : tensor<1024x!tt.ptr<f64>, #blocked0>, tensor<1024xi64, #blocked0>
%365 = triton_gpu.convert_layout %364 : (tensor<1024x!tt.ptr<f64>, #blocked0>) -> tensor<1024x!tt.ptr<f64>, #blocked0>
%366 = triton_gpu.convert_layout %343 : (tensor<1024xf64, #blocked0>) -> tensor<1024xf64, #blocked0>
tt.store %365, %366 : tensor<1024xf64, #blocked0>
return
}
// A mnist model from torch inductor.
// Check if topological sort is working correct and there's no unnecessary convert
// CHECK-LABEL: mnist
func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) {
// CHECK-NOT: triton_gpu.convert_layout
%cst = arith.constant dense<10> : tensor<16x1xi32, #blocked2>
%cst_0 = arith.constant dense<10> : tensor<1x16xi32, #blocked3>
%c16_i32 = arith.constant 16 : i32
%cst_1 = arith.constant dense<64> : tensor<16x1xi32, #blocked2>
%cst_2 = arith.constant dense<0xFF800000> : tensor<16x16xf32, #blocked2>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked2>
%cst_4 = arith.constant dense<0> : tensor<16x16xi32, #blocked2>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c16_i32 : i32
%2 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked0>
%3 = triton_gpu.convert_layout %2 : (tensor<16xi32, #blocked0>) -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%4 = tt.expand_dims %3 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<16x1xi32, #blocked1>
%5 = triton_gpu.convert_layout %4 : (tensor<16x1xi32, #blocked1>) -> tensor<16x1xi32, #blocked2>
%6 = tt.splat %1 : (i32) -> tensor<16x1xi32, #blocked2>
%7 = arith.addi %6, %5 : tensor<16x1xi32, #blocked2>
%8 = "triton_gpu.cmpi"(%7, %cst_1) {predicate = 2 : i64} : (tensor<16x1xi32, #blocked2>, tensor<16x1xi32, #blocked2>) -> tensor<16x1xi1, #blocked2>
%9 = triton_gpu.convert_layout %2 : (tensor<16xi32, #blocked0>) -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>
%10 = tt.expand_dims %9 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>) -> tensor<1x16xi32, #blocked3>
%11 = "triton_gpu.cmpi"(%10, %cst_0) {predicate = 2 : i64} : (tensor<1x16xi32, #blocked3>, tensor<1x16xi32, #blocked3>) -> tensor<1x16xi1, #blocked3>
%12 = arith.muli %7, %cst : tensor<16x1xi32, #blocked2>
%13 = tt.broadcast %10 : (tensor<1x16xi32, #blocked3>) -> tensor<16x16xi32, #blocked3>
%14 = triton_gpu.convert_layout %13 : (tensor<16x16xi32, #blocked3>) -> tensor<16x16xi32, #blocked2>
%15 = tt.broadcast %12 : (tensor<16x1xi32, #blocked2>) -> tensor<16x16xi32, #blocked2>
%16 = arith.addi %14, %15 : tensor<16x16xi32, #blocked2>
%17 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<16x16x!tt.ptr<f32>, #blocked2>
%18 = tt.addptr %17, %16 : tensor<16x16x!tt.ptr<f32>, #blocked2>, tensor<16x16xi32, #blocked2>
%19 = tt.broadcast %11 : (tensor<1x16xi1, #blocked3>) -> tensor<16x16xi1, #blocked3>
%20 = triton_gpu.convert_layout %19 : (tensor<16x16xi1, #blocked3>) -> tensor<16x16xi1, #blocked2>
%21 = tt.broadcast %8 : (tensor<16x1xi1, #blocked2>) -> tensor<16x16xi1, #blocked2>
%22 = arith.andi %20, %21 : tensor<16x16xi1, #blocked2>
%23 = triton_gpu.convert_layout %18 : (tensor<16x16x!tt.ptr<f32>, #blocked2>) -> tensor<16x16x!tt.ptr<f32>, #blocked4>
%24 = triton_gpu.convert_layout %22 : (tensor<16x16xi1, #blocked2>) -> tensor<16x16xi1, #blocked4>
%25 = tt.load %23, %24 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<16x16xf32, #blocked4>
%26 = triton_gpu.convert_layout %25 : (tensor<16x16xf32, #blocked4>) -> tensor<16x16xf32, #blocked2>
%27 = "triton_gpu.cmpf"(%cst_2, %26) {predicate = 4 : i64} : (tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xi1, #blocked2>
%28 = arith.andi %22, %27 : tensor<16x16xi1, #blocked2>
%29 = "triton_gpu.select"(%28, %26, %cst_2) : (tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked2>
%30 = tt.reduce %29 {axis = 1 : i32, redOp = 12 : i32} : tensor<16x16xf32, #blocked2> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%31 = triton_gpu.convert_layout %30 : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16xf32, #blocked0>
%32 = triton_gpu.convert_layout %31 : (tensor<16xf32, #blocked0>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%33 = tt.expand_dims %32 {axis = 1 : i32} : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<16x1xf32, #blocked1>
%34 = triton_gpu.convert_layout %33 : (tensor<16x1xf32, #blocked1>) -> tensor<16x1xf32, #blocked2>
%35 = arith.sitofp %cst_4 : tensor<16x16xi32, #blocked2> to tensor<16x16xf32, #blocked2>
%36 = arith.addf %35, %cst_3 : tensor<16x16xf32, #blocked2>
%37 = triton_gpu.convert_layout %18 : (tensor<16x16x!tt.ptr<f32>, #blocked2>) -> tensor<16x16x!tt.ptr<f32>, #blocked4>
%38 = triton_gpu.convert_layout %22 : (tensor<16x16xi1, #blocked2>) -> tensor<16x16xi1, #blocked4>
%39 = tt.load %37, %38 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<16x16xf32, #blocked4>
%40 = triton_gpu.convert_layout %39 : (tensor<16x16xf32, #blocked4>) -> tensor<16x16xf32, #blocked2>
%41 = tt.broadcast %34 : (tensor<16x1xf32, #blocked2>) -> tensor<16x16xf32, #blocked2>
%42 = arith.subf %40, %41 : tensor<16x16xf32, #blocked2>
%43 = math.exp %42 : tensor<16x16xf32, #blocked2>
%44 = arith.addf %36, %43 : tensor<16x16xf32, #blocked2>
%45 = "triton_gpu.select"(%22, %44, %36) : (tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked2>
%46 = tt.reduce %45 {axis = 1 : i32, redOp = 2 : i32} : tensor<16x16xf32, #blocked2> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%47 = triton_gpu.convert_layout %46 : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16xf32, #blocked0>
%48 = triton_gpu.convert_layout %47 : (tensor<16xf32, #blocked0>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%49 = tt.expand_dims %48 {axis = 1 : i32} : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<16x1xf32, #blocked1>
%50 = triton_gpu.convert_layout %49 : (tensor<16x1xf32, #blocked1>) -> tensor<16x1xf32, #blocked2>
%51 = triton_gpu.convert_layout %18 : (tensor<16x16x!tt.ptr<f32>, #blocked2>) -> tensor<16x16x!tt.ptr<f32>, #blocked4>
%52 = triton_gpu.convert_layout %22 : (tensor<16x16xi1, #blocked2>) -> tensor<16x16xi1, #blocked4>
%53 = tt.load %51, %52 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<16x16xf32, #blocked4>
%54 = triton_gpu.convert_layout %53 : (tensor<16x16xf32, #blocked4>) -> tensor<16x16xf32, #blocked2>
%55 = arith.subf %54, %41 : tensor<16x16xf32, #blocked2>
%56 = math.log %50 : tensor<16x1xf32, #blocked2>
%57 = tt.broadcast %56 : (tensor<16x1xf32, #blocked2>) -> tensor<16x16xf32, #blocked2>
%58 = arith.subf %55, %57 : tensor<16x16xf32, #blocked2>
%59 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<16x16x!tt.ptr<f32>, #blocked2>
%60 = tt.addptr %59, %16 : tensor<16x16x!tt.ptr<f32>, #blocked2>, tensor<16x16xi32, #blocked2>
%61 = triton_gpu.convert_layout %60 : (tensor<16x16x!tt.ptr<f32>, #blocked2>) -> tensor<16x16x!tt.ptr<f32>, #blocked4>
%62 = triton_gpu.convert_layout %58 : (tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked4>
%63 = triton_gpu.convert_layout %22 : (tensor<16x16xi1, #blocked2>) -> tensor<16x16xi1, #blocked4>
tt.store %61, %62, %63 : tensor<16x16xf32, #blocked4>
return
}