[FRONTEND][BACKEND] Hardened get_program_id axis by making it an enum attribute (#1721)

Also catch out-of-bounds indices at constructio and throw a proper error
in the frontend.
Finally, let's make the IR a bit prettier:

  %0 = tt.get_program_id {axis = 0 : i32} : i32

becomes:

  %0 = tt.get_program_id x : i32

Fixes #1718
This commit is contained in:
Mehdi Amini
2023-05-31 22:49:46 -07:00
committed by GitHub
parent 19c65d6007
commit b0c893cdc5
11 changed files with 69 additions and 33 deletions

View File

@@ -52,4 +52,16 @@ def TT_AtomicRMWAttr : I32EnumAttr<
let cppNamespace = "::mlir::triton";
}
// Program ID dimensions.
def TT_ProgramDim : I32EnumAttr<
"ProgramIDDim", "",
[
I32EnumAttrCase<"X", 0, "x">,
I32EnumAttrCase<"Y", 1, "y">,
I32EnumAttrCase<"Z", 2, "z">,
]> {
let cppNamespace = "::mlir::triton";
}
#endif

View File

@@ -351,11 +351,17 @@ def TT_TransOp : TT_Op<"trans", [Pure,
// SPMD Ops
//
def TT_GetProgramIdOp : TT_Op<"get_program_id", [Pure]> {
let arguments = (ins I32Attr:$axis);
let arguments = (ins TT_ProgramDim:$axis);
let results = (outs I32:$result);
let assemblyFormat = "attr-dict `:` type($result)";
let assemblyFormat = "$axis attr-dict `:` type($result)";
let extraClassDeclaration = [{
int32_t getAxisAsInt() {
return static_cast<int32_t>(getAxis());
}
}];
}
def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> {

View File

@@ -384,10 +384,10 @@ struct GetProgramIdOpConversion
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
assert(op.getAxis() < 3);
assert(op.getAxisAsInt() < 3);
Value blockId =
rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxis()]);
rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxisAsInt()]);
rewriter.replaceOpWithNewOp<arith::TruncIOp>(op, i32_ty, blockId);
return success();
}

View File

