[FRONTEND][BACKEND] Add the noinline annotation for triton.jit (#1568)

# Introducing the `noinline` Parameter for Triton JIT Decorator

We're excited to introduce a new parameter, `noinline`, that can be
added to the `jit` decorator in Triton. This parameter allows developers
to specify that a particular Triton function should not be inlined into
its callers. In this post, we'll dive into the syntax, purpose, and
implementation details of this new feature.

## Syntax

To use the `noinline` parameter, simply add `noinline=True` to the `jit`
decorator for the function that you don't want to be inlined. Here's an
example:

```python
@triton.jit(noinline=True)
def device_fn(x, y, Z):
    z = x + y
    tl.store(Z, z)

def test_noinline():
    @triton.jit
    def kernel(X, Y, Z):
        x = tl.load(X)
        y = tl.load(Y)
        device_fn(x, y, Z)
```

In this example, the `device_fn` function is decorated with
`@triton.jit(noinline=True)`, indicating that it should not be inlined
into its caller, `kernel`.

## Purpose

The `noinline` parameter serves several key purposes:

- Reducing code size: By preventing inlining, we can reduce the size of
the compiled code.
- Facilitating debugging: Keeping functions separate can make it easier
to debug the code.
- Avoiding common subexpression elimination (CSE) in certain cases: CSE
can sometimes be avoided by using the `noinline` parameter to reduce
register pressure.
- Enabling dynamic linking: This parameter makes it possible to
dynamically link Triton functions.

## Implementation

The implementation of the `noinline` parameter involves significant
changes to three analysis modules in Triton: *Allocation*, *Membar*, and
*AxisInfo*. Prior to this update, these modules assumed that all Triton
functions had been inlined into the root kernel function. With the
introduction of non-inlined functions, we've had to rework these
assumptions and make corresponding changes to the analyses.

### Call Graph and Limitations

<div style="text-align: center;">
<img
src="https://user-images.githubusercontent.com/2306281/234663904-12864247-3412-4405-987b-6991cdf053bb.png"
alt="figure 1" width="200" height="auto">
</div>

To address the changes, we build a call graph and perform all the
analyses on the call graph instead of a single function. The call graph
is constructed by traversing the call edges and storing them in an edge
map. Roots are extracted by checking nodes with no incoming edges.

The call graph has certain limitations:

- It does not support recursive function calls, although this could be
implemented in the future.
- It does not support dynamic function calls, where the function name is
unknown at compilation time.

### Allocation

<div style="text-align: center;">
<img
src="https://user-images.githubusercontent.com/2306281/234665110-bf6a2660-06fb-4648-85dc-16429439e72d.png"
alt="figure 2" width="400" height="auto">
</div>

In Triton, shared memory allocation is achieved through two operations:
`triton_gpu.convert_layout` and `triton_gpu.alloc_tensor`. The
`convert_layout` operation allocates an internal tensor, which we refer
to as a *scratch* buffer, while the `alloc_tensor` operation returns an
allocated tensor and is thus known as an *explicit* buffer.

To accommodate the introduction of function calls, we are introducing a
third type of buffer called a *virtual* buffer. Similar to scratch
buffers, virtual buffers are allocated internally within the scope of a
function call, and the buffers allocated by the called functions remain
invisible to subsequent operations in the calling function. However,
virtual buffers are distinct from scratch buffers in that the call
operation itself does not allocate memory—instead, it specifies the
total amount of memory required by all the child functions being called.
The actual allocation of buffers is performed by individual operations
within these child functions. For example, when invoking edge e1, no
memory is allocated, but the total amount of memory needed by function B
is reserved. Notably, the amount of shared memory used by function B
remains fixed across its call sites due to the consideration of dynamic
control flows within each function.

An additional challenge to address is the calculation of shared memory
offsets for functions within a call graph. While we can assume a shared
memory offset starting at 0 for a single root function, this is not the
case with a call graph, where we must determine each function's starting
offset based on the call path. Although each function has a fixed memory
consumption, the starting offset may vary. For instance, in Figure 2,
the starting offset of function C through edges e1->e2 differs from that
through edges e2->e4. To handle this, we accumulate the starting offset
at each call site and pass it as an argument to the called function.
Additionally, we amend both the function declaration and call sites by
appending an offset variable.

### Membar

<div style="text-align: center;">
<img
src="https://user-images.githubusercontent.com/2306281/234665157-844dd66f-5028-4ef3-bca2-4ca74b8f969d.png"
alt="figure 3" width="300" height="auto">
</div>

The membar pass is dependent on the allocation analysis. Once the offset
and size of each buffer are known, we conduct a post-order traversal of
the call graph and analyze each function on an individual basis. Unlike
previous analyses, we now return buffers that remain unsynchronized at
the end of functions, allowing the calling function to perform
synchronization in cases of overlap.

### AxisInfo

<div style="text-align: center;">
<img
src="https://user-images.githubusercontent.com/2306281/234665183-790a11ac-0ba1-47e1-98b1-e356220405a3.png"
alt="figure 4" width="400" height="auto">
</div>

The AxisInfo analysis operates differently from both membar and
allocation, as it traverses the call graph in topological order. This is
necessary because function arguments may contain axis information that
will be utilized by callee functions. As we do not implement
optimizations like function cloning, each function has a single code
base, and the axis information for an argument is determined as a
conservative result of all axis information passed by the calling
functions.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Keren Zhou
2023-04-28 17:59:04 -04:00
committed by GitHub
parent 65fb36e34e
commit ee864048b3
42 changed files with 1473 additions and 434 deletions

View File

@@ -402,8 +402,6 @@ tt.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32
// -----
module {
// This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer.
// CHECK-LABEL: @store_constant_align
tt.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
@@ -433,8 +431,6 @@ tt.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32},
tt.return
}
}
// -----
// This IR is dumped from vecadd test.
@@ -491,3 +487,88 @@ tt.func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %
tt.store %15, %13, %10 : tensor<64xf32>
tt.return
}
// -----
module {
// We don't use function cloning here, so the alignment info is the gcd of all call sites.
// CHECK-LABEL: @addptr_hints
tt.func @addptr_hints(%arg0: !tt.ptr<i32>) {
// CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1
%cst1 = arith.constant 1 : i32
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
%1 = tt.addptr %arg0, %cst1 : !tt.ptr<i32>, i32
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = 4
%cst4 = arith.constant 4 : i32
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
%2 = tt.addptr %arg0, %cst4 : !tt.ptr<i32>, i32
// CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = 16
%cst16 = arith.constant 16 : i32
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
%3 = tt.addptr %arg0, %cst4 : !tt.ptr<i32>, i32
tt.return
}
// CHECK-LABEL: @kernel_div16
tt.func @kernel_div16(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
tt.call @addptr_hints(%arg0) : (!tt.ptr<i32>) -> ()
tt.return
}
// CHECK-LABEL: @kernel_div8
tt.func @kernel_div8(%arg0: !tt.ptr<i32> {tt.divisibility = 8 : i32}) {
tt.call @addptr_hints(%arg0) : (!tt.ptr<i32>) -> ()
tt.return
}
// CHECK-LABEL: @kernel_div4
tt.func @kernel_div4(%arg0: !tt.ptr<i32> {tt.divisibility = 4 : i32}) {
tt.call @addptr_hints(%arg0) : (!tt.ptr<i32>) -> ()
tt.return
}
}
// -----
module {
// We don't use function cloning here, so the alignment info is the gcd of all call sites.
// CHECK-LABEL: @mul
tt.func @mul(%arg0: i32) {
// CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1
%cst1 = arith.constant 1 : i32
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
%1 = arith.muli %arg0, %cst1 : i32
tt.return
}
// CHECK-LABEL: @bar
tt.func @bar(%arg0: i32) {
tt.call @mul(%arg0) : (i32) -> ()
tt.return
}
// CHECK-LABEL: @foo
tt.func @foo(%arg0: i32) {
tt.call @mul(%arg0) : (i32) -> ()
tt.return
}
// CHECK-LABEL: @call_graph
tt.func @call_graph(%arg0: i32) {
// CHECK: contiguity = [1], divisibility = [4], constancy = [1], constant_value = 12
%cst12 = arith.constant 12 : i32
// CHECK: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
%0 = arith.muli %arg0, %cst12 : i32
tt.call @foo(%0) : (i32) -> ()
// CHECK: contiguity = [1], divisibility = [8], constancy = [1], constant_value = 8
%cst8 = arith.constant 8 : i32
// CHECK: contiguity = [1], divisibility = [8], constancy = [1], constant_value = <none>
%1 = arith.muli %arg0, %cst8 : i32
tt.call @bar(%1) : (i32) -> ()
tt.return
}
}

