Commit Graph

229 Commits

Author SHA1 Message Date
Philippe Tillet
4c0e3d907e [TOOLS] improved ahead-of-time compiler (#1805)
This is a revival of @gaxler initial ahead-of-time compiler proposal.
Code was simplified and some constraints were relaxed (i.e., we now
execute the entire file provided vs just the kernel AST) to promote
maintainability. A basic unit test was added, though it does not test
specialization right now.

co-authored by: Gregory Axler, thexler <g.axler@gmail.com>
2023-06-21 01:02:58 -07:00
Keren Zhou
1851c8ca99 [FRONTEND] Fix binary compare op on constexprs (#1801)
Example:

```
if static_a == 0 and static_b == 1:
    ...
```

The return value of `static_a == 0` should be `constexpr(True)` but not
`True`, otherwise the bool object (True/False) doesn't have the
`logical_and` method.
2023-06-18 20:27:56 -07:00
Christopher Hesse
981e98a213 [FRONTEND] update assert_helper.py (#1789) 2023-06-15 16:24:30 -07:00
Philippe Tillet
9a2580de13 [CI] Added H100 node (#1779) 2023-06-15 14:21:47 -07:00
Philippe Tillet
b24dc19741 [FRONTEND] cleaned up symbol names (#1782) 2023-06-14 18:55:32 -07:00
Zahi Moudallal
ac15d00ef4 [TEST] Added f8xf16 tests (#1771) 2023-06-12 16:14:17 -07:00
Wang Weihan
b27a91a113 [FRONTEND] Enable triton to support register thirdparty backend at runtime (#1643)
This PR intends to provide a mechanism to support a third-party backend
at runtime to generate the backend-specific code.

The mechanism provided a common class to abstract the third-party
backend logic and two essential functions to register and get the
third-party backend at runtime.

- `BaseBackend`: A common class to abstract the third-party backend
logic
- `register_backend`: Register a third-party backend with a given device
type
- `get_backend`: Get the third-party backend with a given device type

Generally, a third-party backend must inherit from `BaseBackend` and
implement all the member functions according to the backend
characteristics. As long as the backend implementation is ready, the
third-party backend can invoke `register_backend` to register it under a
given device. During the kernel compilation and execution, the mechanism
will get the registered backend to generate the kernel and launcher code
for a given device.

This PR added a dummy backend to simulate a third-party backend and
demonstrate the usage.

-
[test_device_backend.py](https://github.com/openai/triton/pull/1643/files#diff-bbe4d50624f2d11bf17c878a1ed4d422918c124c182cf9357b993240c385bea1):
To define a third-party backend and register the backend
-
[ExtensionBackend](https://github.com/openai/triton/pull/1643/files#diff-bbe4d50624f2d11bf17c878a1ed4d422918c124c182cf9357b993240c385bea1R123):
Inherit from the `BaseBackend` and implement some specific logic like
[filter out some compile
stages](https://github.com/openai/triton/pull/1643/files#diff-bbe4d50624f2d11bf17c878a1ed4d422918c124c182cf9357b993240c385bea1R129-R135)
- [Register the `ExtensionBackend` for
`CPU`](https://github.com/openai/triton/pull/1643/files#diff-bbe4d50624f2d11bf17c878a1ed4d422918c124c182cf9357b993240c385bea1R279)
  
-
[extension_backend.c](https://github.com/openai/triton/pull/1643/files#diff-169c1d08b3a0a7b343cfa3258fbc32b47e0f6c46305a112652fa1bdaaec89d29):
To provide the utility function to load kernel binary and get the
backend properties.
2023-06-09 09:09:59 -07:00
Keren Zhou
4fbadf6f6f [BACKEND] Fix tl.cat when the number of threads > the size of a tensor (#1751)
`tl.cat(tensor<64>, tensor<64>) -> tensor(128)`, because it concatenates
elements into a single thread, if number of threads is 128, each thread
should own at least 2 elements.
With this PR, we also disable remat of the cat op in some cases.
2023-06-07 15:42:38 -07:00
Philippe Tillet
c52a91231a [FRONTEND][BACKEND] Add acquire/release semantics for atomics (#1739) 2023-06-05 19:09:13 -07:00
Philippe Tillet
6c1992cb38 [FRONTEND] min/max now accept return_indices argument (#1731)
Not just syntactic sugar for successive max + argmax but also avoids
computing the max twice
2023-06-02 22:01:02 -07:00
Keren Zhou
1e171bf270 [BACKEND] Pipeline pass rewrite part 1: functionality fixes (#1716)
Support the following three cases:
1. Operands of `load` depend on induction variables before `load`s.
2. Mixed use of induction variables and offset to update the `ptr`.
3. Cross iteration (>1) dependency values.
2023-06-01 12:07:43 -07:00
Mehdi Amini
440fd1bf20 [TESTS] Increase the paylog of the globaltimer kernel to reduce chances of fakiness (#1726)
If the kernel is too small, on a very fast GPU we may get 0 because the
resolution of the timer seems too coarse.

Fixes #1725
2023-06-01 02:53:07 -07:00
Mehdi Amini
b0c893cdc5 [FRONTEND][BACKEND] Hardened get_program_id axis by making it an enum attribute (#1721)
Also catch out-of-bounds indices at constructio and throw a proper error
in the frontend.
Finally, let's make the IR a bit prettier:

  %0 = tt.get_program_id {axis = 0 : i32} : i32

becomes:

  %0 = tt.get_program_id x : i32

Fixes #1718
2023-05-31 22:49:46 -07:00
Mehdi Amini
19c65d6007 [FRONTEND] fix checks for valid slice and avoid hitting an obscure exception. (#1720)
When comparing to the expected slides, using the `==` operator will
dispatch to the component of the slice. If the user writes `a[10:20]`
these are `triton.constexpr` instances, and the `__eq__` operator which
is implemented as: `return constexpr(self.value == other.value)`. At
this point the access to `.value` on the provided `None` yields an
exception that isn't very friendly to the user.

I am not sure if the implementation of `constexpr` should be hardened
instead?

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-05-31 16:37:19 +00:00
Philippe Tillet
4e2f57add5 [FRONTEND] Added default axis=None for reduction, which reduces across all the axes. (#1712) 2023-05-28 16:13:21 -07:00
Philippe Tillet
420e4acecc [TEST] Added flash attention tests for D_HEAD in {16, 32, 128}. (#1709) 2023-05-27 22:48:22 -07:00
Keren Zhou
0341953466 [FRONTEND] Correct the debug syntax (#1705)
- If `TRITON_DEBUG=True`, all triton functions will be compiled in the
debug mode.
- Otherwise, a triton function `f`'s debug flag is either `True`,
`False` or `None` (default).
    - If `True`, `f` is compiled in the debug mode.
    - If `False`, `f` is compiled in the normal mode.
- If `None`, `f` is compiled based on its caller's debug flag. The root
(kernel) function's debug flag can also be set through the `compile`
function.

cc @ngimel , @Chillee
2023-05-24 23:24:29 -07:00
Philippe Tillet
b5ba639bae [FRONTEND] fixed issue for fp64 literals and added tests (#1698)
fixes #1686
2023-05-20 18:36:28 -07:00
Keren Zhou
fb30d84069 [FRONTEND] Refactor contains_return_op into an independent AST (#1694)
https://github.com/openai/triton/issues/1690
2023-05-20 11:18:40 -07:00
Zahi Moudallal
34817ecc95 [BACKEND] Added support to convert shared to distributed layouts (#1682) 2023-05-17 17:20:29 -07:00
Keren Zhou
3baab48eaf [FRONTEND] Differentiate between bool and int in the frontend (#1678)
`bool` is a subclass of `int`, so `isinstance(bool_var, int) == True`,
and a `bool` constant will be converted to an `int` constant.

In triton specifically, if a bool var is treated as an integer, it
prevents us using the `logical_and` operator which requires both
operands have the same bit length.

> Cannot bitcast data-type of size 32 to data-type of size 1

By differentiating int and bool, it allows us to make the syntax more
close to native python. We can now use `if bool_var and condition` to
check the truthiness, and `if bool_var is True` to check identity.
2023-05-16 18:24:16 +00:00
Keren Zhou
674f9bf7a6 [FRONTEND] Better error messages for noinline functions (#1657)
```
at 10:18:def val_multiplier_noinline(val, i):
    return val * i

           ^
Function val_multiplier_noinline is marked noinline, but was called with non-scalar argument val:fp32[constexpr[128]]
```
2023-05-11 12:46:25 -07:00
Benjamin Chetioui
115964b780 [TESTS] Add regression test for issue #1601. (#1611)
Following up on #1603, I am adding a new file meant to contain
functional regression tests to the repository.
Let me know if another folder would be a more appropriate place for
these tests.

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-05-10 23:30:36 -07:00
Natalia Gimelshein
0daee68d71 [FRONTEND] Don't call set_device in tl.dot (#1646)
This breaks multiprocess compilation
2023-05-10 20:39:27 -04:00
Zahi Moudallal
fb40bf1954 [TEST] Fixed and re-enabled reduce test (#1644)
Re-enabled reduce test after fixing the %cst stride in the ttgir, and
modifying the sweep parameters to make sure the shape per CTA to be less
than or equal to the tensor shape.
2023-05-10 15:15:11 -07:00
Mario Lezcano Casado
6b1af5fe37 [FRONTEND] Add support for scalar conditions in device_assert (#1641)
This sometimes happens in TorchInductor. See
https://github.com/pytorch/pytorch/pull/100880.
More generally, it's useful to be able to write `tl.device_assert(False,
msg)`.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2023-05-09 23:05:00 -07:00
Keren Zhou
b19b274d93 [FRONTEND] Fix return op related control flow issues (#1637)
- Case 1: Return after static control flow is taken. Peel off
instructions after the first `return` for each basic block.

```python
if static_condition:
    tl.store(...)
    return
return
```

- Case 2: Return exists in both `if` and `else` branches of an inlined
`JITFunction` function

```python
def foo():
    if dynamic_condition:
        return a
    else:
        return b
```

- Case 3: Return exists in a `JITFunction` from another module

```python
import module
if cond:
    a = module.func()
```

- Case 4: A chain of calls through undefined local variables

```python
import module
if cond:
    a = x
    a = a.to(tl.int32).to(tl.int32)
```

- Case 5: Call a function `func` without returning variables. `func` is
recognized as an `Expr` first instead of a `Call`.

```python
if cond:
    foo()
else:
    bar()
```

- Case 6: Call a `noinline` function. We don't need to check if the
function contains any return op.
2023-05-09 12:51:14 -04:00
Michaël Benesty
858a2f0a5e [FRONTEND] Added interpreter mode (#1573)
Simple mechanism to run Triton kernels on PyTorch for debugging purpose
(upstream from Kernl).

Todo:
- random grid iteration
- support of atomic ops
- more unit tests
- cover new APIs?
2023-05-08 14:28:20 -07:00
Zahi Moudallal
125d9d1cc7 [TEST] Added convert layout test from/to sliced blocked/mma (#1620) 2023-05-06 00:20:52 +00:00
Keren Zhou
fd381e2336 [BACKEND] Allow noinline functions to return multiple values of primitive types (#1623)
Fix https://github.com/openai/triton/issues/1621
2023-05-05 19:25:58 +00:00
Zahi Moudallal
e2ae2c6c48 [BACKEND] Modified store op thread masking (#1605) 2023-05-04 17:15:05 -07:00
peterbell10
deb2c71fb4 [FRONTEND] Add tl.expand_dims (#1614)
This exposes `semantic.expand_dims` in the public API and builds upon it
with support for expanding multiple dimensions at once. e.g.
```python
tl.expand_dims(tl.arange(0, N), (0, -1))  # shape = [1, N, 1]
```

Compared to indexing with `None`, this API is useful because the
dimensions can be constexpr values rather than hard-coded into the
source. As a basic example
```python
@triton.jit
def max_keepdim(value, dim):
    res = tl.max(value, dim)
    return tl.expand_dims(res, dim)
```
2023-05-04 09:46:24 -07:00
Zahi Moudallal
3449a9d40d Zahi/slice reduce rebased (#1594)
[BACKEND] Enable slice layout support for reduce op
2023-05-01 18:00:23 -07:00
Keren Zhou
3aff0102a3 [FRONTEND] Fix calling local variables’ attribute functions in the if statement (#1597)
If `node.func` is an `ast.Attribute`, it won't cause an early return.
(Not sure if I interpret it correctly)

https://github.com/openai/triton/issues/1591
2023-04-30 15:41:16 -07:00
Keren Zhou
ee864048b3 [FRONTEND][BACKEND] Add the noinline annotation for triton.jit (#1568)
# Introducing the `noinline` Parameter for Triton JIT Decorator

We're excited to introduce a new parameter, `noinline`, that can be
added to the `jit` decorator in Triton. This parameter allows developers
to specify that a particular Triton function should not be inlined into
its callers. In this post, we'll dive into the syntax, purpose, and
implementation details of this new feature.

## Syntax

To use the `noinline` parameter, simply add `noinline=True` to the `jit`
decorator for the function that you don't want to be inlined. Here's an
example:

```python
@triton.jit(noinline=True)
def device_fn(x, y, Z):
    z = x + y
    tl.store(Z, z)

def test_noinline():
    @triton.jit
    def kernel(X, Y, Z):
        x = tl.load(X)
        y = tl.load(Y)
        device_fn(x, y, Z)
```

In this example, the `device_fn` function is decorated with
`@triton.jit(noinline=True)`, indicating that it should not be inlined
into its caller, `kernel`.

## Purpose

The `noinline` parameter serves several key purposes:

- Reducing code size: By preventing inlining, we can reduce the size of
the compiled code.
- Facilitating debugging: Keeping functions separate can make it easier
to debug the code.
- Avoiding common subexpression elimination (CSE) in certain cases: CSE
can sometimes be avoided by using the `noinline` parameter to reduce
register pressure.
- Enabling dynamic linking: This parameter makes it possible to
dynamically link Triton functions.

## Implementation

The implementation of the `noinline` parameter involves significant
changes to three analysis modules in Triton: *Allocation*, *Membar*, and
*AxisInfo*. Prior to this update, these modules assumed that all Triton
functions had been inlined into the root kernel function. With the
introduction of non-inlined functions, we've had to rework these
assumptions and make corresponding changes to the analyses.

### Call Graph and Limitations

<div style="text-align: center;">
<img
src="https://user-images.githubusercontent.com/2306281/234663904-12864247-3412-4405-987b-6991cdf053bb.png"
alt="figure 1" width="200" height="auto">
</div>

To address the changes, we build a call graph and perform all the
analyses on the call graph instead of a single function. The call graph
is constructed by traversing the call edges and storing them in an edge
map. Roots are extracted by checking nodes with no incoming edges.

The call graph has certain limitations:

- It does not support recursive function calls, although this could be
implemented in the future.
- It does not support dynamic function calls, where the function name is
unknown at compilation time.

### Allocation

<div style="text-align: center;">
<img
src="https://user-images.githubusercontent.com/2306281/234665110-bf6a2660-06fb-4648-85dc-16429439e72d.png"
alt="figure 2" width="400" height="auto">
</div>

In Triton, shared memory allocation is achieved through two operations:
`triton_gpu.convert_layout` and `triton_gpu.alloc_tensor`. The
`convert_layout` operation allocates an internal tensor, which we refer
to as a *scratch* buffer, while the `alloc_tensor` operation returns an
allocated tensor and is thus known as an *explicit* buffer.

To accommodate the introduction of function calls, we are introducing a
third type of buffer called a *virtual* buffer. Similar to scratch
buffers, virtual buffers are allocated internally within the scope of a
function call, and the buffers allocated by the called functions remain
invisible to subsequent operations in the calling function. However,
virtual buffers are distinct from scratch buffers in that the call
operation itself does not allocate memory—instead, it specifies the
total amount of memory required by all the child functions being called.
The actual allocation of buffers is performed by individual operations
within these child functions. For example, when invoking edge e1, no
memory is allocated, but the total amount of memory needed by function B
is reserved. Notably, the amount of shared memory used by function B
remains fixed across its call sites due to the consideration of dynamic
control flows within each function.

An additional challenge to address is the calculation of shared memory
offsets for functions within a call graph. While we can assume a shared
memory offset starting at 0 for a single root function, this is not the
case with a call graph, where we must determine each function's starting
offset based on the call path. Although each function has a fixed memory
consumption, the starting offset may vary. For instance, in Figure 2,
the starting offset of function C through edges e1->e2 differs from that
through edges e2->e4. To handle this, we accumulate the starting offset
at each call site and pass it as an argument to the called function.
Additionally, we amend both the function declaration and call sites by
appending an offset variable.

### Membar

<div style="text-align: center;">
<img
src="https://user-images.githubusercontent.com/2306281/234665157-844dd66f-5028-4ef3-bca2-4ca74b8f969d.png"
alt="figure 3" width="300" height="auto">
</div>

The membar pass is dependent on the allocation analysis. Once the offset
and size of each buffer are known, we conduct a post-order traversal of
the call graph and analyze each function on an individual basis. Unlike
previous analyses, we now return buffers that remain unsynchronized at
the end of functions, allowing the calling function to perform
synchronization in cases of overlap.

### AxisInfo

<div style="text-align: center;">
<img
src="https://user-images.githubusercontent.com/2306281/234665183-790a11ac-0ba1-47e1-98b1-e356220405a3.png"
alt="figure 4" width="400" height="auto">
</div>

The AxisInfo analysis operates differently from both membar and
allocation, as it traverses the call graph in topological order. This is
necessary because function arguments may contain axis information that
will be utilized by callee functions. As we do not implement
optimizations like function cloning, each function has a single code
base, and the axis information for an argument is determined as a
conservative result of all axis information passed by the calling
functions.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-04-28 14:59:04 -07:00
Keren Zhou
e326ff74d1 [TEST] Fix test cache (#1588)
To avoid puzzling segment fault problems caused by multiprocessing, this
PR:

- Uses "spawn" instead of "fork".
- Define the `instance_descriptor` namedtuple globally.
- Make the `kernel_sub` JITFunction defined by the child process only.
2023-04-28 07:39:06 -07:00
Keren Zhou
8f7ec23401 [FRONTEND] Refine arithmetic checks and corresponding tests for extern_elementwise (#1577)
The current main would fail on `math.scalbn` because we implicitly cast
the first argument from `int32` to `float32`, while the function only
accepts `int32` as the first argument and `float32` as the second
argument.

So we update the type matching logic as follows:

1. Check if there's a type tuple that matches the types of the input
arguments
2. If yes, we don't allow arithmetic check.
3. If not, we will do arithmetic check to implicitly cast types among
arguments.
4. If we still don't find a corresponding function that accepts the
casted types, throwing an error.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-04-25 14:25:45 -07:00
Natalia Gimelshein
d5969b81fe [FRONTEND] Test pow with mixed dtypes (#1575)
Also reverts #1541 that breaks this test.
2023-04-24 21:38:40 -04:00
Philippe Tillet
a359b62ef3 [RUNTIME] Lazy driver initialization (#1571) 2023-04-24 15:16:09 -07:00
Natalia Gimelshein
7d1a95b046 [TESTS] Added test for avg_pool_bwd kernel (#1540)
This kernel was briefly broken on main, prevent future regressions.

---------

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2023-04-17 21:20:34 -07:00
peterbell10
a3c3e5a3a1 [TESTS][OPTIMIZER] enable tests for argmin/max and fix some bugs (#1537)
`argmin`/`argmax` is currently only tested in 1d and when we enable the
tests for 2d it reveals a few bugs.
2023-04-17 18:47:31 -07:00
Philippe Tillet
608ec061c1 [TESTING] Added more tests for annotations and autotuner (#1533)
Essentially identical to #538, but it fails formatting tests and I don't
want to ping the author on a weekend.
2023-04-15 19:44:08 -07:00
Philippe Tillet
f367647b38 [FRONTEND] Added tl.extra.cuda.smid (#1532) 2023-04-15 14:42:59 -07:00
Philippe Tillet
e5c7d2a83c [FRONTEND] cleaned up language; added frontend function for globaltimer special register (#1525) 2023-04-14 15:29:27 -07:00
peterbell10
0d76c4ca95 [FRONTEND] Rename tl.reduction -> tl.reduce and improve testing (#1521)
`tl.reduction` is currently tested indirectly through the existing
reduction operators, but it's good to have a direct test for the
function itself.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-04-14 14:35:31 -07:00
peterbell10
e152183570 [FRONTEND][BACKEND] ReduceOp to support arbitrary reduce operations (#1305)
Fixes #1285

This changes `tt.reduce` to replace `redOp` by a region containing
arbitrary code. For example, `tl.sum` is now lowered as:
```mlir
%res = "tt.reduce"(%arg0) ({
^bb0(%arg1: f32, %arg2: f32):
  %add = arith.addf %arg1, %arg2 : f32
  tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<128x128xf32>) -> tensor<128xf32>
```
Support for index reductions at the MLIR level are also dropped in favor
of simultaneous reductions over multiple tensors. Which generalizes the
code without loss of performance. So for example `argmin` gets lowered
as:
```mlir
  %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
  %8 = tt.view %7 : (tensor<256xi32>) -> tensor<1x256xi32>
  %9:2 = "tt.reduce"(%6, %8) ({
  ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32):
    %14 = arith.cmpf olt, %arg4, %arg6 : f32
    %15 = arith.cmpf ogt, %arg4, %arg6 : f32
    %16 = arith.cmpi slt, %arg5, %arg7 : i32
    %17 = arith.select %16, %arg5, %arg7 : i32
    %18 = arith.select %15, %arg7, %17 : i32
    %19 = arith.select %14, %arg5, %18 : i32
    %20 = arith.cmpf olt, %arg4, %arg6 : f32
    %21 = arith.select %20, %arg4, %arg6 : f32
    tt.reduce.return %21, %19 : f32, i32
  }) {axis = 1 : i32} : (tensor<1x256xf32>, tensor<1x256xi32>) -> (tensor<1xf32>, tensor<1xi32>)
```
2023-04-13 01:37:39 +00:00
Philippe Tillet
5b9119117b [CI] No longer install triton in editable mode to run tests (#1476) 2023-04-12 17:55:44 -07:00
Philippe Tillet
02e3c18f04 [TESTING] clean up testing.do_bench (#1513) 2023-04-11 20:05:58 -07:00
zahimoud
fd34b20fba [BACKEND] Fixed bug in reduce; add tests 2023-04-11 18:09:18 -07:00
Philippe Tillet
0fedf6b79a [TESTS] disable launch latency test (#1510) 2023-04-11 10:31:16 -07:00