Commit Graph

355 Commits

Author SHA1 Message Date
Michael Melesse
275fead8e3 fix lit test 2023-05-12 15:37:08 -05:00
Ingo Müller
b2a757d000 [BUILD] Add missing CMake link-time dependencies. (#1654) 2023-05-11 19:17:44 -07:00
Zahi Moudallal
9d7980fa3b [BACKEND] Updated predicate for atomic ops (#1619) 2023-05-05 16:45:11 -07:00
Zahi Moudallal
e2ae2c6c48 [BACKEND] Modified store op thread masking (#1605) 2023-05-04 17:15:05 -07:00
Michael Melesse
13facab95f fix lit tests
This is a combination of 3 commits.

fix build and test errors

fix lit test error

fix lit tests
2023-05-01 12:48:20 -05:00
Philippe Tillet
39e751bd4e [OPTIMIZER][BACKEND] Enabled elementwise ops (including casts) between ldmatrix and mma.sync (#1595) 2023-04-30 18:05:26 -07:00
Keren Zhou
ee864048b3 [FRONTEND][BACKEND] Add the noinline annotation for triton.jit (#1568)
# Introducing the `noinline` Parameter for Triton JIT Decorator

We're excited to introduce a new parameter, `noinline`, that can be
added to the `jit` decorator in Triton. This parameter allows developers
to specify that a particular Triton function should not be inlined into
its callers. In this post, we'll dive into the syntax, purpose, and
implementation details of this new feature.

## Syntax

To use the `noinline` parameter, simply add `noinline=True` to the `jit`
decorator for the function that you don't want to be inlined. Here's an
example:

```python
@triton.jit(noinline=True)
def device_fn(x, y, Z):
    z = x + y
    tl.store(Z, z)

def test_noinline():
    @triton.jit
    def kernel(X, Y, Z):
        x = tl.load(X)
        y = tl.load(Y)
        device_fn(x, y, Z)
```

In this example, the `device_fn` function is decorated with
`@triton.jit(noinline=True)`, indicating that it should not be inlined
into its caller, `kernel`.

## Purpose

The `noinline` parameter serves several key purposes:

- Reducing code size: By preventing inlining, we can reduce the size of
the compiled code.
- Facilitating debugging: Keeping functions separate can make it easier
to debug the code.
- Avoiding common subexpression elimination (CSE) in certain cases: CSE
can sometimes be avoided by using the `noinline` parameter to reduce
register pressure.
- Enabling dynamic linking: This parameter makes it possible to
dynamically link Triton functions.

## Implementation

The implementation of the `noinline` parameter involves significant
changes to three analysis modules in Triton: *Allocation*, *Membar*, and
*AxisInfo*. Prior to this update, these modules assumed that all Triton
functions had been inlined into the root kernel function. With the
introduction of non-inlined functions, we've had to rework these
assumptions and make corresponding changes to the analyses.

### Call Graph and Limitations

<div style="text-align: center;">
<img
src="https://user-images.githubusercontent.com/2306281/234663904-12864247-3412-4405-987b-6991cdf053bb.png"
alt="figure 1" width="200" height="auto">
</div>

To address the changes, we build a call graph and perform all the
analyses on the call graph instead of a single function. The call graph
is constructed by traversing the call edges and storing them in an edge
map. Roots are extracted by checking nodes with no incoming edges.

The call graph has certain limitations:

- It does not support recursive function calls, although this could be
implemented in the future.
- It does not support dynamic function calls, where the function name is
unknown at compilation time.

### Allocation

<div style="text-align: center;">
<img
src="https://user-images.githubusercontent.com/2306281/234665110-bf6a2660-06fb-4648-85dc-16429439e72d.png"
alt="figure 2" width="400" height="auto">
</div>

In Triton, shared memory allocation is achieved through two operations:
`triton_gpu.convert_layout` and `triton_gpu.alloc_tensor`. The
`convert_layout` operation allocates an internal tensor, which we refer
to as a *scratch* buffer, while the `alloc_tensor` operation returns an
allocated tensor and is thus known as an *explicit* buffer.

To accommodate the introduction of function calls, we are introducing a
third type of buffer called a *virtual* buffer. Similar to scratch
buffers, virtual buffers are allocated internally within the scope of a
function call, and the buffers allocated by the called functions remain
invisible to subsequent operations in the calling function. However,
virtual buffers are distinct from scratch buffers in that the call
operation itself does not allocate memory—instead, it specifies the
total amount of memory required by all the child functions being called.
The actual allocation of buffers is performed by individual operations
within these child functions. For example, when invoking edge e1, no
memory is allocated, but the total amount of memory needed by function B
is reserved. Notably, the amount of shared memory used by function B
remains fixed across its call sites due to the consideration of dynamic
control flows within each function.

An additional challenge to address is the calculation of shared memory
offsets for functions within a call graph. While we can assume a shared
memory offset starting at 0 for a single root function, this is not the
case with a call graph, where we must determine each function's starting
offset based on the call path. Although each function has a fixed memory
consumption, the starting offset may vary. For instance, in Figure 2,
the starting offset of function C through edges e1->e2 differs from that
through edges e2->e4. To handle this, we accumulate the starting offset
at each call site and pass it as an argument to the called function.
Additionally, we amend both the function declaration and call sites by
appending an offset variable.

### Membar

<div style="text-align: center;">
<img
src="https://user-images.githubusercontent.com/2306281/234665157-844dd66f-5028-4ef3-bca2-4ca74b8f969d.png"
alt="figure 3" width="300" height="auto">
</div>

The membar pass is dependent on the allocation analysis. Once the offset
and size of each buffer are known, we conduct a post-order traversal of
the call graph and analyze each function on an individual basis. Unlike
previous analyses, we now return buffers that remain unsynchronized at
the end of functions, allowing the calling function to perform
synchronization in cases of overlap.

### AxisInfo

<div style="text-align: center;">
<img
src="https://user-images.githubusercontent.com/2306281/234665183-790a11ac-0ba1-47e1-98b1-e356220405a3.png"
alt="figure 4" width="400" height="auto">
</div>

The AxisInfo analysis operates differently from both membar and
allocation, as it traverses the call graph in topological order. This is
necessary because function arguments may contain axis information that
will be utilized by callee functions. As we do not implement
optimizations like function cloning, each function has a single code
base, and the axis information for an argument is determined as a
conservative result of all axis information passed by the calling
functions.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-04-28 14:59:04 -07:00
Zahi Moudallal
65fb36e34e [BACKEND] Updated slice layout semantics, updated vectorization logic used for load/store ops. (#1587) 2023-04-28 13:40:01 -07:00
Philippe Tillet
8f47bdcc92 [OPTIMIZER] Added kWidth attribute to DotOperandEncoding (#1584)
This is a pre-requisist for efficient mixed-precision matmul
2023-04-26 23:03:18 -07:00
Michael Melesse
2784b804d9 Merge remote-tracking branch 'upstream/main' into ifu_4_26_2023 2023-04-26 12:04:21 -05:00
Michael Melesse
705d47d0dd fix lit test issues
This is a combination of 6 commits.

install lit

fix lit test

fix lit test

fix aot lit issues

fix final lit tests

add lit tests
2023-04-17 11:46:37 -05:00
peterbell10
e152183570 [FRONTEND][BACKEND] ReduceOp to support arbitrary reduce operations (#1305)
Fixes #1285

This changes `tt.reduce` to replace `redOp` by a region containing
arbitrary code. For example, `tl.sum` is now lowered as:
```mlir
%res = "tt.reduce"(%arg0) ({
^bb0(%arg1: f32, %arg2: f32):
  %add = arith.addf %arg1, %arg2 : f32
  tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<128x128xf32>) -> tensor<128xf32>
```
Support for index reductions at the MLIR level are also dropped in favor
of simultaneous reductions over multiple tensors. Which generalizes the
code without loss of performance. So for example `argmin` gets lowered
as:
```mlir
  %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
  %8 = tt.view %7 : (tensor<256xi32>) -> tensor<1x256xi32>
  %9:2 = "tt.reduce"(%6, %8) ({
  ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32):
    %14 = arith.cmpf olt, %arg4, %arg6 : f32
    %15 = arith.cmpf ogt, %arg4, %arg6 : f32
    %16 = arith.cmpi slt, %arg5, %arg7 : i32
    %17 = arith.select %16, %arg5, %arg7 : i32
    %18 = arith.select %15, %arg7, %17 : i32
    %19 = arith.select %14, %arg5, %18 : i32
    %20 = arith.cmpf olt, %arg4, %arg6 : f32
    %21 = arith.select %20, %arg4, %arg6 : f32
    tt.reduce.return %21, %19 : f32, i32
  }) {axis = 1 : i32} : (tensor<1x256xf32>, tensor<1x256xi32>) -> (tensor<1xf32>, tensor<1xi32>)
```
2023-04-13 01:37:39 +00:00
Keren Zhou
6d0ed41307 [BACKEND] Replace Func Dialect with custom triton ops (func, call, return) (#1502)
MLIR current only supports a custom inlining interface per dialect, so
we cannot change the inlining decision of `func.func`.


https://discourse.llvm.org/t/avoid-inlining-some-functions-using-the-func-dialect/69830/3

Could revert it back once they've designed a better inliner interface.

Inlining attributes will be implemented in the next PR since this PR is
already huge.
2023-04-10 21:08:40 -07:00
Philippe Tillet
640f3c3921 [OPTIMIZER] Tweaked layout removal conversion heuristics (#1501)
Loads are now consider cheap to rematerialize when there are more
threads than elements in the tensor
2023-04-10 15:19:08 -07:00
Keren Zhou
032509384a [ANALYSIS] Fine-tune comments for shared memory allocation (#1492)
And add a new test to check multiple color cases which have never be
tested before
2023-04-10 09:00:36 -07:00
Philippe Tillet
adc760dac1 [OPTIMIZER] enable loop pipelining using pointer increments from vector look-up tables (#1490) 2023-04-10 08:59:42 -07:00
Philippe Tillet
b86425a28e [TEST] made lut_bmm pipeline test more concise and specific (#1488) 2023-04-08 19:17:35 -07:00
Rahul Batra
a27b388df5 Merge remote-tracking branch 'upstream/main' into IFU_04-06-2023 2023-04-06 16:18:31 -05:00
Keren Zhou
801bb9d3b5 [ANALYSIS] Fix divisibility calculation for addptr (#1453) 2023-03-31 17:57:31 -07:00
Chenggang Zhao
72b071253e [FRONTEND] Support block pointer semantics (#1392)
This PR introduces a new semantics: **block pointer**, which makes users
easier & faster to load a block from a parent tensor.

Below is a detailed API change by an example:
```
# Make a block pointer, which points to a block in the parent shape
# `base`: the parent tensor
# `shape`: the shape of the parent tensor
# `strides`: the strides of the parent tensor
# `offsets`: the offsets of the block in the parent tensor
# `order`: the order of the data arrangement in memory
# Below is an example loading a 2D column-major matrix 
block_ptr = tl.make_block_ptr(base=ptr, shape=(M, N), strides=(stride_m, stride_n), offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))

# Advance the offsets; note that the striding information is already saved in `block_ptr`
# `base`: the block pointer to be advanced
# `offsets`: the offsets for each dimension
block_ptr = tl.advance(base=block_ptr, offsets=(BLOCK_M, -BLOCK_N))
block_ptr = tl.advance(base=block_ptr, offsets=(-BLOCK_M, BLOCK_N))

# Load from a block pointer, the output type is the dereferenced type of `block_ptr`, e.g. ptr<tensor<32x32xf32>> -> tensor<32x32xf32>
# `ptr`: the block pointer to be loaded
# `boundary_check`: a tuple of dimensions to check the boundary
# `padding`: padding strategy for elements out of bound
val = tl.load(ptr=block_ptr, boundary_check=(0, 1), padding="zero")

# Store by a block pointer, in which the pointer and the value tensor should have the same shape
# `ptr`: the block pointer to be stored
# `boundary_check`: a tuple of dimensions to check the boundary (no-write if out of bound)
tl.store(ptr=block_ptr, value=val, boundary_check=(0, 1))
```

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-03-27 16:46:49 -07:00
Xuehai Pan
5b36cb48ad [CI][TEST] update pre-commit hooks and use pre-commit for style tests in CI (#1409)
Ref issue:

- #1408

Changes:

- Add `.editorconfig`
- Add `pre-commit-hooks`:

    ```yaml
    - repo: https://github.com/pre-commit/pre-commit-hooks
      rev: v4.4.0
      hooks:
        - id: check-symlinks
        - id: destroyed-symlinks
        - id: trailing-whitespace
        - id: end-of-file-fixer
        - id: check-yaml
        - id: check-toml
        - id: check-ast
        - id: check-added-large-files
        - id: check-merge-conflict
        - id: check-executables-have-shebangs
        - id: check-shebang-scripts-are-executable
        - id: detect-private-key
        - id: debug-statements
    ```
- Add `flake8` to `pre-commit` config and add `.flake8` file
- Use `pre-commit` for style tests in CI
- Run `pre-commit` and fix existing violations:
    - fix trailing spaces
    - fix end-of-files
    - fix mod file mode with `chmod -x`
    - run `autopep8` on existing code
    - fix `flake8` violations
2023-03-25 14:52:16 -07:00
Keren Zhou
b7762bee2c [TEST] Cleanup SCF dialect in tests (#1402) 2023-03-24 09:21:40 -07:00
Keren Zhou
c9f47d9094 [BACKEND] Init values before load to avoid ptxas issues (#1396) 2023-03-23 17:24:03 -07:00
Keren Zhou
2ba77a9212 [OPTIMIZER] Fix a typo in SimplifyReduceCvt (#1385) 2023-03-21 22:45:58 -07:00
Keren Zhou
23fc647a3e [OPTIMIZER] Fixe optimizer hanging caused by SimplifyReduceCvt (#1377)
https://github.com/openai/triton/issues/1328

Match the convert_layout operation in SimplifyReduceCvt
(convert_layout->reduce). This way we don't miss higher priority rewrite
patterns like RematerializeBackward and SimplifyConversion. We also need
to set SimplifyConversion's benefit = 4, RematerializeBackward's benefit
= 3, and RematerializeForward's benefit = 2.
2023-03-20 16:20:19 -07:00
Keren Zhou
78d5900467 [OPTIMIZER] Improve pipeline to handle general indirect access to matrices (#1291)
Differentiate between immediate and non-immediate block arguments. 
If we have a load that immediately depends on a block argument in the
current iteration, it is an immediate dependency. Otherwise, it is a
non-immediate dependency, which means the load depends on a block
argument in the previous iterations.

For example:
```
scf.for (%arg0, %arg1, %arg2) {
%0 = load %arg0  <--- immediate dep, this address is initialized at numStages-2
%1 = load %arg1
%2 = add %1, %arg2
%3 = load %2  <--- non-immediate dep, %arg1 must be an update-to-date value
}
```

The above code pattern is commonly seen in cases where we have indirect
memory accesses using a lookup table, such as PyTorch's `bsr_dense_bmm`.
This PR improves `bsr_dense_bmm` for about ~20% on the unit test cases.
2023-03-20 14:39:47 -04:00
Rohit Santhanam
a84b4883e6 Merge remote-tracking branch 'upstream/main' into triton-mlir-IFU-03192023 2023-03-19 13:46:50 +00:00
peterbell10
c9740f0870 [OPTIMIZER] Add canonicalize/fold for ExpandDimsOp, ViewOp and BroadcastOp (#1354)
These eliminate no-op reshapes, and simplify some combinations of view
ops into a single view. e.g. viewing a splat becomes a single splat.
2023-03-16 21:13:58 -07:00
Berke Kocaoğlu
ba91f39dbf [DOC] Fix syntax errors, typos, formatting; increase consistency (#1357)
This PR;
- Fixes syntax errors like `.type values: dict[str,
Callable[[list[Any]], Any]]` to `:type values: dict[str,
Callable[[list[Any]], Any]]`,
- Fixes typos,
- Fixes formatting like `k ++` to ` k++`,
- Increases consistency (e.g. by transforming the minority `cd dir/` to
the majority `cd dir`).
2023-03-16 15:32:02 -07:00
Rohit Santhanam
77dcb667b5 Merge remote-tracking branch 'upstream/main' into triton-mlir-IFU-03162023 2023-03-16 13:21:15 +00:00
Philippe Tillet
082828af47 [OPTIMIZER] Fixed up divisibility analysis in div operation (#1341) 2023-03-14 18:17:05 -07:00
Rohit Santhanam
6ff54b495c Merge remote-tracking branch 'upstream/main' into triton-mlir-IFU-03122023 2023-03-13 18:09:12 +00:00
Philippe Tillet
9b7c65a3a9 [BACKEND][OPTIMIZER] Refactor MMAv1 codegen (#1322)
- Significant simplification of the optimizer pipeline. Right mma
version is now set directly after the coalescing pass. DotOperand layout
no longer hold a state to `isRow` argument, and instead query it from
their parent
- Moved a bunch of things from TritonGPUToLLVM/DotOpHelpers to
TritonGPUAttrDefs. All MMAv1 state is now queried from attributes.
- logic for getELemsPerThread is no longer duplicated in TypeConverter
2023-03-12 19:54:38 -07:00
Philippe Tillet
53c216f506 [TESTS] Now using downloaded FileCheck directly without copying it (#1293) 2023-03-07 09:11:08 -08:00
Rohit Santhanam
6b35506291 Some warp64 related fixes.
Reduce the num_warps for layernorm and softmax.

Fix some lit unit tests.
2023-03-07 16:15:04 +00:00
rsanthanam-amd
70c1ffa77e Merge pull request #153 from B1tway/wave64-support
Adding wave64 support to Triton
2023-03-07 09:33:51 -06:00
rsanthanam-amd
8d42b7e8e2 Merge pull request #146 from ROCmSoftwarePlatform/fix_for_non_vectorized_float16_loads_stores
Fix for LLVM compiler's inability to auto-vectorize float16 loads and stores on AMDGPU.
2023-03-07 08:24:01 -06:00
Rohit Santhanam
38710d6623 Fix LIT tests. 2023-03-07 13:25:12 +00:00
Philippe Tillet
3db55c5f94 [OPTIMIZER]]BACKEND] Some backend and optimization passes clean-up (#1284)
* Cleaned up pipeline pass. Now works when there are element-wise ops
between the load and the dot
* Made `splat` compatible with varibales that have DotOperandLayout
* Moves rematerialization utils to separate Transforms/Utility.cpp file.
2023-03-06 17:17:59 -08:00
Keren Zhou
73d55eb59c [OPTIMIZER] Check if the definition op is NULL in GPU combine passes (#1288) 2023-03-06 16:19:13 -08:00
Keren Zhou
4731f300d3 [BACKEND] Mask out wrapped threads in store ops (#1283) 2023-03-06 14:50:20 -08:00
B1tway
a931a50719 Fix offline tests 2023-03-06 20:12:19 +00:00
Rohit Santhanam
dadc09623b Replace hard coded ROCM paths with ROCM_PATH env var. 2023-03-06 03:20:38 +00:00
Philippe Tillet
ff94e34430 [TESTS][BUILD] now using llvm @ 8e5a41e8271f (#1282)
Now we also use the FileTest utility packaged with llvm pre-built binaries
2023-03-05 17:23:00 -08:00
Keren Zhou
d54745538b [BACKEND][CI] Disable most backward rematerialization through load/store (#1260) 2023-03-02 09:45:50 -08:00
Philippe Tillet
fa0fbc937f [FRONTEND][BACKEND][OPTIMIZER] Loops now use 64-bit indices when necessary (#1261)
* Frontend:
  - `int` kernel arguments are always signed
- Loop induction variable is now determine by integer promotion on
lb/ub/step
* Optimizer:
  -  Added new ExtractSliceOp that enforces 32-bit offsets
* Backend:
    - Use 64-bit indices when lowering functions and control flow
    - Removed `idx_val` macro and replaced it with `i32_val`
    - Cleaned up comments
- Added new ArithToIndex pass to make sure operations on indices are
done with the `index` dialect, that gets converted to LLVM separately
using a 64-bit target
2023-03-01 23:09:48 -08:00
Rohit Santhanam
cd9ae1cd36 Merge remote-tracking branch 'upstream/main' into triton-mlir-IFU-02232023 2023-02-23 21:41:54 +00:00
Alexander Efimov
c3f24143d2 fix basic_load atomic_add_f32 tests 2023-02-23 12:03:56 +01:00
Alexander Efimov
33667cd106 Specialize checks for async slice tests:
- basic_insert_slice_async_v4
- basic_insert_slice_async_v1
- basic_insert_slice_async_v1_multictas
2023-02-23 12:03:56 +01:00
Alexander Efimov
9ad7fec871 [Test][LIT] Fix Convertion/tritongpu_to_llvm.mlir crash
This PR disables following sub tests, because they are PTX specific:
- basic_async_wait
- convert_dot
- matmul_kernel_dot_operand_layout
- matmul884_kernel_dot_operand_layout
- matmul_tf32dot
2023-02-23 12:03:56 +01:00