Commit Graph

1120 Commits

Author SHA1 Message Date
Philippe Tillet
54f1ac950e [CI] disable AMD CI (#2045) 2023-08-07 12:03:26 -07:00
Philippe Tillet
223c2d32a2 [CI] disable XPU tests (not compiling) (#2044)
cc @EikanWang . I'm disabling this for now since it broke with the H100
merge, but please feel free to fix the compilation errors and submit a
PR.
2023-08-07 11:56:16 -07:00
goostavz
f1512bded1 Initial code merge of Hopper support (#2036)
The initial code merge of Nvidia Hopper features support. Please be
aware that the code merge is not finished yet and the trouble-shooting
is still ongoing. The new hardware features (GMMA, TMA, STMATRIX etc.)
and automatic warp-specialization are experimental for now and turned
off by default. It is recommended for a trial when version 3.0 is
released.

The work is contributed by:
ben-zhang-609, bealwang, donproc, qliu93, jsh20, allatit23, LyricZhao,
ivanyinwz, goostavz & yangjunpro
from Nvidia, in cooperation with:
ptillet, Jokeren, ThomasRaoux & zahimoud
from OpenAI.

Co-authored-by: Goostav Zhu <gzhu@nvidia.com>
2023-08-07 09:53:04 +08:00
Philippe Tillet
5df904233c [BACKEND] now clamping fp8e4b15 in conversion from fp16 (#2017) 2023-07-31 18:50:42 -07:00
Yan Chunwei
89b0b79d75 [FRONTEND] fix the silent return issue in AOT launcher (#2013)
In the current link.py, it produces the launcher code as below:

```python
CUresult matmul_fp16xfp16_16x16x16(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, CUdeviceptr C, CUdeviceptr A, CUdeviceptr B, int32_t stride_cm, int32_t stride_am, int32_t stride_bk){
  if ((C % 16 == 0) && (A % 16 == 0) && (B % 16 == 0) && (stride_cm % 16 == 0))
    return matmul_fp16xfp16_16x16x16_688cc413_0d1d2d3d45d(stream, gX, gY, gZ, C, A, B, stride_cm, stride_am, stride_bk);
  // ...
  if ((C % 16 == 0) && (A % 16 == 0) && (B % 16 == 0))
    return matmul_fp16xfp16_16x16x16_7c0255bf_0d1d2d345(stream, gX, gY, gZ, C, A, B, stride_cm, stride_am, stride_bk);
}
```
Note that, when the input does not match any of the if branches, it will
do nothing, and the compiler should make it return 0 as a default
behavior, which equals to `CUDA_SUCCESS`, this doesn't match the
expectation.

This PR adds a `return CUDA_VALUE_ERROR;` to the tail of launchers, and
it produces code like:

```c++
CUresult matmul_fp16xfp16_16x16x16(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, CUdeviceptr C, CUdeviceptr A, CUdeviceptr B, int32_t stride_cm, int32_t stride_cn, int32_t stride_am, int32_t stride_ak, int32_t stride_bk, int32_t stride_bn){
  if ((C % 16 == 0) && (A % 16 == 0) && (B % 16 == 0) && (stride_cm == 1) && (stride_cn == 1) && (stride_am == 1) && (stride_ak == 1) && (stride_bk % 16 == 0) && (stride_bn == 1))
    return matmul_fp16xfp16_16x16x16_1f18a6da_0d1d2d3c4c5c6c7d8c(stream, gX, gY, gZ, C, A, B, stride_bk);

  return CUDA_ERROR_INVALID_VALUE;
}
```

And it requires users to check the result in their application, which I
think should match the initial AOT ideas.
2023-07-31 09:59:28 -07:00
Philippe Tillet
52c146f66b [OPTIMIZER][BACKEND] significantly cleaner handling of mixed-precision kernels (#1949)
we currently have a very janky approach to optimizing mixed-precision
matmul workloads, where some layout combinations (e.g., NT matmul) were
explicitly pattern-matched to take a more optimized codepath. Attempt at
unifying all the codepaths to codegen cp.async failed, due to bugs in
SharedToDotOperandMMAv2.cpp.

This PR fixes said bugs, add some assertions for SharedToDotOperandMMAv2
modes that aren't well supported, and greatly simplify our handling of
element-wise operations between load and conversions to DotOperand.
2023-07-28 10:29:42 -07:00
Bin Fan
2689f4a3b0 [TOOLS][AOT] some issues in equal_to_1 hint (#1998)
- Change test_aot.py to actually use equal_to_1 hint

- In the client function, equal_to_1 parameters are not specialized,
because AOT clients may not know the details of Triton argument
specialization, they still want to use the same parameter list as they
write the Triton kernel. The generated kernels has specialized argument
list, the generated dispatcher code will make sure the correct arguments
from the original full argument list are passed.

- Fixed a bug in _match_suffix in link.py. Previously it assumes each
parameter has a suffix of either ‘d’ or ‘c’, but in fact sometimes a
parameter doesn’t have a suffix, like 0d1d2d34c56c78c
2023-07-27 16:07:49 -07:00
Shantanu
4f1b2ea8d7 [FRONTEND] fix error with -> None return annotation (#1987)
None is not a type, so you get:
```
    self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
E   TypeError: argument of type 'NoneType' is not iterable
```

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-07-25 18:49:45 -07:00
Phil Tillet
db695c093f [TUTORIALS] fix format 2023-07-25 18:16:39 -07:00
janEbert
62a8afa403 [TUTORIALS] Support FlashAttention-2 reference (#1984)
Uses FlashAttention-2 if available, otherwise acts as before (if
FlashAttention-1 is available, that is used, otherwise the
FlashAttention reference benchmark is not run).

I decided to keep the same name for the imported function, but feel free
to make me change that.
2023-07-24 13:54:01 -07:00
Thomas
e6216047b8 [BACKEND] Upgrade the max PTX version allowed to 8.2 (#1982) 2023-07-23 19:56:01 -07:00
Izzy Putterman
de6f053c0f [TRITON][OPS] add Flash Attention v2 to Ops (#1970)
I also dropped the do_scaled as it is no longer needed (no scaling done
to the do in v2).

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-07-23 14:07:15 -07:00
youkaichao
c9ab44888e [FRONTEND] improve the process of finding libcuda.so and the error message (#1981)
`triton` uses `whereis` command to find `libcuda.so`, which is intended
to find binary, source, and manual page files. When `libcuda.so` is not
properly setup, the `whereis` command ends up with
`/usr/share/man/man7/libcuda.7`, which is not the place to look for.

This PR uses `ldconfig -p` to reliably find `libcuda.so`.

In my case, I find that I have a `libcuda.so.1` file, but it is not
linked to `libcuda.so`. Therefore `ld` cannot find the library to link.
After creating the linking, I was able to run `triton` successfully.

Therefore, I improve the code by first invoking `ldconfig -p`, and
checking `libcuda.so` strings first. These might be possible library to
link against. If the literal `libcuda.so` file is not found, then I
raise an error and tells the user that a possible fix is to create a
symlink file.
2023-07-23 10:31:07 -07:00
Philippe Tillet
66eda76e45 [FRONTEND][BACKEND] no longer serialize float8e4b15 (#1979)
We had a number of complains that the previous packed format was
error-prone and may not yet be worth the 2 SASS instruction saved per 4
conversions
2023-07-21 22:44:55 -07:00
Philippe Tillet
1db3bdc52e [BACKEND] avoid code duplication for fully warp-synchronous reductions (#1978) 2023-07-21 16:06:00 -07:00
Philippe Tillet
07c346b948 [OPTIMIZER] Falls back to using RemSI in the pipeline pass for now (#1972)
This is strange. Using RemUI should be strictly better, but it can cause
up to 20% performance regression in some cases. I am reverting to RemSI
pending investigation
2023-07-19 22:06:51 -07:00
Phil Tillet
cfce82d715 [TUTORIALS] Flash Attention tutorial now properly tries fwd, bwd, causal, non-causal 2023-07-19 21:56:29 -07:00
Philippe Tillet
3452615d79 [BUILD] Reverted ptxas change and fixed bug in cache key computation (#1971) 2023-07-19 20:58:24 -07:00
Philippe Tillet
28a61484bc [FRONTEND] more leniency when converting to/from fp8e4b15 (#1969) 2023-07-19 18:26:21 -07:00
David Berard
9c422e260b [OPTIMIZER] AxisInfoVisitor for LoadOp constancy calculation (#1968)
If you call `result = load(x, mask)` where `x` and `mask` have some
constancy properties, then you can infer some constancy properties for
`result`.
2023-07-19 17:40:46 -07:00
Philippe Tillet
68124676c9 [FRONTEND][BACKEND] Fix trans for float8e4b15 (#1964)
float8e4b15 is a packed type; it is incompatible with most of our layout conversions. For now, we just convert to float16.
2023-07-19 11:30:39 -07:00
nccx
15ab48d407 [TESTS] remove unnecessary lit command from combine.mlir (#1961)
The only difference between the two RUNs is `FileCheck`, which should be
needed.
2023-07-19 11:14:56 -07:00
nccx
cd61f99fb5 [DOCS] remove empty README (#1963) 2023-07-19 10:51:38 -07:00
nccx
4c3eabff67 [DOCS] update CONTRIBUTING.md (#1962) 2023-07-19 10:51:10 -07:00
kshama-msft
10d7078fab [DOCS] update 07-18-2023.md (#1959)
added minutes and recording link.
2023-07-19 10:24:29 -07:00
Philippe Tillet
c46a842b6f [TUTORIAL] more attention cleanup (#1958) 2023-07-18 12:36:15 -07:00
kshama-msft
fab5068503 [DOCS] update 07-18-2023.md (#1953) 2023-07-18 09:50:46 -07:00
Philippe Tillet
9e3e10c5ed [OPTIMIZER][TUTORIAL] flash attention v2 (#1952) 2023-07-17 12:23:02 -07:00
David Berard
7202c6cff0 [FRONTEND] expose tl.max_constancy hint (#1951)
Similar to `tl.multiple_of` and `tl.max_contiguous`, `tl.max_constancy`
will expose a compiler hint indicating that all the values are equal in
a block of a certain size.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-07-17 18:30:25 +00:00
Philippe Tillet
f6c4e8de76 [OPTIMIZER] RemSI -> RemUI in pipeline pass (#1950) 2023-07-17 00:54:46 -07:00
Phil Tillet
c7757fae71 [GITHUB] tweak CODEOWNERS 2023-07-17 00:41:11 -07:00
Christian Sigg
80c6e39716 [BACKEND] Fix enable_debug implementation. (#1876)
Print before every pass and after failures if MLIR_ENABLE_DUMP is set.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2023-07-16 21:50:30 -04:00
Keren Zhou
bcfd990a88 [TESTS] Fix autopep8 error (#1948) 2023-07-16 16:55:12 -07:00
Mehdi Amini
51fc42a568 [FRONTEND] fix AST IR generation for while loop nested inside other SCF (#1947)
The process of visiting twice the body of the while didn't restore
properly the insertion point, and was leaking the dummy block.
2023-07-15 10:17:29 -07:00
Philippe Tillet
8207eabd7b [FRONTEND][OPTIMIZER] small perf improvements (#1945) 2023-07-14 15:11:36 -07:00
Alex Collins
80163a9c1e [FRONTEND] Add support for default args in kernel wrappers (#1943)
Fixes the case where setting default values for arguments in a kernel
function signature results in a generated kernel wrapper function
without these default values.

For example:
```
@triton.jit
def kernel(x, y, z=3):
    ...

...
kernel[grid](x,y)
```

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-07-14 21:32:47 +00:00
Kojo Acquah
4042bd57a0 Fix assignment of BroadcastOp in conditional error (#1944)
Fix for current build error on main by assigning before if conditional: 
```shell
      /usr/local/home/kooljblack/Code/triton/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp:146:23: error: using the result of an assignment as a condition without parentheses [-Werror,-Wparentheses]
            if (broadcastOp = operand.getDefiningOp<triton::BroadcastOp>()) {
                ~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
      /usr/local/home/kooljblack/Code/triton/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp:146:23: note: place parentheses around the assignment to silence this warning
            if (broadcastOp = operand.getDefiningOp<triton::BroadcastOp>()) {
                            ^
                (                                                         )
      /usr/local/home/kooljblack/Code/triton/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp:146:23: note: use '==' to turn this assignment into an equality comparison
            if (broadcastOp = operand.getDefiningOp<triton::BroadcastOp>()) {
                            ^
                            ==
      1 error generated.
```
2023-07-14 16:52:35 +00:00
Yan Chunwei
d0c35b3b7d Hot fix for AOT (#1939)
This PR addresses the following issues encountered when using AOT
kernels in our project:

1. When different signatures are set for the same Triton kernel, it can
result in C functions with the same name. This is problematic because C
does not support function overloading.

2. Currently, the AOT kernel always compiles with `num_warps=1`, as
indicated
[here](https://github.com/openai/triton/pull/1939/files#diff-293af646f671d3a895c453a8b175754e9d4ec4fc855bb939ffa4d6e9e91b07c6L83).
However, the generated function includes a `numWarps` argument, which
can cause errors when the specified value does not match.

To resolve these issues, this PR does the following modifications:

1. Adds an 8-char hash key as a suffix to the generated function's
signature. This ensures that different function names are generated in C
when the argument dtype or constexpr value or even hint differs since we
hope these kernels could be used in one C/C++ library.

2. Introduces a new flag called `num-warps` that allows manual
specification of the `numWarps` value for AOT. This change hardcodes the
specified value into the generated kernel.c and removes the `numWarps`
argument from the generated function.
2023-07-14 09:16:43 +08:00
Keren Zhou
9cd970ebaa [OPTIMIZER] Skip checking ops without any operand (#1940) 2023-07-13 10:11:02 -07:00
Thomas
77ceb4ddb5 [BACKEND] Avoid circular dependency between Triton and TritonGPU (#1938) 2023-07-13 08:40:05 +02:00
Keren Zhou
fd89aa1d2b [ANALYSIS] Fix allocation algorithm (#1929) 2023-07-12 12:20:36 -07:00
Keren Zhou
571c92f2a8 [CI] Fix CI kernel compare (#1931)
With this PR, we find the latest merged PR that successfully passed
"Integration Tests".
2023-07-12 10:06:34 -07:00
Izzy Putterman
c615ce944c [FRONTEND] use local bindings in triton.cc (#1932)
Another follow up with the relative imports this time dealing with the
bindings.
2023-07-12 02:19:48 +00:00
Keren Zhou
4795820014 [TESTS] Fix unmatched test names (#1933) 2023-07-11 19:08:28 -07:00
Stonepia
d50e32fab7 [FRONTEND] fix the hard code builder.arch that could block third_party tests (#1859)
For CUDA devices, the `builder.arch` is an int.
For third_party devices, this line would be a TypeError. For example:

```
TypeError: '<' not supported between instances of 'dict' and 'int'
```

Co-authored-by: Wang Weihan <eikan.wang@intel.com>
2023-07-11 19:06:35 -07:00
Philippe Tillet
bf5acf46e2 [OPS] improved pointer arithmetic in attention (#1926)
this provides an additional 3-4% speed-up in non-causal attention, which
now tops at 155TFLOPS
2023-07-11 12:04:00 -07:00
Daniyal khan
b70d07aafe [BUILD][DOCS] updated setup.py and documentation (#1930) 2023-07-11 11:46:28 -07:00
Phil Tillet
041f1144e8 [DOCS] fixed flash_attn causal argument in tutorial 2023-07-11 09:28:20 -07:00
Goran Flegar
bbc1ad16d8 [BACKEND] Vectorize s8 to bf16 casts (#1879)
The code generated by LLVM ends up using 15 SASS instructions, while the
inline PTX added here only uses 8. It might be possible to reduce this
down to 6 if NVIDIA optimizes ptxas to use the byte selector in I2F for
all bytes (right now, we still have some bit manipulation code generated
for 2 out of 4 bytes).

This change improves the performance of mixed precision matmul kernel
with M=N=K=4096, where one operand is casted from s8 to bf16 from 140
TFlop/s to 165 TFlop/s on A100-40GB.

Also refactors the ElementwiseOpConversionBase template to support
vectorized operations, reducing the boilerplate needed for existing, and
this new vectorized cast; and extends the casting test to process more
than one element (so vectorized casts can be properly tested).
2023-07-11 09:24:05 -07:00
Philippe Tillet
8fe5524c75 [BACKEND] no longer uses shared mem or barriers for single-warp reductions (#1915)
0-bytes shared mem buffers don't materialize empty allocation buffers;
this could lead to unnecessary barriers.

note: reduceop code has become quite messy and will require some cleanup
2023-07-11 00:23:26 -07:00