@@ -1325,8 +1325,12 @@ void init_triton_ir(py::module &&m) {
.def("create_get_program_id",
[](mlir::OpBuilder &self, int axis) -> mlir::Value {
auto loc = self.getUnknownLoc();
if (axis < 0 || axis > 3)
throw std::runtime_error("program_id must be in [0,3]");
return self.create<mlir::triton::GetProgramIdOp>(
loc, self.getI32Type(), self.getI32IntegerAttr(axis));
loc, self.getI32Type(),
mlir::triton::ProgramIDDimAttr::get(
loc.getContext(), mlir::triton::ProgramIDDim(axis)));
})
.def("create_get_num_programs",
[](mlir::OpBuilder &self, int axis) -> mlir::Value {

View File

@@ -563,6 +563,20 @@ def test_expand_dims_error_cases():
duplicate_dim2[(1,)](dummy_tensor, N)
# ----------------------------
# test invalid program id axis
# ----------------------------
def test_invalid_pid_axis():
dst = torch.empty(128, device='cuda')
@triton.jit
def _kernel(dst):
pid = tl.program_id(20)
with pytest.raises(triton.CompilationError, match=r"program_id must be in \[0,3\]"):
_kernel[(1,)](dst)
# ---------------
# test where
# ---------------

View File

@@ -406,7 +406,7 @@ tt.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32
// CHECK-LABEL: @store_constant_align
tt.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
// CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%pid = tt.get_program_id {axis = 0 : i32} : i32
%pid = tt.get_program_id x : i32
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128
%c128_i32 = arith.constant 128 : i32
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = <none>
@@ -438,7 +438,7 @@ tt.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32},
// CHECK-LABEL: @vecadd_mask_align_16
tt.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
%c64_i32 = arith.constant 64 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c64_i32 : i32
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%3 = tt.splat %1 : (i32) -> tensor<64xi32>
@@ -467,7 +467,7 @@ tt.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
// CHECK-LABEL: @vecadd_mask_align_1
tt.func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
%c64_i32 = arith.constant 64 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c64_i32 : i32
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%3 = tt.splat %1 : (i32) -> tensor<64xi32>

View File

@@ -86,7 +86,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: global_load_store_no_vec
tt.func @global_load_store_no_vec(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
@@ -138,7 +138,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: global_load_store_vec4
tt.func @global_load_store_vec4(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
@@ -175,7 +175,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
module attributes {"triton_gpu.num-warps" = 2 : i32} {
tt.func @vecadd_masked_vec1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
%c64_i32 = arith.constant 64 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c64_i32 : i32
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked>
%3 = tt.splat %1 : (i32) -> tensor<64xi32, #blocked>
@@ -205,7 +205,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: global_load_store_vec2
tt.func @global_load_store_vec2(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
@@ -250,7 +250,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: global_load_store_vec8
tt.func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
@@ -357,7 +357,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_program_id
tt.func @basic_program_id() {
// CHECK: nvvm.read.ptx.sreg.ctaid.x : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
tt.return
}
}
@@ -1066,9 +1066,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
%blockidx = tt.get_program_id {axis=0:i32} : i32
%blockidy = tt.get_program_id {axis=1:i32} : i32
%blockidz = tt.get_program_id {axis=2:i32} : i32
%blockidx = tt.get_program_id x : i32
%blockidy = tt.get_program_id y : i32
%blockidz = tt.get_program_id z : i32
// CHECK: nvvm.read.ptx.sreg.ctaid.x
// CHECK: nvvm.read.ptx.sreg.ctaid.y
// CHECK: nvvm.read.ptx.sreg.ctaid.z

View File

@@ -10,8 +10,8 @@ tt.func public @matmul_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32},
%c32_i32 = arith.constant 32 : i32
%c128_i32 = arith.constant 128 : i32
%c8_i32 = arith.constant 8 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = tt.get_program_id {axis = 1 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = tt.get_program_id y : i32
%2 = arith.addi %arg3, %c127_i32 : i32
%3 = arith.divsi %2, %c128_i32 : i32
%4 = arith.addi %arg4, %c31_i32 : i32

View File

@@ -2,7 +2,7 @@
module {
tt.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%c256_i32 = arith.constant 256 : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
@@ -49,7 +49,7 @@ module {
// %c0 = arith.constant 0 : index
// %cst = arith.constant 0.000000e+00 : f32
// %c256_i32 = arith.constant 256 : i32
// %0 = tt.get_program_id {axis = 0 : i32} : i32
// %0 = tt.get_program_id x : i32
// %1 = arith.muli %0, %c256_i32 : i32
// %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>

View File

@@ -86,7 +86,7 @@ tt.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
tt.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout1>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout1>
%2 = arith.muli %1, %c32_i32 : tensor<1024xi32, #layout1>
%3 = arith.addi %2, %c32_i32 : tensor<1024xi32, #layout1>
@@ -102,7 +102,7 @@ tt.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
// CHECK-LABEL: if_convert_else_not
tt.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
%9 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout1>
%2 = arith.muli %1, %c32_i32 : tensor<1024xi32, #layout0>
@@ -123,7 +123,7 @@ tt.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility =
// CHECK-LABEL: if_not_else_convert
tt.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
%9 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout1>
%2 = arith.muli %1, %c32_i32 : tensor<1024xi32, #layout0>
@@ -144,7 +144,7 @@ tt.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility =
// CHECK-LABEL: if_else_both_convert
tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
%2 = arith.muli %1, %c32_i32 : tensor<1024xi32, #layout0>
%3 = arith.addi %2, %c32_i32 : tensor<1024xi32, #layout0>
@@ -323,7 +323,7 @@ tt.func @loop_if(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i3
tt.func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
// CHECK-NOT: triton_gpu.convert_layout
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.splat %1 : (i32) -> tensor<256xi32, #layout1>
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #layout1>
@@ -361,7 +361,7 @@ tt.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr
%c0 = arith.constant 0 : index
%cst_1 = arith.constant dense<2048> : tensor<1x1xi32, #blocked2>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<1x512xf64, #blocked2>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked0>
%2 = triton_gpu.convert_layout %1 : (tensor<1xi32, #blocked0>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<1x1xi32, #blocked1>
@@ -422,7 +422,7 @@ tt.func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg
%cst_12 = arith.constant dense<1> : tensor<1024xi32, #blocked0>
%cst_13 = arith.constant dense<0.000000e+00> : tensor<1024xf32, #blocked0>
%cst_14 = arith.constant dense<0> : tensor<1024xi32, #blocked0>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked0>
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked0>
@@ -809,7 +809,7 @@ tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !
%cst_2 = arith.constant dense<0xFF800000> : tensor<16x16xf32, #blocked2>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked2>
%cst_4 = arith.constant dense<0> : tensor<16x16xi32, #blocked2>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c16_i32 : i32
%2 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked0>
%3 = triton_gpu.convert_layout %2 : (tensor<16xi32, #blocked0>) -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
@@ -908,7 +908,7 @@ tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt
%cst_4 = arith.constant dense<2048> : tensor<64x1xi32, #blocked2>
%cst_5 = arith.constant dense<49152> : tensor<64x1xi32, #blocked2>
%cst_6 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked2>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c64_i32 : i32
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0>
%3 = triton_gpu.convert_layout %2 : (tensor<64xi32, #blocked0>) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
@@ -1044,7 +1044,7 @@ tt.func public @if_no_tensor(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %
%c-1_i64 = arith.constant -1 : i64
%cst = arith.constant 0.000000e+00 : f32
%c-1_i32 = arith.constant -1 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = tt.addptr %arg3, %0 : !tt.ptr<i64>, i32
%2 = tt.load %1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : i64
%3 = arith.cmpi eq, %2, %c-1_i64 : i64
@@ -1127,7 +1127,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%cst_3 = arith.constant dense<196> : tensor<1x256xi32, #blocked>
%cst_4 = arith.constant dense<3136> : tensor<1x256xi32, #blocked>
%cst_5 = arith.constant dense<256> : tensor<1x1xi32, #blocked>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked1>
%2 = triton_gpu.convert_layout %1 : (tensor<1xi32, #blocked1>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<1x1xi32, #blocked2>

View File

@@ -12,7 +12,7 @@ tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__1
%c64_i32 = arith.constant 64 : i32
%c63_i32 = arith.constant 63 : i32
%c8_i32 = arith.constant 8 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.addi %arg3, %c63_i32 : i32
%2 = arith.divsi %1, %c64_i32 : i32
%3 = arith.addi %arg4, %c63_i32 : i32