View File

@@ -28,10 +28,10 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
// CHECK: offset = 0, size = 4608
// CHECK: scratch offset = 0, size = 4608
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT>
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
// CHECK-NEXT: offset = 0, size = 4224
// CHECK-NEXT: scratch offset = 0, size = 4224
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
@@ -56,17 +56,17 @@ tt.func @reusable(%A : !tt.ptr<f16>) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%b_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #AL>
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
// CHECK-NEXT: offset = 0, size = 4608
// CHECK-NEXT: scratch offset = 0, size = 4608
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT>
%a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
// CHECK-NEXT: offset = 0, size = 1152
// CHECK-NEXT: scratch offset = 0, size = 1152
%a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #B_DOT>
%a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
// CHECK-NEXT: offset = 0, size = 4608
// CHECK-NEXT: scratch offset = 0, size = 4608
%a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT>
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
%a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
// CHECK-NEXT: offset = 0, size = 1152
// CHECK-NEXT: scratch offset = 0, size = 1152
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #B_DOT>
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
tt.return
@@ -396,3 +396,127 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
}
}
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: alloc1
tt.func @alloc1(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 512
%cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
tt.return
// CHECK-NEXT: size = 512
}
// CHECK-LABEL: alloc2
tt.func @alloc2(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 1024
%cst0 = triton_gpu.alloc_tensor : tensor<32x16xf16, #A_SHARED>
tt.return
// CHECK-NEXT: size = 1024
}
// CHECK-LABEL: alloc3
tt.func @alloc3(%cond : i1) {
scf.if %cond {
// CHECK: offset = 0, size = 512
%cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
} else {
// CHECK-NEXT: offset = 0, size = 1024
%cst0 = triton_gpu.alloc_tensor : tensor<16x32xf16, #A_SHARED>
}
tt.return
// CHECK-NEXT: size = 1024
}
// CHECK-LABEL: alloc4
tt.func @alloc4(%A : !tt.ptr<f16>, %cond : i1) {
scf.if %cond {
// CHECK: virtual offset = 0, size = 1024
tt.call @alloc3(%cond) : (i1) -> ()
} else {
// CHECK-NEXT: virtual offset = 0, size = 512
tt.call @alloc1(%A) : (!tt.ptr<f16>) -> ()
}
tt.return
// CHECK-NEXT: size = 1024
}
// CHECK-LABEL: single_call
tt.func @single_call(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
// CHECK-NEXT: virtual offset = 0, size = 512
tt.call @alloc1(%A) : (!tt.ptr<f16>) -> ()
tt.return
// CHECK-NEXT: size = 512
}
// CHECK-LABEL: multiple_calls
tt.func @multiple_calls(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: virtual offset = 0, size = 512
tt.call @alloc1(%A) : (!tt.ptr<f16>) -> ()
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
// CHECK-NEXT: virtual offset = 0, size = 1024
tt.call @alloc2(%A) : (!tt.ptr<f16>) -> ()
tt.return
// CHECK-NEXT: size = 1024
}
// CHECK-LABEL: if_else_calls
tt.func @if_else_calls(%A : !tt.ptr<f16>, %cond : i1) {
scf.if %cond {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: offset = 0, size = 1024
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
// CHECK-NEXT: virtual offset = 0, size = 512
tt.call @alloc1(%A) : (!tt.ptr<f16>) -> ()
} else {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
// CHECK-NEXT: virtual offset = 0, size = 1024
tt.call @alloc2(%A) : (!tt.ptr<f16>) -> ()
}
tt.return
// CHECK-NEXT: size = 1024
}
// CHECK-LABEL: for_calls
tt.func @for_calls(%A : !tt.ptr<f16>, %cond : i1) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
%lb = arith.constant 0 : index
%ub = arith.constant 10 : index
%step = arith.constant 1 : index
scf.for %iv = %lb to %ub step %step {
// CHECK-NEXT: virtual offset = 0, size = 512
tt.call @alloc1(%A) : (!tt.ptr<f16>) -> ()
}
tt.return
// CHECK-NEXT: size = 512
}
// CHECK-LABEL: call_graph_1
tt.func @call_graph_1(%A : !tt.ptr<f16>, %cond : i1) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: virtual offset = 0, size = 1024
tt.call @alloc3(%cond) : (i1) -> ()
tt.return
// CHECK-NEXT: size = 1024
}
// CHECK-LABEL: call_graph_2
tt.func @call_graph_2(%A : !tt.ptr<f16>, %cond : i1) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: virtual offset = 0, size = 1024
tt.call @alloc4(%A, %cond) : (!tt.ptr<f16>, i1) -> ()
tt.return
// CHECK-NEXT: size = 1024
}
}

