Commit Graph

7 Commits

Author SHA1 Message Date
ian Bearman
215b2e77a1 Add Shared Middle Layer to Triton via Plug-In (#2374)
This PR leverages the plug-in system to add a shared middle-layer to
Triton.

Currently the middle layer is not complete but has enough functionality
to demonstrate how it can work. The general idea is that Triton IR is
lowered into an MLIR core dialect to allow it to be both shared across
Triton targets as well as allow back-ends to be shared with other
languages.

The basic intended architecture looks like this:

[Triton IR] -> [Middle Layer] -> [HW specific IR]

The middle-layer uses MLIR's Linalg and Tenor Dialects for operations on
Triton block values. Operations on Triton pointers use the Memref
Dialect.

## Usage
To include the shared middle-layer in your Triton build do `export
TRITON_CODEGEN_TRITON_SHARED=1` before invoking your build. Once it is
part of the build it can be leveraged in two ways:

### Stand-Alone
The middle layer can be used as a stand-alone component to convert
Triton dialect to the middle layer dialects.

Stand-alone example:
```
triton-shared-opt --triton-to-linalg %file
```

### Backend Component
The middle layer can also be used as a component in a Triton back-end by
adding the cmake targets it produces and its headers files to that
back-end. An example back-end will be published at a later date.

## Implementation details

Even though a valid triton program can perform load and store in
arbitrary memory locations, the prototype only supports lowering
programs that have structured memory access patterns.

### Analyses

As part of the conversion process, there are three important analyses:

1. Pointer analysis:
+ This analysis is responsible for extracting structured memory access
patterns from a `triton` program during load and store; it walks the IR
and visits relevant instructions to build strided memory accesses in the
`memref` dialect. The analysis is still in its early stage and does not
support all scenarios.

2. Use analysis:
+ After "Pointer analysis", instructions that are part of memory address
calculation will no longer be necessary in a triton program because
their semantics have now been captured by `memref` operations
representing strided memory accesses. To aid with removing these
instructions safely, we perform `Use analysis` to mark which
instructions are used *only* in address calculation (called `MetaUse`)
or used in *both* address calculation and data manipulation (called
`MixedUse`) operations. Those that are `MixedUse` are cloned and have
their users adjusted accordingly with the goal of separating out the
`MetaUse` ops so that they can be safely deleted.

3. Mask analysis:
    + This analysis is responsible for handling masked loads and stores.

### Conversion strategy

We introduce the `TritonToLinalg` pass that converts the `triton`
dialect to the `linalg` dialect on *tensors*. This means the resulting
IR is fully compatible with `linalg` tiling and fusion transformation
passes. As mentioned in the `Pointer analysis`'s description, we do
however have to deal with memref instructions at the load and store
boundaries and have to convert them to tensors using
`bufferization.to_tensor`. Here's a simple example of what the IR looks
like:

```mlir
tt.func @kernel(%afloat : !tt.ptr<bf16>, %res : !tt.ptr<bf16>) {
  %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
  %1 = tt.splat %afloat : (!tt.ptr<bf16>) -> tensor<128x!tt.ptr<bf16>>
  %2 = tt.addptr %1, %0 : tensor<128x!tt.ptr<bf16>>, tensor<128xi32>
  %afm = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xbf16>
  %3 = "tt.reduce"(%afm) ({
  ^bb0(%arg5: bf16, %arg6: bf16):
    %21 = arith.addf %arg5, %arg6 : bf16
    tt.reduce.return %21 : bf16
  }) {axis = 0 : i32} : (tensor<128xbf16>) -> bf16
  tt.store %res, %3 : bf16
  tt.return
}
```

after conversion:

```mlir
func.func @kernel(%arg0: memref<*xbf16>, %arg1: memref<*xbf16>, %arg2: i32, %arg3: i32, %arg4: i32) {
    %cst = arith.constant 0.000000e+00 : f32
    %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [128], strides: [1] :
        memref<*xbf16> to memref<128xbf16, strided<[1]>>
    %alloc = memref.alloc() : memref<128xbf16>
    memref.copy %reinterpret_cast, %alloc : memref<128xbf16, strided<[1]>> to memref<128xbf16>
    %0 = bufferization.to_tensor %alloc restrict writable : memref<128xbf16>
    %1 = bufferization.alloc_tensor() : tensor<f32>
    %inserted = tensor.insert %cst into %1[] : tensor<f32>
    %reduced = linalg.reduce ins(%0 : tensor<128xbf16>) outs(%inserted : tensor<f32>) dimensions = [0]
      (%in: bf16, %init: f32) {
        %3 = arith.extf %in : bf16 to f32
        %4 = arith.addf %3, %init : f32
        linalg.yield %4 : f32
      }
    %extracted = tensor.extract %reduced[] : tensor<f32>
    %2 = arith.truncf %extracted : f32 to bf16
    %reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [1], strides: [1] :
        memref<*xbf16> to memref<1xbf16, strided<[1]>>
    affine.store %2, %reinterpret_cast_0[0] : memref<1xbf16, strided<[1]>>
    return

}

```

Important details to note:

+ `tt.load` (together with all of its related address calculation
instructions such as `tt.addptr` and `tt.splat`) are lowered to a
combination of `memref.reinterpret_cast`, `memref.alloc`, and
`memref.copy`. After the initialization of the local buffer, we convert
the memref back to a tensor using `bufferization.to_tensor`; this op is
automatically removed during bufferization.

+ `tt.store` lowers to a combination of `memref.reinterpret_cast` and
either `affine.store` or `memref.tensor_store`:

```
%reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [...] memref<*xf32> to memref<1024xf32>
%extracted_slice = tensor.extract_slice %15[0] [%21] [1] : tensor<1024xf32> to tensor<?xf32>
%subview = memref.subview %reinterpret_cast[0] [%21] [1] : memref<1024xf32> to memref<?xf32>
memref.tensor_store %extracted_slice, %subview : memref<?xf32>
```

+ element-wise `arith` and `math` operators are converted to their
corresponding `linalg.generic` version.
+ `tt.dot` becomes `linalg.matmul`.
+ `tt.reduce` becomes `linalg.reduce`; known limitation: only support
`addf` and `maxf` reduction in the reduction body for now.

### Testing

The prototype was tested on the following triton kernel examples:

1. vector addition
2. fused softmax
3. matrix multiplication
4. layer normalization
5. fused attention

In addition to testing on the tutorial kernels, I have also added many
lit tests covering various scenarios.

## Recognition

The work here represents contributions from myself as well as many of my
colleagues at Microsoft. I especially want to call out @nhat-nguyen and
@haishanzzz who were major contributors to this work.
2023-09-22 15:29:31 -07:00
Michael Melesse
c6d33dcebf [ROCM] Core Functionality for AMD (#1983)
* this pr adds a third party backend for triton that works on AMD 
* this expose a lot of the work that has been done in our
[fork](https://github.com/ROCmSoftwarePlatform/triton)
* most unit tests on `test_core.py` pass
* it skips some unit tests for various reasons
* we plan to follow up with more prs improving Functionality and
Performance in the future

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-08-31 14:02:00 -07:00
Wang Weihan
b27a91a113 [FRONTEND] Enable triton to support register thirdparty backend at runtime (#1643)
This PR intends to provide a mechanism to support a third-party backend
at runtime to generate the backend-specific code.

The mechanism provided a common class to abstract the third-party
backend logic and two essential functions to register and get the
third-party backend at runtime.

- `BaseBackend`: A common class to abstract the third-party backend
logic
- `register_backend`: Register a third-party backend with a given device
type
- `get_backend`: Get the third-party backend with a given device type

Generally, a third-party backend must inherit from `BaseBackend` and
implement all the member functions according to the backend
characteristics. As long as the backend implementation is ready, the
third-party backend can invoke `register_backend` to register it under a
given device. During the kernel compilation and execution, the mechanism
will get the registered backend to generate the kernel and launcher code
for a given device.

This PR added a dummy backend to simulate a third-party backend and
demonstrate the usage.

-
[test_device_backend.py](https://github.com/openai/triton/pull/1643/files#diff-bbe4d50624f2d11bf17c878a1ed4d422918c124c182cf9357b993240c385bea1):
To define a third-party backend and register the backend
-
[ExtensionBackend](https://github.com/openai/triton/pull/1643/files#diff-bbe4d50624f2d11bf17c878a1ed4d422918c124c182cf9357b993240c385bea1R123):
Inherit from the `BaseBackend` and implement some specific logic like
[filter out some compile
stages](https://github.com/openai/triton/pull/1643/files#diff-bbe4d50624f2d11bf17c878a1ed4d422918c124c182cf9357b993240c385bea1R129-R135)
- [Register the `ExtensionBackend` for
`CPU`](https://github.com/openai/triton/pull/1643/files#diff-bbe4d50624f2d11bf17c878a1ed4d422918c124c182cf9357b993240c385bea1R279)
  
-
[extension_backend.c](https://github.com/openai/triton/pull/1643/files#diff-169c1d08b3a0a7b343cfa3258fbc32b47e0f6c46305a112652fa1bdaaec89d29):
To provide the utility function to load kernel binary and get the
backend properties.
2023-06-09 09:09:59 -07:00
cloudhan
323843cde8 [BUILD] stop depending on dlfcn-win32 by implementing dladdr natively with WIN32 API (#1674)
Co-authored-by: Philippe Tillet <phil@openai.com>
2023-05-16 07:19:36 +00:00
Philippe Tillet
25e1b36785 Revert "[pybind11] Use git-submodule for pybind11" (#701)
Reverts openai/triton#699
2022-09-23 12:25:38 -07:00
Shintaro Iwasaki
61d104ab3a [FRONTEND] Use git-submodule for pybind11 (#699)
This PR changes the `pybind11` source code management from copy-paste to
a package controlled by git-submodule.

See the discussion in #694 for details.
2022-09-23 09:55:03 -07:00
Victor
73b04d71b2 Fixes for building on Windows (#382)
* make C++ code compatible with Windows + MSVC

* added dlfcn-win32 for cross-platform dlopen

* fixed building and pip install on Windows

* fixed shared library file name under Windows
2021-12-07 14:10:58 -08:00