Commit Graph

874 Commits

Author SHA1 Message Date
Alexander Efimov
af05f01218 [Tests] Fix some tests in test_core_amd.py (#288)
This PR:
- enables test_dot_mfma_vector_load for fast path in mfma dot op pipeline
- fixes kernel execution for mfma enabled GPUS
- disables mfma layout conversion tests on architectures which can not run these tests
2023-08-08 20:12:32 +02:00
allatit23
6dee55c912 [HOPPER][WS] fix TMA store hang in ws mode (#2056) 2023-08-08 19:53:52 +08:00
danny.jang
bb47f894f7 [FRONTEND] improve error message for shape mismatch (#2031)
Improve error messaging for block shape and value shape mismatch.
2023-08-08 01:13:16 -07:00
Philippe Tillet
658747feff [FRONTEND] remove ptxas from git (#2055) 2023-08-08 01:11:32 -07:00
Bin Fan
a76ecd74e7 add num_stages parameter to aot compile.py (#2000)
This allows the AOT client to tune the number of stages for the
generated kernel. set the default number to 3 to match the triton
compiler.
2023-08-08 06:04:57 +00:00
Qingyi Liu
341f5b61be [BACKEND] Add BarrierOp after AllocMBarrierOp when numCTAs == 1 (#2040)
Make sure that other threads within CTA do not operate on mbarrier until
it is initialized by thread 0.

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-08-07 20:11:00 -07:00
danny.jang
6a1ac65043 [FRONTEND] improve error message for type mismatch (#2038) 2023-08-07 19:34:09 -07:00
Keren Zhou
30a331e628 [FRONTEND] Support jit functions without arguments (#2043)
Issue https://github.com/openai/triton/issues/1973

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-08-07 19:05:56 -07:00
goostavz
f1512bded1 Initial code merge of Hopper support (#2036)
The initial code merge of Nvidia Hopper features support. Please be
aware that the code merge is not finished yet and the trouble-shooting
is still ongoing. The new hardware features (GMMA, TMA, STMATRIX etc.)
and automatic warp-specialization are experimental for now and turned
off by default. It is recommended for a trial when version 3.0 is
released.

The work is contributed by:
ben-zhang-609, bealwang, donproc, qliu93, jsh20, allatit23, LyricZhao,
ivanyinwz, goostavz & yangjunpro
from Nvidia, in cooperation with:
ptillet, Jokeren, ThomasRaoux & zahimoud
from OpenAI.

Co-authored-by: Goostav Zhu <gzhu@nvidia.com>
2023-08-07 09:53:04 +08:00
Alexander Efimov
86f8b64ae0 [Dot] [MFMA] [FMA] Update Dot implementation to support upstream tests (#260)
* [Dot] [MFMA] Support FP16 output of MFMA dot

This PR adds cast of output tensor to requested data type.

* add tests

* fix test for FMA implementation

* loose fp16xfp16->fp16 tolerance

* enable FMA fallback for unsupported sizes of dot operation

* rework granularity check

* add constant modifier to granularity
2023-08-03 13:47:18 -05:00
Alexander Efimov
a318bc5777 Disable pipelining (#276)
This PR sets default value of pipeline stages to 1 for amd.
Add explicit num stages in test_gemm test
2023-08-01 14:54:17 -05:00
Shucai Xiao
82db24cde7 add gfx940/941 to archs that support mfma (#277) 2023-08-01 10:57:16 -05:00
Yan Chunwei
89b0b79d75 [FRONTEND] fix the silent return issue in AOT launcher (#2013)
In the current link.py, it produces the launcher code as below:

```python
CUresult matmul_fp16xfp16_16x16x16(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, CUdeviceptr C, CUdeviceptr A, CUdeviceptr B, int32_t stride_cm, int32_t stride_am, int32_t stride_bk){
  if ((C % 16 == 0) && (A % 16 == 0) && (B % 16 == 0) && (stride_cm % 16 == 0))
    return matmul_fp16xfp16_16x16x16_688cc413_0d1d2d3d45d(stream, gX, gY, gZ, C, A, B, stride_cm, stride_am, stride_bk);
  // ...
  if ((C % 16 == 0) && (A % 16 == 0) && (B % 16 == 0))
    return matmul_fp16xfp16_16x16x16_7c0255bf_0d1d2d345(stream, gX, gY, gZ, C, A, B, stride_cm, stride_am, stride_bk);
}
```
Note that, when the input does not match any of the if branches, it will
do nothing, and the compiler should make it return 0 as a default
behavior, which equals to `CUDA_SUCCESS`, this doesn't match the
expectation.

This PR adds a `return CUDA_VALUE_ERROR;` to the tail of launchers, and
it produces code like:

```c++
CUresult matmul_fp16xfp16_16x16x16(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, CUdeviceptr C, CUdeviceptr A, CUdeviceptr B, int32_t stride_cm, int32_t stride_cn, int32_t stride_am, int32_t stride_ak, int32_t stride_bk, int32_t stride_bn){
  if ((C % 16 == 0) && (A % 16 == 0) && (B % 16 == 0) && (stride_cm == 1) && (stride_cn == 1) && (stride_am == 1) && (stride_ak == 1) && (stride_bk % 16 == 0) && (stride_bn == 1))
    return matmul_fp16xfp16_16x16x16_1f18a6da_0d1d2d3c4c5c6c7d8c(stream, gX, gY, gZ, C, A, B, stride_bk);

  return CUDA_ERROR_INVALID_VALUE;
}
```

And it requires users to check the result in their application, which I
think should match the initial AOT ideas.
2023-07-31 09:59:28 -07:00
Bin Fan
2689f4a3b0 [TOOLS][AOT] some issues in equal_to_1 hint (#1998)
- Change test_aot.py to actually use equal_to_1 hint

- In the client function, equal_to_1 parameters are not specialized,
because AOT clients may not know the details of Triton argument
specialization, they still want to use the same parameter list as they
write the Triton kernel. The generated kernels has specialized argument
list, the generated dispatcher code will make sure the correct arguments
from the original full argument list are passed.

- Fixed a bug in _match_suffix in link.py. Previously it assumes each
parameter has a suffix of either ‘d’ or ‘c’, but in fact sometimes a
parameter doesn’t have a suffix, like 0d1d2d34c56c78c
2023-07-27 16:07:49 -07:00
Alexander Efimov
0073bb98f4 [MFMA] Introduce dot operand loading fast path (#269)
* [MFMA] Introduce dot operand loading fast path

This PR introduces fast path for code generation of MFMA dot operand
loading from LDS.

Fast path is used when operand is not swizzled and is not slice of some
bigger LDS object(it is not a slice of a tensor).
This is a case for current FA and GEMM kernels compiled with
num_stages=1, i.e. software pipelining is disabled.

* cleanup swizzle info
2023-07-27 13:46:50 -05:00
Shantanu
4f1b2ea8d7 [FRONTEND] fix error with -> None return annotation (#1987)
None is not a type, so you get:
```
    self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
E   TypeError: argument of type 'NoneType' is not iterable
```

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-07-25 18:49:45 -07:00
Shucai Xiao
356595f3dc revert PR#245 and #241 for gemm performance regression (#265)
* Revert "[MFMA][Dot] Swizzle mfma dot operands  (#245)"

This reverts commit 4d0deef45f.

* revert the swizzle and slice PRs for gemm performance regression

* fix build error after reverting the PR#245(commit 4d0deef45f) Swizzle mamf dot operands
2023-07-24 14:09:46 -05:00
Izzy Putterman
de6f053c0f [TRITON][OPS] add Flash Attention v2 to Ops (#1970)
I also dropped the do_scaled as it is no longer needed (no scaling done
to the do in v2).

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-07-23 14:07:15 -07:00
youkaichao
c9ab44888e [FRONTEND] improve the process of finding libcuda.so and the error message (#1981)
`triton` uses `whereis` command to find `libcuda.so`, which is intended
to find binary, source, and manual page files. When `libcuda.so` is not
properly setup, the `whereis` command ends up with
`/usr/share/man/man7/libcuda.7`, which is not the place to look for.

This PR uses `ldconfig -p` to reliably find `libcuda.so`.

In my case, I find that I have a `libcuda.so.1` file, but it is not
linked to `libcuda.so`. Therefore `ld` cannot find the library to link.
After creating the linking, I was able to run `triton` successfully.

Therefore, I improve the code by first invoking `ldconfig -p`, and
checking `libcuda.so` strings first. These might be possible library to
link against. If the literal `libcuda.so` file is not found, then I
raise an error and tells the user that a possible fix is to create a
symlink file.
2023-07-23 10:31:07 -07:00
Philippe Tillet
3452615d79 [BUILD] Reverted ptxas change and fixed bug in cache key computation (#1971) 2023-07-19 20:58:24 -07:00
Philippe Tillet
28a61484bc [FRONTEND] more leniency when converting to/from fp8e4b15 (#1969) 2023-07-19 18:26:21 -07:00
Shucai Xiao
1c86e3238a remove multiple archtictures to isa head and adding gemm tuning scripts (#261)
* Remove adding multiple architectures to isa head

* Add mask for gpu memory load in scripts for tuning gemm 'script/amd/gemm/matmul.py'

* Move the scripts to a better place 'scripts/amd/gemm/'
2023-07-18 14:21:16 -05:00
David Berard
7202c6cff0 [FRONTEND] expose tl.max_constancy hint (#1951)
Similar to `tl.multiple_of` and `tl.max_contiguous`, `tl.max_constancy`
will expose a compiler hint indicating that all the values are equal in
a block of a certain size.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-07-17 18:30:25 +00:00
Mehdi Amini
51fc42a568 [FRONTEND] fix AST IR generation for while loop nested inside other SCF (#1947)
The process of visiting twice the body of the while didn't restore
properly the insertion point, and was leaking the dummy block.
2023-07-15 10:17:29 -07:00
Philippe Tillet
8207eabd7b [FRONTEND][OPTIMIZER] small perf improvements (#1945) 2023-07-14 15:11:36 -07:00
Alex Collins
80163a9c1e [FRONTEND] Add support for default args in kernel wrappers (#1943)
Fixes the case where setting default values for arguments in a kernel
function signature results in a generated kernel wrapper function
without these default values.

For example:
```
@triton.jit
def kernel(x, y, z=3):
    ...

...
kernel[grid](x,y)
```

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-07-14 21:32:47 +00:00
Yan Chunwei
d0c35b3b7d Hot fix for AOT (#1939)
This PR addresses the following issues encountered when using AOT
kernels in our project:

1. When different signatures are set for the same Triton kernel, it can
result in C functions with the same name. This is problematic because C
does not support function overloading.

2. Currently, the AOT kernel always compiles with `num_warps=1`, as
indicated
[here](https://github.com/openai/triton/pull/1939/files#diff-293af646f671d3a895c453a8b175754e9d4ec4fc855bb939ffa4d6e9e91b07c6L83).
However, the generated function includes a `numWarps` argument, which
can cause errors when the specified value does not match.

To resolve these issues, this PR does the following modifications:

1. Adds an 8-char hash key as a suffix to the generated function's
signature. This ensures that different function names are generated in C
when the argument dtype or constexpr value or even hint differs since we
hope these kernels could be used in one C/C++ library.

2. Introduces a new flag called `num-warps` that allows manual
specification of the `numWarps` value for AOT. This change hardcodes the
specified value into the generated kernel.c and removes the `numWarps`
argument from the generated function.
2023-07-14 09:16:43 +08:00
oplavsic
d6e51fd221 [FA OPTIMIZATION] Keep results of FA dot operations in registers (#247)
* [WIP][FA OPTIMIZATION] Optimize chain dot

This commit optimizes chain dot operation by keeping
results of the first dot operation in registers.

* [FA OPTIMIZATION] Enable lowering pipeline for keeping result of chain dot in registers

* Move operand swapping in ttgir -> llir lowering phase

* Refactor emitMfmaOffsetForCTA function to be more readable

* Fix accidental change in 06-fused-attention.py

* Address review comments

* Fix rebase errors
2023-07-12 15:25:55 -05:00
jayfurmanek
21c757564a Merge branch 'triton-mlir' into ifu230620 2023-07-12 11:29:23 -05:00
Xinya Zhang
788b283a9a Make configurable warp size compatible with PyTorch's inductor (#256)
* CMakeLists: Fix the typo

* Remove warp size argument from python function.

The additional warp_size argument breaks the compatibility with TOT, but
after examination this argument is not necessary, because this value can be retrieved from hipStream_t.

Fixes #255

* Replace hipGetStreamDevice with hipGetStreamDeviceId.

hipGetStreamDevice is not available in ROCM < 5.6.
Fixes #255 for older ROCM releases.
2023-07-12 11:23:00 -05:00
Stonepia
d50e32fab7 [FRONTEND] fix the hard code builder.arch that could block third_party tests (#1859)
For CUDA devices, the `builder.arch` is an int.
For third_party devices, this line would be a TypeError. For example:

```
TypeError: '<' not supported between instances of 'dict' and 'int'
```

Co-authored-by: Wang Weihan <eikan.wang@intel.com>
2023-07-11 19:06:35 -07:00
Philippe Tillet
bf5acf46e2 [OPS] improved pointer arithmetic in attention (#1926)
this provides an additional 3-4% speed-up in non-causal attention, which
now tops at 155TFLOPS
2023-07-11 12:04:00 -07:00
Philippe Tillet
7e3ebbc4c8 [TESTING] now using cuda graphs for perf regression tests (#1925) 2023-07-10 22:49:25 -07:00
danny.jang
4a20d5010b [FRONTEND] Fix a inspection warning (#1914)
"Expected type 'SupportsIndex', got 'constexpr' instead" is no longer
reported.
2023-07-10 21:30:59 -07:00
Izzy Putterman
d39d78fa08 [OPS] Add more perf-tests, new features to FA (#1849)
Adding new tests across the board for float32, bfloat16, non-powers-of-2
shapes (to test masks), and tests on sequence parallel for atomics. This
also adds the sequence parallel features from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py.
I am not sure about the best way to grab the baseline benchmarking
numbers. I have access to V100s and A100s, but I saw on the tests it
mentions " # A100 in the CI server is slow-ish for some reason.
# On some other servers, we are getting about 90% peak for 8kx8x8k
float16". Current plan is to run CI here and use those numbers for
baseline, then match against my GPUs as a sanity check.

---------

Co-authored-by: Phil Tillet <phil@openai.com>
2023-07-10 18:52:59 -07:00
peterbell10
e3d9478d31 [OPTIMIZER] Add pass to move broadcasts after elementwise operations (#1811)
This adds a pass that tries to reduce the shape of tensor arguments to
element-wise operations by moving splat and broadcast operations later
in the graph. So, for example say we have:

```python
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset  + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (0))
    tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
    tmp2 = 0.017453292519943295
    tmp3 = tmp1 * tmp2
    tmp4 = tl.sin(tmp3)
    tl.store(out_ptr0 + (x0), tmp4, None)
```

Today this results in duplicate `sin` calls:
```
    %27 = llvm.fmul %26, %3  : f32
    %28 = llvm.call @__nv_sinf(%27) : (f32) -> f32
    %29 = llvm.call @__nv_sinf(%27) : (f32) -> f32
```

The duplicate `llvm.fmul` calls are eliminated via CSE, but `llvm.call`
doesn't get CSE'd because it might be impure.

After this change, the sin is done on a scalar value in the triton IR
and splatted at the very end, so no duplicate calculation happens within
a thread.

---------

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Philippe Tillet <phil@openai.com>
2023-07-10 11:44:38 -07:00
Philippe Tillet
5a722b5f74 [OPS][TESTS] Added float8 support in triton.ops.matmul (#1918)
this also adds rather extensive testing for mixed precision mode,
including `float8e4b15 x float8e5` and `float8e5 x float16`
2023-07-10 09:31:12 -07:00
Natalia Gimelshein
778ed64a66 [BACKEND] make sure we always bind to primary context in loadBinary (#1912) 2023-07-07 14:28:03 -07:00
Bert Maher
38d767ea93 [FRONTEND] fix memory leak caused by retaining args to autotuned kernel (#1911) 2023-07-07 20:58:29 +00:00
Keren Zhou
cc5a7ed52f [FRONTEND][BACKEND] Materialize line info for triton kernels (#1902)
`export TRITON_DISABLE_LINE_INFO=1` to disable the feature.
2023-07-07 16:03:44 -04:00
Philippe Tillet
6d1285e1ae [FRONTEND][BACKEND] improved fp8 specs (#1906)
This un-reverts commit
d4c941177e.
2023-07-06 13:03:53 -07:00
Thomas
ae0ee5248f [FRONTEND] Add cumprod scan op (#1894)
Add and test cumprod. This also allows testing a case of accumulation
where 0 is not the identity element.
Also add documention for scan functions.
2023-07-05 10:09:06 -07:00
Philippe Tillet
d57dcd9994 [FRONTEND] pattern-match ExpandDims + Mul + Reduce into DotOp (#1889)
transforms e.g.:
```
x1 = tl.expand_dims(x0, axis=2)
y1 = tl.expand_dims(y0, axis=0)
z = tl.sum(x1 * y1, axis=1)
```

into
```
z = tl.dot(x0, y0) 
```

uses allowTF32 = True by default.
2023-07-05 00:31:21 -07:00
Yongjik Kim
387cdc8fe9 [FRONTEND] improve error message. (#1887) 2023-07-05 00:19:28 -07:00
Christopher Hesse
cc93356c4b [DOCS] update print docs (#1884) 2023-07-05 00:19:12 -07:00
Keren Zhou
aa55d3b5bf [FRONTEND] Remove unnecessary replace in while op's after block (#1886)
We've already updated the mapping between name and tensor before
visiting each compound statement in the while op. As a result, any
overwritten name gets up-to-date values updated in the while loop. And
any unchanged livein names hold the original tensors.
2023-07-05 00:16:38 -07:00
Ying Hang Eng
55eb32dff9 Fix tl.device_assert compilation error (#1875)
As mentioned in #1769, we set file name, function name to 'unknown' and
lineno to 0 if frame is None
2023-06-30 22:19:25 +00:00
Jason Furmanek
12005a82f2 Initial commit to resolve merge conflicts 2023-06-30 19:53:53 +00:00
Oleg Shyshkov
66ed53d19d [FRONTEND] Support mixed-precision inputs in triton.ops.matmul. (#1754)
Support only combinations of float32 with float16 or bfloat16 for now.
Shouldn't change anything for cases when input types match.

That's a follow-up to the comment in my other PR:
https://github.com/openai/triton/pull/1746#issuecomment-1579630016.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-06-30 09:22:27 -07:00
Philippe Tillet
f77015967d Revert "[FRONTEND][BACKEND] improved fp8 specs (#1841)" (#1865)
This reverts commit d4c941177e.
2023-06-29 21:07:01 -04:00