View File

@@ -503,3 +503,136 @@ tt.func @cf_if_else_return(%i1 : i1) {
}
}
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: convert_layout1
tt.func @convert_layout1(%A : !tt.ptr<f16>) {
// CHECK-NOT: gpu.barrier
%0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
%1 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
tt.return
}
// CHECK-LABEL: convert_layout2
tt.func @convert_layout2(%A : !tt.ptr<f16>) {
%0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
%1 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
%2 = tt.cat %1, %1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
// CHECK: triton_gpu.convert_layout
// CHECK-NEXT: gpu.barrier
%3 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
%4 = tt.cat %2, %2 {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #AL>
tt.return
}
// CHECK-LABEL: convert_layout3
tt.func @convert_layout3(%cond : i1) {
scf.if %cond {
%0 = triton_gpu.alloc_tensor : tensor<16x64xf16, #A_SHARED>
// CHECK: triton_gpu.convert_layout
// CHECK-NOT: gpu.barrier
%1 = triton_gpu.convert_layout %0 : (tensor<16x64xf16, #A_SHARED>) -> tensor<16x64xf16, #AL>
} else {
%0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
// CHECK: triton_gpu.convert_layout
// CHECK-NEXT: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%1 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
%2 = triton_gpu.convert_layout %1 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED>
}
tt.return
}
// CHEKC-LABEL: convert_layout4
tt.func @convert_layout4(%A : !tt.ptr<f16>, %cond : i1) {
// CHECK-NOT: gpu.barrier
scf.if %cond {
tt.call @convert_layout3(%cond) : (i1) -> ()
} else {
tt.call @convert_layout2(%A) : (!tt.ptr<f16>) -> ()
}
tt.return
}
// CHECK-LABEL: single_call_sync
tt.func @single_call_sync(%A : !tt.ptr<f16>) {
%0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
// CHECK: tt.call
// CHECK-NEXT: gpu.barrier
tt.call @convert_layout1(%A) : (!tt.ptr<f16>) -> ()
%1 = triton_gpu.convert_layout %0 : (tensor<16x32xf16, #AL>) -> tensor<16x32xf16, #BL>
tt.return
}
// CHECK-LABEL: single_call_no_sync
// %1 can reuse %0 in convert_layout2, which has been synced
tt.func @single_call_no_sync(%A : !tt.ptr<f16>) {
// CHECK-NOT: gpu.barrier
%0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
tt.call @convert_layout2(%A) : (!tt.ptr<f16>) -> ()
%1 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #BL>
tt.return
}
// CHECK-LABEL: multiple_calls
tt.func @multiple_calls(%A : !tt.ptr<f16>) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
tt.call @convert_layout1(%A) : (!tt.ptr<f16>) -> ()
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
tt.call @convert_layout2(%A) : (!tt.ptr<f16>) -> ()
tt.return
}
// CHECK-LABEL: if_else_calls
tt.func @if_else_calls(%A : !tt.ptr<f16>, %cond : i1) {
scf.if %cond {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.call
// CHECK-NEXT: gpu.barrier
tt.call @convert_layout1(%A) : (!tt.ptr<f16>) -> ()
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
} else {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
// CHECK: tt.call
// CHECK-NOT: gpu.barrier
tt.call @convert_layout2(%A) : (!tt.ptr<f16>) -> ()
}
tt.return
}
// CHECK-LABEL: for_calls
tt.func @for_calls(%A : !tt.ptr<f16>, %cond : i1) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
%lb = arith.constant 0 : index
%ub = arith.constant 10 : index
%step = arith.constant 1 : index
scf.for %iv = %lb to %ub step %step {
// CHECK: gpu.barrier
// CHECK-NEXT: tt.call
tt.call @convert_layout1(%A) : (!tt.ptr<f16>) -> ()
}
tt.return
}
// CHECK-LABEL: call_graph_1
tt.func @call_graph_1(%A : !tt.ptr<f16>, %cond : i1) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.call
tt.call @convert_layout3(%cond) : (i1) -> ()
tt.return
}
// CHECK-LABEL: call_graph_2
tt.func @call_graph_2(%A : !tt.ptr<f16>, %cond : i1) {
tt.call @convert_layout4(%A, %cond) : (!tt.ptr<f16>, i1) -> ()
// CHECK: tt.call
// CHECK-NEXT: gpu.barrier
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
tt.return
}
}