Commit Graph

874 Commits

Author SHA1 Message Date
Ognjen
171a67e837 Add scheduling pass 2024-01-25 18:07:45 +00:00
oplavsic
760ac8441a Dot slicing pass (#440)
* First commit

* Implement DotSlicing pass.

* small fixes

* Support chained dot in DotSlicingPass (second GEMM in FA)

* Add lit test for FA dot slicing

---------

Co-authored-by: Ognjen Plavsic <ognjen.plavsic@luxoft.com>
Co-authored-by: Ognjen <oplavsic@luxoft.com>
2024-01-16 14:25:10 -06:00
Jack Taylor
1e2fd0dd1a Update hip_backend to use libhsa-runtime for arch info, (#411)
brings in path changes for pytorch triton wheels

Co-authored-by: jayfurmanek <Jason.Furmanek@amd.com>
2023-12-21 15:40:57 +00:00
jayfurmanek
16281f02f4 [ROCM] drop GIL for launch, and set value=false upon pointer error (#426) 2023-12-19 08:34:51 -06:00
joviliast
af15da2f84 Support WMMA layout in TritonAMDGPUAccelerateMatmulPass
-Introduce WmmaEncodingAttr for WMMA output
-Introduce BlockedToWMMA rewrite pattern in TritonAMDGPUAccelerateMatmulPass
-Provide a flag tho check if wmma instructions are supported by target

Signed-off-by: joviliast <iveselov.nn@gmail.com>
2023-12-18 09:11:20 -06:00
jayfurmanek
29847e9bb1 Merge pull request #410 from ROCmSoftwarePlatform/ifu-231117
Ifu 231117
2023-12-15 09:09:40 -06:00
Shucai Xiao
521f425fbf add bitcode for gfx941 and gfx942 (#403)
Co-authored-by: Aleksandr Efimov <130555951+alefimov-amd@users.noreply.github.com>
2023-12-14 08:19:23 -06:00
jayfurmanek
a42ac260aa Merge branch 'triton-mlir' into ifu-231117 2023-12-12 14:24:11 -06:00
Alexander Efimov
605a90c58e [MFMA] Support tile size 4x4 version 1 (#413)
This PR enables 4x4 tile size in MFMA based dot operations.

Supported tiled dot is (4x64) x (64x4) -> (4x4) in MFMA layout.
However, actual dot operation should have at least 64 output elements, this is a limitation of other layouts appearing during result processing (i.e. blocked layout can not handle tensors smaller than wavesize).

For example, following dots are supported: (4x64) x (64x16) -> (4x16), (16x64) x (64x4) -> (16x4) or (8x64) x (64x8) -> (8x8)
Following dots are not supporter: (4x128) x (128x4) -> (4x4), (4x64) x (64x8) -> (4x8)

This is a first version of dot using mfma 4x4 instructions, with redundancy and reductions.
2023-12-12 18:23:55 +01:00
Jason Furmanek
f5f6b3c0a3 ROCM IFU: Add get_version_key for ROCM backend 2023-11-28 00:11:44 +00:00
Jason Furmanek
71547e4fdb ROCM IFU: Fixes for kwargs 2023-11-28 00:09:46 +00:00
jayfurmanek
99aa1f4f75 Merge branch 'triton-mlir' into ifu-231117 2023-11-27 07:44:04 -06:00
Jason Furmanek
968a35fbf0 Fix merge conflict error 2023-11-27 13:33:23 +00:00
Shucai Xiao
d9219e0eba use hw for fp8 type conversion (#386)
* use hardware instruction for type conversion between fp8 and fp32

* move gpu_matrix_core_version from semantics.py to hip_backend.py

---------

Co-authored-by: Aleksandr Efimov <efimov.alexander@gmail.com>
2023-11-24 10:26:40 -06:00
Jason Furmanek
a08dafe7fe Initial commit to resolve merge conflicts 2023-11-20 22:41:03 +00:00
Jason Furmanek
5c87f363e4 Merge commit 'cb3d79a185e40c9d8a579bea07747a8a8d157d52' into ifu-231117
Conflicts:
	lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
	lib/Dialect/TritonGPU/IR/Dialect.cpp
	python/setup.py
	python/test/unit/language/assert_helper.py
	python/test/unit/operators/test_flash_attention.py
	python/test/unit/runtime/test_subproc.py
	python/triton/compiler/compiler.py
	python/triton/language/semantic.py
	python/triton/runtime/autotuner.py
	python/triton/runtime/jit.py
	python/tutorials/03-matrix-multiplication.py
	python/tutorials/05-layer-norm.py
	python/tutorials/06-fused-attention.py
	python/tutorials/11-grouped-gemm.py
	test/Conversion/tritongpu_to_llvm.mlir
2023-11-17 20:42:12 +00:00
Jason Furmanek
484852876e Resolve merge conflicts; AMD adjustments for new LLVM version 2023-11-09 19:00:49 +00:00
Jason Furmanek
977d5aa267 Merge commit '721897fcc4f942aa97d2e9ba3787a5e213758177' into ifu-231108
Conflicts:
	bin/triton-translate.cpp
	lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
	lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
	python/triton/compiler/compiler.py
	python/triton/runtime/jit.py
	python/tutorials/06-fused-attention.py
	test/Conversion/tritongpu_to_llvm.mlir
2023-11-08 18:51:23 +00:00
Lixun Zhang
1af893d8a2 [FRONTEND] Add input dtypes to autotuning key (#2534) (#374)
* [FRONTEND] Add input dtypes to autotuning key (#2534)

* Fix conflict in 06-fused-attention

* Fix get_best_config in FA-transV.py

* Fix leftover get_best_config()

---------

Co-authored-by: Adnan Akhundov <adnan.akhundov@gmail.com>
2023-11-07 19:36:57 -06:00
Jason Furmanek
39e8901d7a ROCM IFU: Resolve merge conflicts in RemoveLayoutConversions.cpp
fix merge error

fix dot

fix make_range

additional fix
2023-11-07 04:29:38 +00:00
Jason Furmanek
c3132eeda8 ROCM IFU: Third-party fixes: preload ROCDL
Additional 3rd-party fix

Remove redundant mfma_supported defines
2023-11-07 03:29:16 +00:00
Jason Furmanek
3a6dc5ad8d resolve some merge conflicts
fix more conflits

Resolve merge conflicts

Some more build and conflict fixes

Resolve conflicts for 06-fused-attension.py

resolve merge conflicts for the tutorial group gemm example

Fixes for some LIT tests

resolve remaining conflicts in tests

Fix empty kernel

set capability 0
2023-11-06 23:13:10 +00:00
Jason Furmanek
33151a860f Merge commit 'ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33' into ifu-rebase-again
Conflicts:
	.gitignore
	.gitmodules
	README.md
	bin/triton-translate.cpp
	include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
	include/triton/Target/AMDGCN/AMDGCNTranslation.h
	include/triton/Target/HSACO/HSACOTranslation.h
	lib/Analysis/Allocation.cpp
	lib/Analysis/Utility.cpp
	lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
	lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/Utility.cpp
	lib/Conversion/TritonGPUToLLVM/Utility.h
	lib/Dialect/TritonGPU/IR/Dialect.cpp
	lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
	lib/Target/HSACO/CMakeLists.txt
	lib/Target/HSACO/HSACOTranslation.cpp
	lib/Target/LLVMIR/LLVMIRTranslation.cpp
	python/src/triton.cc
	python/test/unit/language/test_core.py
	python/test/unit/operators/test_flash_attention.py
	python/triton/compiler/compiler.py
	python/triton/compiler/make_launcher.py
	python/triton/language/semantic.py
	python/triton/runtime/jit.py
	python/tutorials/06-fused-attention.py
	python/tutorials/11-grouped-gemm.py
	test/Conversion/tritongpu_to_llvm.mlir
2023-11-06 23:10:10 +00:00
oplavsic
c65f1e6211 Add OptimizeEpilogue pass. (#346)
* optimize_epilogue

* Add config

* Remove licenses

* Comment out Hopper specific parameters when printing out configs

* Add benchmark parameters from flash-attention repo

* Add Z and H in the key of autotuner

---------

Co-authored-by: Lixun Zhang <lixun.zhang@amd.com>
2023-11-03 16:46:24 -05:00
Thomas Raoux
6ac9d51ff0 [OPTIMIZATION] Enable pipelining for bwd flash attention (#2590)
This allow pipelining when a load is used by multiple dot in a loop.

Relax the condition to pipeline dot operands for mma v3 case. This
improves performance for the bwd pass from 260TF to 275TF. However this
expose a performance problem due to the wmma pipelining as ptxas will
now fall back to serial wgmma. A follow up PR will fix a bug in how we
emit wgmma_wait during pipelining and will bring performance to 335TF
2023-11-03 11:46:51 -07:00
Shucai Xiao
cb02a0b346 remove unnecessary arch names (#388) 2023-11-03 12:04:02 -05:00
Justin Lebar
df08301e76 Reformat Python code with yapf. (#2589)
I've add an option to yapf to do what we want for long lines, see
https://github.com/google/yapf/pull/1177.  We can now have a real Python
formatter, yay!

To make this PR, I ran my modified yapf over the repository, then looked
over the full diff.  Where yapf was mangling the param list of long
function decls/calls (mostly kernels), I manually added `#` to put
linebreaks where we want.  I fixed up other formatting too -- mostly
adding or removing a trailing comma from lists.

Overall, trailing `#` was sufficient to get formatting similar to our
current code.  I didn't have to disable yapf anywhere.

---------

Co-authored-by: Phil Tillet <phil@openai.com>
2023-11-02 20:44:17 -07:00
Shucai Xiao
79bebc4ffe fp8 type support (#357)
* add two fp8 data types `tl.float8e4b8` and `tl.float8e5b16` to triton.
* add SW type conversion between `tl.float8e4b8/tl.float8e5b16` and `fp16`
* change flashattention to support fp8 in q/k.
2023-11-02 15:51:23 -05:00
jeromeku
37cd3d5339 [FIX][AOT Compiler] Fix No Specialization Edge Case (#2584)
The [hints
dispatching](218492cd65/python/triton/tools/link.py (L161))
logic currently fails for the edge case where a single kernel with no
specializations is to be linked in the [AOT
compiler](https://github.com/openai/triton/blob/main/python/triton/tools/link.py).

Since the dispatcher inserts a conditional branch for each
specialization case, this results in an `if ()` being inserted into the
`C` source, which clearly breaks downstream artifacts.

Fix:
- Added simple check for this edge case
- Added unit test that mirrors the existing
[`test_compile_link_matmul`](218492cd65/python/test/unit/tools/test_aot.py (L224))
test case save for the aforementioned condition.
2023-11-02 10:02:41 -07:00
Vedant Roy
702cde0d6f [FRONTEND] Implement ternary operator for dynamic values (#2560) 2023-11-01 20:21:32 -04:00
Chenggang Zhao
e7fdfd76fb [FRONTEND] Add value restoration for autotuner (#2549)
For in-place kernels, neither `reset_to_zero` nor `Config.prehook`
provided in the autotuner can restore the values changed during the
tuning process, so I propose a recovery mechanism here.

---------

Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2023-10-31 21:37:44 -04:00
Zahi Moudallal
3650213218 [OPTIMIZER] Thread local reduction optimization (#2542)
Co-authored-by: Phil Tillet <phil@openai.com>
2023-10-31 16:13:36 -07:00
Justin Lebar
258399c114 Enable ruff linter instead of flake8 (#2574)
[FRONTEND] Enable ruff linter instead of flake8.
    
This fixes a few issues automatically, and also flagged two issues to
fix manually in test_core.py: We had two duplicate function names!  One
of these function bodies was a duplicate, so I deleted it.  The other
function body was not a duplicate, so I gave it a new name.

AIUI all of these errors should have been picked up by flake8.  I'm
confused why it wasn't working.  Anyway this is working, and it's faster
than flake8, so it seems like an improvement in all dimensions.
2023-10-31 21:28:24 +00:00
Zahi Moudallal
943330790a [FRONTEND] add do_not_specialize property back to JITFunction (#2573) 2023-10-31 12:02:45 -07:00
Chris Jones
2398b82f18 [FRONTEND][BACKEND] dd memory synchronization scope parameter to atomic ops. (#2562)
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2023-10-30 19:18:27 -07:00
Keren Zhou
70fca00b67 [BACKEND] Fix device_print without arguments (#2566) 2023-10-30 20:04:44 -04:00
Keren Zhou
492886fcde [FRONTEND] Add reverse eq and ne (#2563) 2023-10-30 16:56:43 -04:00
Justin Lebar
12f906287f [FRONTEND] Refactor jit.py. (#2556)
[FRONTEND] Refactor jit.py.

The goal is to simplify the code and make it more flexible before we
change the kernel launch syntax to
`kernel[grid, compiler_flags(...)](...)`.

The main changes here are:

 - Get rid of the eval'ed code in make_launcher.  We can do everything
   using bind().
 - Add KernelParam and KernelArg classes, letting us get rid of the
   parallel arrays/dicts indexed by parameter index.
 - Get rid of duplicated kernel launch code in the cache-hit/cache-miss
   branches.
2023-10-30 13:14:51 -07:00
Justin Lebar
f88b01f558 Apply ruff pre-commit to python/triton/runtime. (#2558)
We're in the process of incrementally converting from autopep8 + flake8
+ isort to ruff, on a directory-by-directory basis.

The motivation to switch away from autopep8 is that I can't get it to
wrap long lines, even with -aaa.  This seems to be a known problem,
https://github.com/hhatto/autopep8/issues/497.

See more details about alternatives tried in
https://github.com/openai/triton/pull/2557.
2023-10-30 11:06:44 -07:00
Someone
cde42e6221 [BUILD] make cuda tools vendoring optional (#2546) 2023-10-26 23:16:41 -07:00
Michael Melesse
1fd9b40f2f Works as StandAlone and Backend and also Perf is Good
This is a combination of 4 commits.

Works as StandAlone and Backend

Works as StandAlone and Backend

This is a combination of 13 commits.

Works StandAlone and as Backend

This is a combination of 7 commits.

backend set default dir with flag

move bitcode to backend dir

copy backend

save

empty test work in backendmode

enable backend mode when copying to upstream

clean up

fix failure

minimize diff

add skip function

fix bug with corrupted dwarf exp

match num_wraps

fix multi threaded test issue

move bitcode file out of lib

move backend to python/triton/third_party/hip

move libhsa

backend works again

restart ci

clean upstream location first before copy

match scripts

fix new error

memoize backend stuff

fix bug
2023-10-26 14:27:18 -05:00
Michael Melesse
09ba348f87 [ROCM] Core Functionality for AMD (#1983)
* this pr adds a third party backend for triton that works on AMD
* this expose a lot of the work that has been done in our
[fork](https://github.com/ROCmSoftwarePlatform/triton)
* most unit tests on `test_core.py` pass
* it skips some unit tests for various reasons
* we plan to follow up with more prs improving Functionality and
Performance in the future

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-10-26 08:36:49 -05:00
runseny
4c816c2f59 [OPS] enable flash_attention_v2 TMA (#2544) 2023-10-25 23:31:17 -07:00
Alexander Efimov
5a86b46bb1 [MFMA] FP8 and BF8 support (#355)
* [MFMA] FP8 and BF8 support

This PR adds support of fp8 and bf8 in AccelerateMatmul pass and
Introduces generation of float8 mfma instructions in ttg to llvm conversion.

* add tests

* fix tests

* review fix: fix variable naming and dot operand promotion.

* review comments fixes

---------

Co-authored-by: Shucai Xiao <shucai.xiao@amd.com>
2023-10-25 13:27:10 -05:00
oplavsic
715a589ce3 [FA fwd D=128] Reduce LDS usage in epilogue (#340)
* rebase onto improve_fwd_fa

* Fixed a leftover from rebase

* rebase onto improve_fa_fwd

* Reduce tuning space

* Disable bwd with D=128

* Add test for d=128

* Fix an issue with get_best_config when there is only one config

* Added better configs for d=128

* Fix typos

---------

Co-authored-by: Lixun Zhang <lixun.zhang@amd.com>
2023-10-25 12:10:34 -05:00
Justin Lebar
e70e11e834 [BACKEND] Improve printf. (#2532)
[BACKEND] Improve printf.

Previously, we printed all of a GPU thread's values in a single printf()
call, and this, plus the user-specified prefix, was all we printed.

This caused a few problems.

 - nvptx printf can only handle 32 arguments; if you pass more than
   that, it prints garbage.  So if a thread had more than 32 values, you
   couldn't print them, issue #2486.

 - The order of the values within the Triton program (GPU thread block)
   is an implementation detail -- it depends on the layout the compiler
   assigns to a tensor.  So this also prevented you from interpreting
   the printed output.

To address this, we now print the Triton pid and multi-dimensional
Tensor index for each value.  And each value gets its own line to avoid
passing too many args to printf.

Example output:

    ```
    pid (0, 1, 2) idx (36, 127) x: 42
    ```

If you want to observe all the values in a tensor in order, you can grep
and then sort the output.

We also make a UX enhancement to print: The printed label always ends
with ": "; you don't have to add it yourself.

Fixes #2486.
2023-10-25 08:47:55 +00:00
Sam Shleifer
12da43084b [TESTING] add diff column, option to return df in benchmark (#2469) 2023-10-24 05:17:00 +00:00
Adnan Akhundov
50add54334 [FRONTEND] Add input dtypes to autotuning key (#2534) 2023-10-24 03:29:30 +00:00
runseny
dc9e3063d7 [HOPPER] Move to tl.make_block_ptr in flash_attention backward scripts (#2395) 2023-10-20 11:06:48 +08:00
Justin Lebar
30186f401e Fix segfault in assertion test. (#2520)
<git-pr-chain>

#### Commits in this PR
1. Fix segfault in assertion test.
    
The issue here is that we were not checking the return values of the
CUDA API
calls we were making. We call one function and then use the data it
returns as
input to another call. Obviously this doesn't work if the first call
returns
    an error and doesn't actually return meaningful data.
    
I don't know why this was passing in CI, but it failed consistently for
me.

#### [PR chain](https://github.com/jlebar/git-pr-chain)
1. 👉 #2520 👈 **YOU ARE HERE**


</git-pr-chain>
2023-10-19 13:42:38 -07:00