Commit Graph

1160 Commits

Author SHA1 Message Date
Wen Chen
ffc230ebfe [ROCM] Fixed implementation of fp32 to bf16 conversion on ROCm. 2023-09-06 18:10:54 -05:00
Wen Chen
2d3e38e182 [ROCM] Added ROCm support for int8 to bfloat16 conversion. 2023-09-06 18:10:54 -05:00
Wen Chen
59a40d3f72 [ROCM] Added ROCm support for the conversions of following data types:
[float8e4m3, float8e4m3b15, float8e5m2] <-> [float16, bfloat16]
2023-09-06 18:10:54 -05:00
Jason Furmanek
320b1029da Temporarily disable F8 tests on ROCm 2023-09-01 04:02:14 +00:00
Jason Furmanek
df5c263a19 Fix merge conflicts 2023-09-01 04:01:32 +00:00
Jason Furmanek
3eaeb89d18 Merge commit '5df904233c11a65bd131ead7268f84cca7804275' into ifu230810-2
Conflicts:
	include/triton/Dialect/Triton/Transforms/Passes.h
	include/triton/Dialect/TritonGPU/IR/Dialect.h
	include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
	lib/Analysis/Allocation.cpp
	lib/Analysis/Utility.cpp
	lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
	lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp
	lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
	lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp
	lib/Target/LLVMIR/LLVMIRTranslation.cpp
	python/src/triton.cc
	python/triton/compiler/compiler.py
	python/triton/ops/flash_attention.py
	python/triton/runtime/autotuner.py
	python/triton/runtime/jit.py
	python/triton/tools/aot.py
	python/tutorials/06-fused-attention.py
	test/Conversion/tritongpu_to_llvm.mlir
	test/Target/tritongpu_to_llvmir.mlir
	test/Target/tritongpu_to_llvmir_noinline.mlir
