Commit Graph

1330 Commits

Author SHA1 Message Date
Philippe Tillet
eea0718445 [TESTING] better cudagraph-based benchmarking (#2394) 2023-09-25 21:41:26 -07:00
Philippe Tillet
80adbbb87b [OPTIMIZER] fix-up acceleratematmul (#2392) 2023-09-25 17:05:04 -07:00
Hongtao Yu
00c089d897 [DOCS] tweak install instructions for custom llvm build. (#2381) 2023-09-25 12:18:28 -07:00
ben-zhang-609
d040b58547 [HOPPER] fix ref check failure of flash attention with mma v3 (#2384) 2023-09-25 11:29:49 -07:00
Thomas Raoux
6bc1d9e1be [BACKEND] Support MMA V3 with register operand (#2375)
MMA V3 support taking operand A from register. This helps for chained
matmul operations like in attention.
Add an optimization to use this mode when it helps and add the lowering
for it.
2023-09-25 10:43:54 -07:00
kshama-msft
8ae2ae4f40 [DOCS] update README.md (#2391)
Remove conference details.
2023-09-25 10:42:04 -07:00
kshama-msft
e7abafe4b4 [DOCS] create tritonconf2023.md (#2390)
File and video location.
2023-09-25 10:41:45 -07:00
Thomas Raoux
a4dbdefe3b [BACKEND] Use shuffle intrinsics instead of inline asm (#2378)
This will ensure we get the proper "convergent" semantic for those
instructions
2023-09-23 11:50:37 -07:00
Keren Zhou
57fc6d1f13 [BACKEND] shfl ptx insts should have side effects (#2376)
Otherwise, llvm pass could generate very weird structure of CFG and
yield incorrect results.

https://github.com/openai/triton/issues/2361
2023-09-23 10:05:20 -07:00
edimetia3d
cb83b42ed6 [FRONTEND] using closure to create jit launcher (#2289)
Hi,

I'm adding some features to
`triton.runtime.jit.JITFunction_make_launcher` and found it is hard to
debug it:
1. The inlined Python code is hard to inspect in my editor.
2. My debugger fails to step into these inlined codes.

In response, I've introduced some code to solve these issues. My
modifications include:
~~1. Refactoring the launcher's inline Python code, ensuring it only
relies on the "self" object.~~
~~2. Add a utility method that generates a temporary file to create a
launcher when debugging kernel in main module~~
Using a closure to hold the launcher's body

Because this features might be good to others, I have initiated this
Pull Request.

~~Tests are yet to be added; if this submission might be accepted, I
will add it later.~~
Since this change is a refactor, no new test was added.
2023-09-22 17:01:54 -07:00
ian Bearman
215b2e77a1 Add Shared Middle Layer to Triton via Plug-In (#2374)
This PR leverages the plug-in system to add a shared middle-layer to
Triton.

Currently the middle layer is not complete but has enough functionality
to demonstrate how it can work. The general idea is that Triton IR is
lowered into an MLIR core dialect to allow it to be both shared across
Triton targets as well as allow back-ends to be shared with other
languages.

The basic intended architecture looks like this:

[Triton IR] -> [Middle Layer] -> [HW specific IR]

The middle-layer uses MLIR's Linalg and Tenor Dialects for operations on
Triton block values. Operations on Triton pointers use the Memref
Dialect.

## Usage
To include the shared middle-layer in your Triton build do `export
TRITON_CODEGEN_TRITON_SHARED=1` before invoking your build. Once it is
part of the build it can be leveraged in two ways:

### Stand-Alone
The middle layer can be used as a stand-alone component to convert
Triton dialect to the middle layer dialects.

Stand-alone example:
```
triton-shared-opt --triton-to-linalg %file
```

### Backend Component
The middle layer can also be used as a component in a Triton back-end by
adding the cmake targets it produces and its headers files to that
back-end. An example back-end will be published at a later date.

## Implementation details

Even though a valid triton program can perform load and store in
arbitrary memory locations, the prototype only supports lowering
programs that have structured memory access patterns.

### Analyses

As part of the conversion process, there are three important analyses:

1. Pointer analysis:
+ This analysis is responsible for extracting structured memory access
patterns from a `triton` program during load and store; it walks the IR
and visits relevant instructions to build strided memory accesses in the
`memref` dialect. The analysis is still in its early stage and does not
support all scenarios.

2. Use analysis:
+ After "Pointer analysis", instructions that are part of memory address
calculation will no longer be necessary in a triton program because
their semantics have now been captured by `memref` operations
representing strided memory accesses. To aid with removing these
instructions safely, we perform `Use analysis` to mark which
instructions are used *only* in address calculation (called `MetaUse`)
or used in *both* address calculation and data manipulation (called
`MixedUse`) operations. Those that are `MixedUse` are cloned and have
their users adjusted accordingly with the goal of separating out the
`MetaUse` ops so that they can be safely deleted.

3. Mask analysis:
    + This analysis is responsible for handling masked loads and stores.

### Conversion strategy

We introduce the `TritonToLinalg` pass that converts the `triton`
dialect to the `linalg` dialect on *tensors*. This means the resulting
IR is fully compatible with `linalg` tiling and fusion transformation
passes. As mentioned in the `Pointer analysis`'s description, we do
however have to deal with memref instructions at the load and store
boundaries and have to convert them to tensors using
`bufferization.to_tensor`. Here's a simple example of what the IR looks
like:

```mlir
tt.func @kernel(%afloat : !tt.ptr<bf16>, %res : !tt.ptr<bf16>) {
  %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
  %1 = tt.splat %afloat : (!tt.ptr<bf16>) -> tensor<128x!tt.ptr<bf16>>
  %2 = tt.addptr %1, %0 : tensor<128x!tt.ptr<bf16>>, tensor<128xi32>
  %afm = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xbf16>
  %3 = "tt.reduce"(%afm) ({
  ^bb0(%arg5: bf16, %arg6: bf16):
    %21 = arith.addf %arg5, %arg6 : bf16
    tt.reduce.return %21 : bf16
  }) {axis = 0 : i32} : (tensor<128xbf16>) -> bf16
  tt.store %res, %3 : bf16
  tt.return
}
```

after conversion:

```mlir
func.func @kernel(%arg0: memref<*xbf16>, %arg1: memref<*xbf16>, %arg2: i32, %arg3: i32, %arg4: i32) {
    %cst = arith.constant 0.000000e+00 : f32
    %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [128], strides: [1] :
        memref<*xbf16> to memref<128xbf16, strided<[1]>>
    %alloc = memref.alloc() : memref<128xbf16>
    memref.copy %reinterpret_cast, %alloc : memref<128xbf16, strided<[1]>> to memref<128xbf16>
    %0 = bufferization.to_tensor %alloc restrict writable : memref<128xbf16>
    %1 = bufferization.alloc_tensor() : tensor<f32>
    %inserted = tensor.insert %cst into %1[] : tensor<f32>
    %reduced = linalg.reduce ins(%0 : tensor<128xbf16>) outs(%inserted : tensor<f32>) dimensions = [0]
      (%in: bf16, %init: f32) {
        %3 = arith.extf %in : bf16 to f32
        %4 = arith.addf %3, %init : f32
        linalg.yield %4 : f32
      }
    %extracted = tensor.extract %reduced[] : tensor<f32>
    %2 = arith.truncf %extracted : f32 to bf16
    %reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [1], strides: [1] :
        memref<*xbf16> to memref<1xbf16, strided<[1]>>
    affine.store %2, %reinterpret_cast_0[0] : memref<1xbf16, strided<[1]>>
    return

}

```

Important details to note:

+ `tt.load` (together with all of its related address calculation
instructions such as `tt.addptr` and `tt.splat`) are lowered to a
combination of `memref.reinterpret_cast`, `memref.alloc`, and
`memref.copy`. After the initialization of the local buffer, we convert
the memref back to a tensor using `bufferization.to_tensor`; this op is
automatically removed during bufferization.

+ `tt.store` lowers to a combination of `memref.reinterpret_cast` and
either `affine.store` or `memref.tensor_store`:

```
%reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [...] memref<*xf32> to memref<1024xf32>
%extracted_slice = tensor.extract_slice %15[0] [%21] [1] : tensor<1024xf32> to tensor<?xf32>
%subview = memref.subview %reinterpret_cast[0] [%21] [1] : memref<1024xf32> to memref<?xf32>
memref.tensor_store %extracted_slice, %subview : memref<?xf32>
```

+ element-wise `arith` and `math` operators are converted to their
corresponding `linalg.generic` version.
+ `tt.dot` becomes `linalg.matmul`.
+ `tt.reduce` becomes `linalg.reduce`; known limitation: only support
`addf` and `maxf` reduction in the reduction body for now.

### Testing

The prototype was tested on the following triton kernel examples:

1. vector addition
2. fused softmax
3. matrix multiplication
4. layer normalization
5. fused attention

In addition to testing on the tutorial kernels, I have also added many
lit tests covering various scenarios.

## Recognition

The work here represents contributions from myself as well as many of my
colleagues at Microsoft. I especially want to call out @nhat-nguyen and
@haishanzzz who were major contributors to this work.
2023-09-22 15:29:31 -07:00
Thomas Raoux
840e7e7b53 [BACKEND] Improve decision of MMA dimension on H100 (#2373)
When there is a chain of mma ops we want to pick the same shape to avoid
conversions. This improves the detection going through for loops.
This fixes a crash in tutorial bw attention.

We might want to change this logic and convert the format to allow more
efficient MMA at some point.
2023-09-22 15:21:56 -07:00
Bin Fan
1724604bd9 [DOCS] Add a tutorial example of grouped gemm (#2326) 2023-09-22 11:16:35 -07:00
q.yao
413b18eb73 [FROJTEND] fix core.dtype.__repr__ (#2372)
`function_type` does not have a `name` field, which leads to an error
when debugging with gdb.
2023-09-22 08:34:20 -07:00
Zahi Moudallal
293b7fd592 [TESTING] cleanup (#2293)
Co-authored-by: Philippe Tillet <phil@openai.com>
2023-09-22 05:37:14 +00:00
Philippe Tillet
c71ec14f31 [TEST] only test 4 configs without TF32 (#2370) 2023-09-21 21:23:19 -07:00
Alexander Zinoviev
d543eb1a36 [BACKEND] implement dot for INT8 on Turing (#2364)
Replace a single
mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32 instruction
that is used on Ampere with 4 x
mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32 instructions
for Turing

Extracted the Turing-int8, Turing-fp16 and Ampere to separate functions.

Somehow I messed up with my previous PR, so just open a new one.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-09-21 16:40:53 -07:00
Philippe Tillet
32c9d2bb8f [FRONTEND] improved error messages (#2363)
this is a combination of #1774 and #2006, which I cannot edit but fail
CI pre-commit hook
2023-09-21 15:05:57 -07:00
Philippe Tillet
c4bc3fd92f [BACKEND] Fix-up memory leak (#2365) 2023-09-21 13:46:30 -07:00
Thomas Raoux
e36c99b588 [BACKEND] Handle scan of function non commutative (#2362)
Make sure we accumulate in the right order for scans so that non
commutative operations are handled correctly.
2023-09-21 12:00:41 -07:00
Tori Baker
be9849bda9 [BACKEND] Set min bitwidth of shared store&load (#2358)
Using `i1` results in st.shared.b1, which does not exist. Set a min bit
width here to handle this case.

Resolves issue https://github.com/openai/triton/issues/2351

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2023-09-21 06:42:29 -07:00
peterbell10
8094f46632 [FRONTEND][BACKEND] Fix various atomic_rmw bugs (#2355)
This fixes a few bugs I've encountered
- `atomic_add` with int64/uint64 `Operation .add requires .u32 or .s32
or .u64 [...] for instruction 'atom'`
- `atomic_min/max` with float64 -> `ValueError('Cannot bitcast data-type
of size 64 to data-type of size 32')`
- `atomic_min/max` with float32 returns the old value as int32
2023-09-21 03:31:20 +00:00
ben-zhang-609
bcaf14755a [HOPPER] enable flash attention with tma (#2336) 2023-09-20 14:06:56 -07:00
Thomas Raoux
9cab885dff [BACKEND] Optimize wgmma with accumulator source equal to 0 (#2343)
Also add a test for MMA v3 reduction.
2023-09-20 14:05:12 -07:00
Keren Zhou
ed5a53057d [BACKEND] Handle repetitive threads in scan op when the tensor dim is small (#2345)
https://github.com/openai/triton/issues/2298
2023-09-20 12:25:52 -04:00
Dongdong Li
e5eda098b3 [TESTS] fix flash attention (#2086)
Co-authored-by: dongdongl <dongdongl@nvidia.com>
2023-09-20 14:23:46 +08:00
Justin Lebar
363182928c Add instructions for building with custom LLVM (#2344)
I tested these locally, seems to work for me.
2023-09-20 06:05:47 +00:00
Justin Lebar
5491707093 Switch pre-commit clang-format to v16.0.6. (#2342)
Google uses clang-format at LLVM HEAD.  clang-format's formatting is not
stable, so we want to minimize the difference between the pre-commit
clang-format and HEAD to minimize differences with Google's formatter.

In practice, it appears that there are no relevant changes to the
formatting, so this is a nop.  🤷

Tested by running `pre-commit run --all-files`.
2023-09-19 23:14:55 +00:00
Matthias Springer
ae07b7b3d3 Integration fixes for llvm/llvm-project#66754 (#2338)
llvm/llvm-project#66754 extends the `LoopLikeOpInterface`: the signature
of `getLoopBody` has changed. `ForOp::getRegion` can be used instead.

This change works with and without llvm/llvm-project#66754.
2023-09-19 23:00:33 +00:00
Shantanu
8e75e392ae [FRONTEND] Fix Python error handling in launch (#2334)
This was regressed by #2185 because we didn't realise CUDA_CHECK macro
could do Python calls (similar to what led to #2225). I think the
PyErr_Occurred got removed in that PR because there was missing error
handling before the call to _launch, so it looked like it was just in
the wrong place.

It looks like there are also potentially a couple places in cuda.c that
can return with error set, e.g. getDeviceProperties, memAlloc,
memcpyHtoD, memFree, tensorMapEncodeTiled etc, but those are all
pre-existing and not affected by recent changes.
2023-09-19 00:12:00 -07:00
Thomas Raoux
3a848e2729 [BACKEND] Relax patterns to move sink broadcast and hoist convert (#2331)
Improve patterns that sync broadcast to reduce the arithmetic density
and also hoist convert on top of expand_dims to do less work.

This address comments in https://github.com/openai/triton/pull/2274
2023-09-18 15:08:19 -07:00
Philippe Tillet
73dae775df [DOCS] improved fused attention tutorial (bwd pass) (#2332) 2023-09-18 15:07:41 -07:00
Keren Zhou
307b5caa49 [BACKEND] Fix scan issues on repetitive warps and improve perf when there's a single warp on the axis (#2330)
1. On the axis, using `getAxisNumWarpsWithUniqueData` instead of getting
the raw number of warps to avoid communication among warps that handle
the same piece of data.
2. When there's a single warp on the axis, using warp Intrinsics for
communication and skip shared memory.

Need a follow up PR for code clean up.
2023-09-18 17:45:05 -04:00
Matthias Springer
a9ae9886dc Integration fixes for llvm/llvm-project#66512 (#2328)
Some duplicate functions on `scf.for` have been removed in
llvm/llvm-project#66512. This PR works with and without
llvm/llvm-project#66512.
2023-09-18 15:22:06 +00:00
Philippe Tillet
894fa9e943 [RUNTIME][INTERPRETER] now also override __str__ method for tensors (#2325) 2023-09-17 16:49:30 -07:00
Philippe Tillet
e686b4d6d4 [FRONTEND] interpreter rewrite (#2321)
This is a new interpreter mode that shares semantic analysis with the
JIT'ed codepath and that the Triton core team is committed to maintain
2023-09-17 14:58:50 -07:00
Myeonghwan Ahn
2b066000aa [FRONTEND] fix matmul int8 overflow issue (#2297)
Previously on matmul, if inputs are int8, output was also int8.
This commit fixes the overflow problem with int32 output.
#2296
2023-09-17 16:41:02 +00:00
Stonepia
68e1bd162c [FRONTEND] fix xpu stages logic (#2305) 2023-09-17 09:19:14 -07:00
jon-chuang
4f2d995fad [FRONTEND] Explicitly forbid dot(.., out_dtype=bfloat16) (#2308)
Fixes: https://github.com/openai/triton/issues/2302
2023-09-17 09:15:06 +00:00
Justin Lebar
073aa16379 [BUILD] use ninja (#2318) 2023-09-17 02:08:04 -07:00
Philippe Tillet
c98671cf7c Revert "Update integration-tests.yml" (#2323)
reverts #2310 as recent changes to Triton-IR have broken third-party backends
2023-09-17 01:16:00 -07:00
Justin Lebar
0015611c17 [DOCS] Add build instrs for running in a virtualenv. (#2320)
On my machine, when I try to `pip install cmake` outside a virtualenv,
it gets mad at me and tells me to use apt.  Which doesn't quite work for
some reason.  Anyway maybe this is simple to Python people, but perhaps
worth mentioning.  Especially because we have `.venv` in gitignore
already.
2023-09-17 00:14:33 -07:00
Justin Lebar
41584c71a6 Add cuobjdump and nvsisasm to gitignore. (#2319)
Otherwise, these files show up in `git status` under
python/triton/third_party/cuda/bin/.
2023-09-17 01:53:46 +00:00
Thomas Raoux
bb949d1141 [BACKEND] Move struct optimization down the LLVM pipeline (#2312)
Move the optimization to remove phi of struct later in the optimization
pipeline to avoid interfering with CFG optimization.
2023-09-16 12:28:53 -07:00
Thomas Raoux
31b0c52142 [FRONTEND][BACKEND] Add flag to control accumulation for fp8 (#2300)
Change the dot to allow taking an initial accumulator and add a flag
that will allow the compiler to accumulate in a lower precision than the
output type.
On Hopper this flag is on by default which allows accumualting with
lower precision.
This only affect Hopper fp8 dot.
2023-09-15 18:42:54 -07:00
Michael Melesse
78a0b5dc2a [CI] update integration-tests.yml (#2310) 2023-09-15 18:38:15 -07:00
Zahi Moudallal
db5c793f82 [FRONTEND] Add sass to asm dict with lazy evaluation (#2309) 2023-09-15 15:31:43 -07:00
kshama-msft
ac1c216110 [DOCS] update README.md (#2311)
Triton conf registration closed.
2023-09-15 15:07:38 -07:00
Thomas Raoux
976aabdeb2 [BUILD] Fix few dependencies and layering issues to make lld work (#2307)
This fixes few problems that were preventing me to use lld linker.
2023-09-15 17:00:58 +00:00
Keren Zhou
08c1658957 [FRONTEND] Accommodate new triton IR format (#2294)
- Support memory space for pointers (e.g., `!tt.ptr<f32, 1>`).
- Support parsing function attribute, though not used yet.
2023-09-14 09:03:23 -07:00