mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND][BACKEND] Hardened get_program_id axis by making it an enum attribute (#1721)
Also catch out-of-bounds indices at constructio and throw a proper error
in the frontend.
Finally, let's make the IR a bit prettier:
%0 = tt.get_program_id {axis = 0 : i32} : i32
becomes:
%0 = tt.get_program_id x : i32
Fixes #1718
This commit is contained in:
@@ -52,4 +52,16 @@ def TT_AtomicRMWAttr : I32EnumAttr<
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
// Program ID dimensions.
|
||||
def TT_ProgramDim : I32EnumAttr<
|
||||
"ProgramIDDim", "",
|
||||
[
|
||||
I32EnumAttrCase<"X", 0, "x">,
|
||||
I32EnumAttrCase<"Y", 1, "y">,
|
||||
I32EnumAttrCase<"Z", 2, "z">,
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
@@ -351,11 +351,17 @@ def TT_TransOp : TT_Op<"trans", [Pure,
|
||||
// SPMD Ops
|
||||
//
|
||||
def TT_GetProgramIdOp : TT_Op<"get_program_id", [Pure]> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
let arguments = (ins TT_ProgramDim:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
let assemblyFormat = "$axis attr-dict `:` type($result)";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
int32_t getAxisAsInt() {
|
||||
return static_cast<int32_t>(getAxis());
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> {
|
||||
|
||||
@@ -384,10 +384,10 @@ struct GetProgramIdOpConversion
|
||||
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
assert(op.getAxis() < 3);
|
||||
assert(op.getAxisAsInt() < 3);
|
||||
|
||||
Value blockId =
|
||||
rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxis()]);
|
||||
rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxisAsInt()]);
|
||||
rewriter.replaceOpWithNewOp<arith::TruncIOp>(op, i32_ty, blockId);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1325,8 +1325,12 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("create_get_program_id",
|
||||
[](mlir::OpBuilder &self, int axis) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
if (axis < 0 || axis > 3)
|
||||
throw std::runtime_error("program_id must be in [0,3]");
|
||||
return self.create<mlir::triton::GetProgramIdOp>(
|
||||
loc, self.getI32Type(), self.getI32IntegerAttr(axis));
|
||||
loc, self.getI32Type(),
|
||||
mlir::triton::ProgramIDDimAttr::get(
|
||||
loc.getContext(), mlir::triton::ProgramIDDim(axis)));
|
||||
})
|
||||
.def("create_get_num_programs",
|
||||
[](mlir::OpBuilder &self, int axis) -> mlir::Value {
|
||||
|
||||
@@ -563,6 +563,20 @@ def test_expand_dims_error_cases():
|
||||
duplicate_dim2[(1,)](dummy_tensor, N)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# test invalid program id axis
|
||||
# ----------------------------
|
||||
def test_invalid_pid_axis():
|
||||
dst = torch.empty(128, device='cuda')
|
||||
|
||||
@triton.jit
|
||||
def _kernel(dst):
|
||||
pid = tl.program_id(20)
|
||||
|
||||
with pytest.raises(triton.CompilationError, match=r"program_id must be in \[0,3\]"):
|
||||
_kernel[(1,)](dst)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test where
|
||||
# ---------------
|
||||
|
||||
@@ -406,7 +406,7 @@ tt.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32
|
||||
// CHECK-LABEL: @store_constant_align
|
||||
tt.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%pid = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%pid = tt.get_program_id x : i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128
|
||||
%c128_i32 = arith.constant 128 : i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = <none>
|
||||
@@ -438,7 +438,7 @@ tt.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
// CHECK-LABEL: @vecadd_mask_align_16
|
||||
tt.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c64_i32 : i32
|
||||
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<64xi32>
|
||||
@@ -467,7 +467,7 @@ tt.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
// CHECK-LABEL: @vecadd_mask_align_1
|
||||
tt.func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c64_i32 : i32
|
||||
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<64xi32>
|
||||
|
||||
@@ -86,7 +86,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
// CHECK-LABEL: global_load_store_no_vec
|
||||
tt.func @global_load_store_no_vec(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
|
||||
@@ -138,7 +138,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
// CHECK-LABEL: global_load_store_vec4
|
||||
tt.func @global_load_store_vec4(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
|
||||
@@ -175,7 +175,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
tt.func @vecadd_masked_vec1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c64_i32 : i32
|
||||
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<64xi32, #blocked>
|
||||
@@ -205,7 +205,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: global_load_store_vec2
|
||||
tt.func @global_load_store_vec2(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg3: i32) {
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
|
||||
@@ -250,7 +250,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: global_load_store_vec8
|
||||
tt.func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
|
||||
@@ -357,7 +357,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_program_id
|
||||
tt.func @basic_program_id() {
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.x : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
@@ -1066,9 +1066,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
tt.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
%blockidx = tt.get_program_id {axis=0:i32} : i32
|
||||
%blockidy = tt.get_program_id {axis=1:i32} : i32
|
||||
%blockidz = tt.get_program_id {axis=2:i32} : i32
|
||||
%blockidx = tt.get_program_id x : i32
|
||||
%blockidy = tt.get_program_id y : i32
|
||||
%blockidz = tt.get_program_id z : i32
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.x
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.y
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.z
|
||||
|
||||
@@ -10,8 +10,8 @@ tt.func public @matmul_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
||||
%c32_i32 = arith.constant 32 : i32
|
||||
%c128_i32 = arith.constant 128 : i32
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = tt.get_program_id {axis = 1 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = tt.get_program_id y : i32
|
||||
%2 = arith.addi %arg3, %c127_i32 : i32
|
||||
%3 = arith.divsi %2, %c128_i32 : i32
|
||||
%4 = arith.addi %arg4, %c31_i32 : i32
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
module {
|
||||
tt.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
|
||||
@@ -49,7 +49,7 @@ module {
|
||||
// %c0 = arith.constant 0 : index
|
||||
// %cst = arith.constant 0.000000e+00 : f32
|
||||
// %c256_i32 = arith.constant 256 : i32
|
||||
// %0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
// %0 = tt.get_program_id x : i32
|
||||
// %1 = arith.muli %0, %c256_i32 : i32
|
||||
// %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
|
||||
@@ -86,7 +86,7 @@ tt.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
tt.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout1>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout1>
|
||||
%2 = arith.muli %1, %c32_i32 : tensor<1024xi32, #layout1>
|
||||
%3 = arith.addi %2, %c32_i32 : tensor<1024xi32, #layout1>
|
||||
@@ -102,7 +102,7 @@ tt.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
// CHECK-LABEL: if_convert_else_not
|
||||
tt.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
|
||||
%9 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout1>
|
||||
%2 = arith.muli %1, %c32_i32 : tensor<1024xi32, #layout0>
|
||||
@@ -123,7 +123,7 @@ tt.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility =
|
||||
// CHECK-LABEL: if_not_else_convert
|
||||
tt.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
|
||||
%9 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout1>
|
||||
%2 = arith.muli %1, %c32_i32 : tensor<1024xi32, #layout0>
|
||||
@@ -144,7 +144,7 @@ tt.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility =
|
||||
// CHECK-LABEL: if_else_both_convert
|
||||
tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
|
||||
%2 = arith.muli %1, %c32_i32 : tensor<1024xi32, #layout0>
|
||||
%3 = arith.addi %2, %c32_i32 : tensor<1024xi32, #layout0>
|
||||
@@ -323,7 +323,7 @@ tt.func @loop_if(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i3
|
||||
tt.func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
%2 = tt.splat %1 : (i32) -> tensor<256xi32, #layout1>
|
||||
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #layout1>
|
||||
@@ -361,7 +361,7 @@ tt.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst_1 = arith.constant dense<2048> : tensor<1x1xi32, #blocked2>
|
||||
%cst_2 = arith.constant dense<0.000000e+00> : tensor<1x512xf64, #blocked2>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked0>
|
||||
%2 = triton_gpu.convert_layout %1 : (tensor<1xi32, #blocked0>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<1x1xi32, #blocked1>
|
||||
@@ -422,7 +422,7 @@ tt.func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg
|
||||
%cst_12 = arith.constant dense<1> : tensor<1024xi32, #blocked0>
|
||||
%cst_13 = arith.constant dense<0.000000e+00> : tensor<1024xf32, #blocked0>
|
||||
%cst_14 = arith.constant dense<0> : tensor<1024xi32, #blocked0>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c1024_i32 : i32
|
||||
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked0>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked0>
|
||||
@@ -809,7 +809,7 @@ tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !
|
||||
%cst_2 = arith.constant dense<0xFF800000> : tensor<16x16xf32, #blocked2>
|
||||
%cst_3 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked2>
|
||||
%cst_4 = arith.constant dense<0> : tensor<16x16xi32, #blocked2>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c16_i32 : i32
|
||||
%2 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked0>
|
||||
%3 = triton_gpu.convert_layout %2 : (tensor<16xi32, #blocked0>) -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
@@ -908,7 +908,7 @@ tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt
|
||||
%cst_4 = arith.constant dense<2048> : tensor<64x1xi32, #blocked2>
|
||||
%cst_5 = arith.constant dense<49152> : tensor<64x1xi32, #blocked2>
|
||||
%cst_6 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked2>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c64_i32 : i32
|
||||
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0>
|
||||
%3 = triton_gpu.convert_layout %2 : (tensor<64xi32, #blocked0>) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
@@ -1044,7 +1044,7 @@ tt.func public @if_no_tensor(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %
|
||||
%c-1_i64 = arith.constant -1 : i64
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%c-1_i32 = arith.constant -1 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = tt.addptr %arg3, %0 : !tt.ptr<i64>, i32
|
||||
%2 = tt.load %1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : i64
|
||||
%3 = arith.cmpi eq, %2, %c-1_i64 : i64
|
||||
@@ -1127,7 +1127,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%cst_3 = arith.constant dense<196> : tensor<1x256xi32, #blocked>
|
||||
%cst_4 = arith.constant dense<3136> : tensor<1x256xi32, #blocked>
|
||||
%cst_5 = arith.constant dense<256> : tensor<1x1xi32, #blocked>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked1>
|
||||
%2 = triton_gpu.convert_layout %1 : (tensor<1xi32, #blocked1>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<1x1xi32, #blocked2>
|
||||
|
||||
@@ -12,7 +12,7 @@ tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__1
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%c63_i32 = arith.constant 63 : i32
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.addi %arg3, %c63_i32 : i32
|
||||
%2 = arith.divsi %1, %c64_i32 : i32
|
||||
%3 = arith.addi %arg4, %c63_i32 : i32
|
||||
|
||||
Reference in New Issue
Block a user