2023-09-01 03:25:33 +00:00
Vinayak Gokhale
9cdf3a58c3 Enable split kernel in bwd pass (#303)
* Add fwd and bwd v2

Changes are largely from upstream.

* Split bwd kernel in dq and dk+dv

Only adds the split kernels. They are not enabled yet.

* Pull scalar multiplies out of the loop

* Enable split kernel for bwd pass

* Put back P_SEQ=128 in fwd test

Not used for bwd test

* Address review comments

* Address comments

Conditionally set causal/ splitkernel to False for bwd.

* Add block pointer semantics to bwd pass

This significantly increases perf for bwd, similar to fwd.
2023-08-29 13:51:29 -05:00
Lixun Zhang
b834f42ae4 [autotuner] Add an option to print best_config for each key 2023-08-28 14:45:54 -05:00
jayfurmanek
ff7e707f87 Enable usage of block pointer semantics for AMD gpus (#301)
* Enable usage of block pointer semantics for AMD gpus

This commit enables usage of block pointer semantics by enabling
rewrite_tensor_pointer_pass that rewrites block pointer loads/stores
to legacy loads/stores.

* Update FA fwd in tutorial to use the block pointers

* use 90 compute capability for amd gpus in python/triton/compiler/compiler.py

Co-authored-by: Alexander Efimov <efimov.alexander@gmail.com>

---------

Co-authored-by: Ognjen Plavsic <ognjen.plavsic@dxc.com>
Co-authored-by: Lixun Zhang <lixun.zhang@amd.com>
Co-authored-by: Aleksandr Efimov <130555951+alefimov-amd@users.noreply.github.com>
Co-authored-by: Alexander Efimov <efimov.alexander@gmail.com>
2023-08-24 13:05:12 -05:00
jayfurmanek
fa429316d4 Merge pull request #268 from ROCmSoftwarePlatform/improve_reduce_for_fa
[CHERRY-PICKED FROM UPSTREAM][BACKEND] no longer uses shared mem or barriers for single-warp reductions (openai#1915)
2023-08-21 13:29:11 -05:00
Alexander Efimov
d86b19f7a3 [CI] [Dot] Reduced test suite (#302)
Use upstream list of test for dot op on machines with no MFMA support.
This is needed to reduce time required for PR testing.
2023-08-18 07:47:14 -05:00
Alexander Efimov
23979098c8 [MFMA] MI200 bfloat16 support (#294)
This PR enables bfloat16 support in MFMA dot on MI200.
Used mfma_f32_32x32x8bf16_1k instruction.
2023-08-18 07:28:18 -05:00
Shucai Xiao
f7cf2c032b Changes of the tutorial matmul scripts to get good performance (#297)
* simple changes of the matmul scripts to get good performance. Specification reason for the performance boost needs futher investigation and are tracked

* fix review comments

* change the num_warps in the autotuning config for hip to workaround an error and change the rtol so correctness check passed
2023-08-17 13:24:49 -05:00
Lixun Zhang
87e45cb011 Set vecSize and maxPhase more generically 2023-08-16 08:30:32 -05:00
Lixun Zhang
7156fcb0ef Set vecSize = 4 and maxPhase = BLOCK_K/4 2023-08-16 08:30:32 -05:00
Philippe Tillet
4215086931 [BACKEND] no longer uses shared mem or barriers for single-warp reductions (#1915)
0-bytes shared mem buffers don't materialize empty allocation buffers;
this could lead to unnecessary barriers.

note: reduceop code has become quite messy and will require some cleanup
2023-08-15 11:51:20 +00:00
Alexander Efimov
af05f01218 [Tests] Fix some tests in test_core_amd.py (#288)
This PR:
- enables test_dot_mfma_vector_load for fast path in mfma dot op pipeline
- fixes kernel execution for mfma enabled GPUS
- disables mfma layout conversion tests on architectures which can not run these tests
2023-08-08 20:12:32 +02:00
Alexander Efimov
ff95dddf18 [tutorial] Amd specific tuning configs (#287)
This pr adds amd specific tuning configs for matmul tutorial with num_stages == 1.
2023-08-08 20:11:30 +02:00
jayfurmanek
32d7c6d646 Fix runtime/test_subproc.py for hip devices (#284)
* Fix runtime/test_subproc.py for hip devices

* address review comments
2023-08-07 10:30:36 -05:00
Alexander Efimov
7158ec286a [MFMA] [Dot] Support vector loads in normal path (#275)
* [MFMA] [Dot] Support vector loads in normal path

This PR adds generation of vector loads in normal path of
MFMA dot operand loading.
This requires shared layout to have contiguous elements
which should be loaded by one lane.

* remove redundant refactoring

* fix tests

* extend test with transposed A/B tensors
2023-08-03 14:57:39 -05:00
Alexander Efimov
86f8b64ae0 [Dot] [MFMA] [FMA] Update Dot implementation to support upstream tests (#260)
* [Dot] [MFMA] Support FP16 output of MFMA dot

This PR adds cast of output tensor to requested data type.

* add tests

* fix test for FMA implementation

* loose fp16xfp16->fp16 tolerance

* enable FMA fallback for unsupported sizes of dot operation

* rework granularity check

* add constant modifier to granularity
2023-08-03 13:47:18 -05:00
Vinayak Gokhale
f1063bb33c Enable backward pass in FA tutorial test (#282)
Enabled the backward pass in the fused attention tutorial.
The tolerance when comparing to the naive implementation
had to be changed. The block size is forced to be 64x64
due to the 64 KiB LDS. Default is block 128 for A100's
larger SMEM. This creates differences in order of computation
and reuslts in a larger gap between the naive and FA
implementations.
2023-08-03 10:12:46 -05:00
Shucai Xiao
31cfda8f0e enable more gemm tests corresponding to PR#273 (#279) 2023-08-02 16:45:31 -05:00
Alexander Efimov
a318bc5777 Disable pipelining (#276)
This PR sets default value of pipeline stages to 1 for amd.
Add explicit num stages in test_gemm test
2023-08-01 14:54:17 -05:00
Shucai Xiao
82db24cde7 add gfx940/941 to archs that support mfma (#277) 2023-08-01 10:57:16 -05:00
Yan Chunwei
89b0b79d75 [FRONTEND] fix the silent return issue in AOT launcher (#2013)
In the current link.py, it produces the launcher code as below:

```python
CUresult matmul_fp16xfp16_16x16x16(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, CUdeviceptr C, CUdeviceptr A, CUdeviceptr B, int32_t stride_cm, int32_t stride_am, int32_t stride_bk){
  if ((C % 16 == 0) && (A % 16 == 0) && (B % 16 == 0) && (stride_cm % 16 == 0))
    return matmul_fp16xfp16_16x16x16_688cc413_0d1d2d3d45d(stream, gX, gY, gZ, C, A, B, stride_cm, stride_am, stride_bk);
  // ...
  if ((C % 16 == 0) && (A % 16 == 0) && (B % 16 == 0))
    return matmul_fp16xfp16_16x16x16_7c0255bf_0d1d2d345(stream, gX, gY, gZ, C, A, B, stride_cm, stride_am, stride_bk);
}
```
Note that, when the input does not match any of the if branches, it will
do nothing, and the compiler should make it return 0 as a default
behavior, which equals to `CUDA_SUCCESS`, this doesn't match the
expectation.

This PR adds a `return CUDA_VALUE_ERROR;` to the tail of launchers, and
it produces code like:

```c++
CUresult matmul_fp16xfp16_16x16x16(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, CUdeviceptr C, CUdeviceptr A, CUdeviceptr B, int32_t stride_cm, int32_t stride_cn, int32_t stride_am, int32_t stride_ak, int32_t stride_bk, int32_t stride_bn){
  if ((C % 16 == 0) && (A % 16 == 0) && (B % 16 == 0) && (stride_cm == 1) && (stride_cn == 1) && (stride_am == 1) && (stride_ak == 1) && (stride_bk % 16 == 0) && (stride_bn == 1))
    return matmul_fp16xfp16_16x16x16_1f18a6da_0d1d2d3c4c5c6c7d8c(stream, gX, gY, gZ, C, A, B, stride_bk);

  return CUDA_ERROR_INVALID_VALUE;
}
```

And it requires users to check the result in their application, which I
think should match the initial AOT ideas.
2023-07-31 09:59:28 -07:00
Philippe Tillet
52c146f66b [OPTIMIZER][BACKEND] significantly cleaner handling of mixed-precision kernels (#1949)
we currently have a very janky approach to optimizing mixed-precision
matmul workloads, where some layout combinations (e.g., NT matmul) were
explicitly pattern-matched to take a more optimized codepath. Attempt at
unifying all the codepaths to codegen cp.async failed, due to bugs in
SharedToDotOperandMMAv2.cpp.

This PR fixes said bugs, add some assertions for SharedToDotOperandMMAv2
modes that aren't well supported, and greatly simplify our handling of
element-wise operations between load and conversions to DotOperand.
2023-07-28 10:29:42 -07:00
Bin Fan
2689f4a3b0 [TOOLS][AOT] some issues in equal_to_1 hint (#1998)
- Change test_aot.py to actually use equal_to_1 hint

- In the client function, equal_to_1 parameters are not specialized,
because AOT clients may not know the details of Triton argument
specialization, they still want to use the same parameter list as they
write the Triton kernel. The generated kernels has specialized argument
list, the generated dispatcher code will make sure the correct arguments
from the original full argument list are passed.

- Fixed a bug in _match_suffix in link.py. Previously it assumes each
parameter has a suffix of either ‘d’ or ‘c’, but in fact sometimes a
parameter doesn’t have a suffix, like 0d1d2d34c56c78c
2023-07-27 16:07:49 -07:00
Alexander Efimov
0073bb98f4 [MFMA] Introduce dot operand loading fast path (#269)
* [MFMA] Introduce dot operand loading fast path

This PR introduces fast path for code generation of MFMA dot operand
loading from LDS.

Fast path is used when operand is not swizzled and is not slice of some
bigger LDS object(it is not a slice of a tensor).
This is a case for current FA and GEMM kernels compiled with
num_stages=1, i.e. software pipelining is disabled.

* cleanup swizzle info
2023-07-27 13:46:50 -05:00
Lixun Zhang
2fbffe2784 Fix masked load (#262)
* Fix the issue with masked load

Cherry-picked from a0b60eb187

* Remove tests in test_gemm that use too much LDS

---------

Co-authored-by: Shucai Xiao <shucai.xiao@amd.com>
2023-07-26 10:45:21 -05:00
Shantanu
4f1b2ea8d7 [FRONTEND] fix error with -> None return annotation (#1987)
None is not a type, so you get:
```
    self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
E   TypeError: argument of type 'NoneType' is not iterable
```

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-07-25 18:49:45 -07:00
Phil Tillet
db695c093f [TUTORIALS] fix format 2023-07-25 18:16:39 -07:00
Shucai Xiao
1cccf14f62 workaround for num_programs() to get correct value (#271)
* workaround for num_programs to return correct value for issue #259.
2023-07-25 15:46:19 -05:00
janEbert
62a8afa403 [TUTORIALS] Support FlashAttention-2 reference (#1984)
Uses FlashAttention-2 if available, otherwise acts as before (if
FlashAttention-1 is available, that is used, otherwise the
FlashAttention reference benchmark is not run).

I decided to keep the same name for the imported function, but feel free
to make me change that.
2023-07-24 13:54:01 -07:00
Shucai Xiao
356595f3dc revert PR#245 and #241 for gemm performance regression (#265)
* Revert "[MFMA][Dot] Swizzle mfma dot operands  (#245)"

This reverts commit 4d0deef45f.

* revert the swizzle and slice PRs for gemm performance regression

* fix build error after reverting the PR#245(commit 4d0deef45f) Swizzle mamf dot operands
2023-07-24 14:09:46 -05:00
Izzy Putterman
de6f053c0f [TRITON][OPS] add Flash Attention v2 to Ops (#1970)
I also dropped the do_scaled as it is no longer needed (no scaling done
to the do in v2).

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-07-23 14:07:15 -07:00
youkaichao
c9ab44888e [FRONTEND] improve the process of finding libcuda.so and the error message (#1981)
`triton` uses `whereis` command to find `libcuda.so`, which is intended
to find binary, source, and manual page files. When `libcuda.so` is not
properly setup, the `whereis` command ends up with
`/usr/share/man/man7/libcuda.7`, which is not the place to look for.

This PR uses `ldconfig -p` to reliably find `libcuda.so`.

In my case, I find that I have a `libcuda.so.1` file, but it is not
linked to `libcuda.so`. Therefore `ld` cannot find the library to link.
After creating the linking, I was able to run `triton` successfully.

Therefore, I improve the code by first invoking `ldconfig -p`, and
checking `libcuda.so` strings first. These might be possible library to
link against. If the literal `libcuda.so` file is not found, then I
raise an error and tells the user that a possible fix is to create a
symlink file.
2023-07-23 10:31:07 -07:00
Philippe Tillet
66eda76e45 [FRONTEND][BACKEND] no longer serialize float8e4b15 (#1979)
We had a number of complains that the previous packed format was
error-prone and may not yet be worth the 2 SASS instruction saved per 4
conversions
2023-07-21 22:44:55 -07:00
Phil Tillet
cfce82d715 [TUTORIALS] Flash Attention tutorial now properly tries fwd, bwd, causal, non-causal 2023-07-19 21:56:29 -07:00
Philippe Tillet
3452615d79 [BUILD] Reverted ptxas change and fixed bug in cache key computation (#1971) 2023-07-19 20:58:24 -07:00
Philippe Tillet
28a61484bc [FRONTEND] more leniency when converting to/from fp8e4b15 (#1969) 2023-07-19 18:26:21 -07:00
Philippe Tillet
68124676c9 [FRONTEND][BACKEND] Fix trans for float8e4b15 (#1964)
float8e4b15 is a packed type; it is incompatible with most of our layout conversions. For now, we just convert to float16.
2023-07-19 11:30:39 -07:00
nccx
cd61f99fb5 [DOCS] remove empty README (#1963) 2023-07-19 10:51:38 -07:00
Philippe Tillet
c46a842b6f [TUTORIAL] more attention cleanup (#1958) 2023-07-18 12:36:15 -07:00
Shucai Xiao
1c86e3238a remove multiple archtictures to isa head and adding gemm tuning scripts (#261)
* Remove adding multiple architectures to isa head

* Add mask for gpu memory load in scripts for tuning gemm 'script/amd/gemm/matmul.py'

* Move the scripts to a better place 'scripts/amd/gemm/'
2023-07-18 14:21:16 -05:00
Philippe Tillet
9e3e10c5ed [OPTIMIZER][TUTORIAL] flash attention v2 (#1952) 2023-07-17 12:23:02 -07:00
David Berard
7202c6cff0 [FRONTEND] expose tl.max_constancy hint (#1951)
Similar to `tl.multiple_of` and `tl.max_contiguous`, `tl.max_constancy`
will expose a compiler hint indicating that all the values are equal in
a block of a certain size.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-07-17 18:30:25 +00:00
Christian Sigg
80c6e39716 [BACKEND] Fix enable_debug implementation. (#1876)
Print before every pass and after failures if MLIR_ENABLE_DUMP is set.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2023-07-16 21:50:30 -04:00
Keren Zhou
bcfd990a88 [TESTS] Fix autopep8 error (#1948) 2023-07-16 16:55:12 -07:00
Mehdi Amini
51fc42a568 [FRONTEND] fix AST IR generation for while loop nested inside other SCF (#1947)
The process of visiting twice the body of the while didn't restore
properly the insertion point, and was leaking the dummy block.
2023-07-15 10:17:29 -07:00