Commit Graph

487 Commits

Author SHA1 Message Date
SJW
491eb9ddfe [MLIR] Added tritongpu-stream-pipeline pass (#305)
* [MLIR] Added tritongpu-stream-pipeline pass
     - Prologue: Hoist the pipelinable load operations and shared memory store
       for the ramp up stage
     - Pipelined Loop: Assemble the loop body minus last iteration
       - Prefetch next tile from global into regs (while computing from previous)
       - Non-load loop body
       - Store next tile into shared mem
     - Epilogue: Peeled non-load loop body for last iteration

* * updated comment
2023-09-07 15:24:59 -05:00
jon-chuang
36859aebff [DOCS] Add MLIR Autogenerated Docs to Sphinx Docs (#2234)
Partially fixes: https://github.com/openai/triton/issues/2226

Here are some example renderings:
![Screenshot from 2023-09-04
18-39-20](https://github.com/openai/triton/assets/9093549/e9c4af04-aeae-4021-a8db-6a4a82b59ae7)
![Screenshot from 2023-09-04
18-39-30](https://github.com/openai/triton/assets/9093549/410391b8-e07e-4bed-909c-8ce5484072d1)
![Screenshot from 2023-09-04
18-39-41](https://github.com/openai/triton/assets/9093549/f1eaef95-66c1-4506-a153-c6069e2b5072)
2023-09-06 08:17:12 +00:00
Zahi Moudallal
acbf716889 [BACKEND] Refactoring NVGPUToLLVMPass (#2158) 2023-09-01 23:40:31 +00:00
Jason Furmanek
df5c263a19 Fix merge conflicts 2023-09-01 04:01:32 +00:00
Jason Furmanek
3eaeb89d18 Merge commit '5df904233c11a65bd131ead7268f84cca7804275' into ifu230810-2
Conflicts:
	include/triton/Dialect/Triton/Transforms/Passes.h
	include/triton/Dialect/TritonGPU/IR/Dialect.h
	include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
	lib/Analysis/Allocation.cpp
	lib/Analysis/Utility.cpp
	lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
	lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp
	lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
	lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp
	lib/Target/LLVMIR/LLVMIRTranslation.cpp
	python/src/triton.cc
	python/triton/compiler/compiler.py
	python/triton/ops/flash_attention.py
	python/triton/runtime/autotuner.py
	python/triton/runtime/jit.py
	python/triton/tools/aot.py
	python/tutorials/06-fused-attention.py
	test/Conversion/tritongpu_to_llvm.mlir
	test/Target/tritongpu_to_llvmir.mlir
	test/Target/tritongpu_to_llvmir_noinline.mlir
2023-09-01 03:25:33 +00:00
Michael Melesse
c6d33dcebf [ROCM] Core Functionality for AMD (#1983)
* this pr adds a third party backend for triton that works on AMD 
* this expose a lot of the work that has been done in our
[fork](https://github.com/ROCmSoftwarePlatform/triton)
* most unit tests on `test_core.py` pass
* it skips some unit tests for various reasons
* we plan to follow up with more prs improving Functionality and
Performance in the future

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-08-31 14:02:00 -07:00
Thomas
d4644d6cb3 [BACKEND] Refactor RemoveLayoutConversion pass (#2181)
Significant changes to the pass logic. Move away from greedy rewrites
and use more global analysis instead. The pass is now bocken down into 2
main phases. First forward propagation of layout starting from ops that
we don't want to change. Propagate to all the nodes. If there is a
single layout needed for the op then we can rewrite the op, if there are
multiple layout required based on dependency we need a tie break.
The second phase is backward propgation that gets a backward slice of
operations starting from the convert and if all the operations in the
slice can be rematerialized rewrite the slice. This backward phase now
supports going through loop arguments.

This will allow more complex logic in the future to add a cost model to
decide which convert to leave and which to fold
2023-08-28 19:05:16 -07:00
peterbell10
fa03b92109 [OPTIMIZER] Add folder for MakeRangeOp (#2187)
This folds `tl.arange(x, x + 1)` into a constant. This shows up for
example when autotuning and one of the block sizes gets set to 1.

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-08-26 16:44:13 +00:00
jayfurmanek
ff7e707f87 Enable usage of block pointer semantics for AMD gpus (#301)
* Enable usage of block pointer semantics for AMD gpus

This commit enables usage of block pointer semantics by enabling
rewrite_tensor_pointer_pass that rewrites block pointer loads/stores
to legacy loads/stores.

* Update FA fwd in tutorial to use the block pointers

* use 90 compute capability for amd gpus in python/triton/compiler/compiler.py

Co-authored-by: Alexander Efimov <efimov.alexander@gmail.com>

---------

Co-authored-by: Ognjen Plavsic <ognjen.plavsic@dxc.com>
Co-authored-by: Lixun Zhang <lixun.zhang@amd.com>
Co-authored-by: Aleksandr Efimov <130555951+alefimov-amd@users.noreply.github.com>
Co-authored-by: Alexander Efimov <efimov.alexander@gmail.com>
2023-08-24 13:05:12 -05:00
Thomas
3116933ccd [BACKEND] Don't do dead code elimination on volatile load (#2165) 2023-08-23 14:59:18 -07:00
ivanyinwz
ec801ce18e [BACKEND] Optimize performance for f16 epilogue with TMA store (#2135)
1. Optimize the conversion and packing for 2xf32 -> 2xf16.
2. Split TMA store block into multiple slices of size 64x64.
3. Distribute the TMA store to all the warps.
4. Fix some naming issue.
2023-08-21 12:44:11 -07:00
Thomas
54ca7fcb35 [FRONTEND] Use inline asm for global timer and smid functions (#2143)
Simplify the code by using inline asm to implement globaltimer and smid
instead of relying on bc file.
2023-08-20 22:56:37 -07:00
Alexander Zinoviev
d5188fa230 [BACKEND] enable transpose for float16 on sm75 (#2139)
Replace the Turing version for the dot operation from following Volta
version to following Ampere version.

Update code generator to produce two m16.n8.k8 MMAs for Turing instead
of one m16.n8.k16 MMA we have for Ampere.
2023-08-18 22:20:17 -07:00
Thomas
23ef2615d2 [BACKEND] Merge TT_ElementwisePureExtern and TT_ElementwiseImpureExtern (#2137)
Use getEffect instead to tell passes whether the op has side effects or
not. This doesn't change functionality otherwise.
2023-08-18 20:56:10 +00:00
Thomas
bf351b9ba2 [FRONTENT][BACKEND] Add support for elementwise inline assembly (#2136)
Add a new operation to be able to implement packed inline assembly for
elementwise operations. This way inline assembly can be used to control
elementwise operations. It also allows to pack elements to be able to
manually vectorize operations.
2023-08-18 12:57:52 -07:00
Thomas
c736ea8492 [BACKEND] Minor clean up and remove loop fixup as it is not needed anymore (#2116) 2023-08-18 12:12:45 -07:00
Whitney Tsang
100cabd0e4 [FRONTEND] use enum instead of bool to select target (#2118)
Before this PR, the determination of `TritonGPUToLLVMIRPass` to generate
NVVM-compatible LLVM or ROCDL-compatible LLVM is controlled by a boolean
`isROCM`. This method is hard to scale.
This PR changes it to use an enum instead, where new target can be added
easily when needed.

---------

Signed-off-by: Tsang, Whitney <whitney.tsang@intel.com>
Co-authored-by: Philippe Tillet <phil@openai.com>
2023-08-17 18:37:09 -07:00
Lixun Zhang
e924e0ecf6 [NFC] Rename MMAv2kWidth to kWidth (#299) 2023-08-17 07:45:22 -05:00
Lixun Zhang
eb940823c4 [OPTIMIZER][BACKEND] Rename MMAv2kWidth to kWidth (#2119) 2023-08-17 04:00:20 +00:00
Lixun Zhang
87e45cb011 Set vecSize and maxPhase more generically 2023-08-16 08:30:32 -05:00
Lixun Zhang
7156fcb0ef Set vecSize = 4 and maxPhase = BLOCK_K/4 2023-08-16 08:30:32 -05:00
darkbuck
a3df6068b4 [BACKEND] Minor fixes found when building triton with LLVM 17/main branches (#2089)
- These minor fixes are not specific to interface changes from LLVM main
or official llvm-17 branch and can be applied on triton main branch.
- https://github.com/darkbuck/triton/tree/darkbuck/main/llvm-main-branch
has extra changes to build again LLVM main branch build to enable me to
work on other backends on the main branch only. That's the hobby effort
and just FYR.
2023-08-16 01:18:06 +00:00
Whitney Tsang
129e7dfc6f [TritonGPUToLLVM] Correct the usage of option passing (#2104)
For example, when given `--convert-triton-gpu-to-llvm="is-rocm=true"`,
`ConvertTritonGPUToLLVMPass` should generate ROCM-compatible LLVM.
Before this PR, transformation options passed in command line are not
respected.
2023-08-16 00:56:01 +00:00
Zahi Moudallal
a01c116f76 [FRONTEND/BACKEND] Revived Float8E4B15x4 (#2090) 2023-08-11 17:49:52 -07:00
Zahi Moudallal
4d373aa103 [BACKEND] Remove HopperHelpers.c and replace with inline ptx and LLVM codegen (#2047) 2023-08-10 15:52:37 -07:00
Alexander Efimov
a1f4ee6a6b [MFMA] Swizzled operands (#285)
This pr enables generation of swizzled tensors for mfma dot operands
2023-08-09 14:53:15 -05:00
allatit23
6dee55c912 [HOPPER][WS] fix TMA store hang in ws mode (#2056) 2023-08-08 19:53:52 +08:00
Thomas
98523bcc48 [BACKEND] Support MMA V3 with float16 accumulator (#2049)
Also fixes a bug exposed in convertLayout lowering for float16. We
shouldn't be using cvt.pack.sat.u16.s32 to pack 16bits values as this
needs to take a 32bits register. Also this prevented optimization at
llvm ir level.
2023-08-07 15:55:44 -07:00
goostavz
f1512bded1 Initial code merge of Hopper support (#2036)
The initial code merge of Nvidia Hopper features support. Please be
aware that the code merge is not finished yet and the trouble-shooting
is still ongoing. The new hardware features (GMMA, TMA, STMATRIX etc.)
and automatic warp-specialization are experimental for now and turned
off by default. It is recommended for a trial when version 3.0 is
released.

The work is contributed by:
ben-zhang-609, bealwang, donproc, qliu93, jsh20, allatit23, LyricZhao,
ivanyinwz, goostavz & yangjunpro
from Nvidia, in cooperation with:
ptillet, Jokeren, ThomasRaoux & zahimoud
from OpenAI.

Co-authored-by: Goostav Zhu <gzhu@nvidia.com>
2023-08-07 09:53:04 +08:00
Philippe Tillet
52c146f66b [OPTIMIZER][BACKEND] significantly cleaner handling of mixed-precision kernels (#1949)
we currently have a very janky approach to optimizing mixed-precision
matmul workloads, where some layout combinations (e.g., NT matmul) were
explicitly pattern-matched to take a more optimized codepath. Attempt at
unifying all the codepaths to codegen cp.async failed, due to bugs in
SharedToDotOperandMMAv2.cpp.

This PR fixes said bugs, add some assertions for SharedToDotOperandMMAv2
modes that aren't well supported, and greatly simplify our handling of
element-wise operations between load and conversions to DotOperand.
2023-07-28 10:29:42 -07:00
Alexander Efimov
0073bb98f4 [MFMA] Introduce dot operand loading fast path (#269)
* [MFMA] Introduce dot operand loading fast path

This PR introduces fast path for code generation of MFMA dot operand
loading from LDS.

Fast path is used when operand is not swizzled and is not slice of some
bigger LDS object(it is not a slice of a tensor).
This is a case for current FA and GEMM kernels compiled with
num_stages=1, i.e. software pipelining is disabled.

* cleanup swizzle info
2023-07-27 13:46:50 -05:00
Shucai Xiao
356595f3dc revert PR#245 and #241 for gemm performance regression (#265)
* Revert "[MFMA][Dot] Swizzle mfma dot operands  (#245)"

This reverts commit 4d0deef45f.

* revert the swizzle and slice PRs for gemm performance regression

* fix build error after reverting the PR#245(commit 4d0deef45f) Swizzle mamf dot operands
2023-07-24 14:09:46 -05:00
Philippe Tillet
1db3bdc52e [BACKEND] avoid code duplication for fully warp-synchronous reductions (#1978) 2023-07-21 16:06:00 -07:00
Thomas
77ceb4ddb5 [BACKEND] Avoid circular dependency between Triton and TritonGPU (#1938) 2023-07-13 08:40:05 +02:00
oplavsic
d6e51fd221 [FA OPTIMIZATION] Keep results of FA dot operations in registers (#247)
* [WIP][FA OPTIMIZATION] Optimize chain dot

This commit optimizes chain dot operation by keeping
results of the first dot operation in registers.

* [FA OPTIMIZATION] Enable lowering pipeline for keeping result of chain dot in registers

* Move operand swapping in ttgir -> llir lowering phase

* Refactor emitMfmaOffsetForCTA function to be more readable

* Fix accidental change in 06-fused-attention.py

* Address review comments

* Fix rebase errors
2023-07-12 15:25:55 -05:00
Alexander Efimov
4d0deef45f [MFMA][Dot] Swizzle mfma dot operands (#245)
* swizzling when loading dot operands

* [MFMA][Dot] Swizzle mfma dot operands

This PR supports swizzling in MFMA dot operands.

* fix comments

* Update TritonGPUAttrDefs.td

* Update TritonGPUAttrDefs.td 2

---------

Co-authored-by: weihanmines <wei.han3@amd.com>
2023-07-12 13:19:31 -05:00
jayfurmanek
21c757564a Merge branch 'triton-mlir' into ifu230620 2023-07-12 11:29:23 -05:00
Alexander Efimov
ff3751682e [MFMA] Parameterize MFMA layout with non k dimension size (#248)
This PR adds nonKDim attribute to MFMA layout to align with dot-rocMLIR branch and adds description of mfma16 layout.
2023-07-11 14:48:52 +02:00
peterbell10
e3d9478d31 [OPTIMIZER] Add pass to move broadcasts after elementwise operations (#1811)
This adds a pass that tries to reduce the shape of tensor arguments to
element-wise operations by moving splat and broadcast operations later
in the graph. So, for example say we have:

```python
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset  + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (0))
    tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
    tmp2 = 0.017453292519943295
    tmp3 = tmp1 * tmp2
    tmp4 = tl.sin(tmp3)
    tl.store(out_ptr0 + (x0), tmp4, None)
```

Today this results in duplicate `sin` calls:
```
    %27 = llvm.fmul %26, %3  : f32
    %28 = llvm.call @__nv_sinf(%27) : (f32) -> f32
    %29 = llvm.call @__nv_sinf(%27) : (f32) -> f32
```

The duplicate `llvm.fmul` calls are eliminated via CSE, but `llvm.call`
doesn't get CSE'd because it might be impure.

After this change, the sin is done on a scalar value in the triton IR
and splatted at the very end, so no duplicate calculation happens within
a thread.

---------

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Philippe Tillet <phil@openai.com>
2023-07-10 11:44:38 -07:00
Thomas
bd900e0a6f [BACKEND] Fix reductions when number of unique element is smaller than layout (#1913)
Fix calculation of unique number of threads within a warp. We need to
consider the number of elements per thread in the calculation. Also
change the layout test to integer sum in order to catch bugs with unique
data as max reduction may hide those kind of problems.
2023-07-07 19:48:13 -07:00
Keren Zhou
cc5a7ed52f [FRONTEND][BACKEND] Materialize line info for triton kernels (#1902)
`export TRITON_DISABLE_LINE_INFO=1` to disable the feature.
2023-07-07 16:03:44 -04:00
Philippe Tillet
6d1285e1ae [FRONTEND][BACKEND] improved fp8 specs (#1906)
This un-reverts commit
d4c941177e.
2023-07-06 13:03:53 -07:00
Christian Sigg
3442904a92 [BACKEND] Avoid circular dependencies (#1877)
Recent changes made TritonGPU dialect depend on transform utils
(`isExpensiveCat()`), and Triton ops depend on TritonGPU dialect
(`DotOperandEncodingAttr`). This works fine with CMake but circular
dependencies are not ideal and Bazel builds (which we use internally at
Google) try hard to prevent them.

Would it be acceptable to move the `isExpensiveCat()` function back to
TritonGPU dialect (where it was before), and split the TritonGPU
attributes into a separate header? This would avoid diverging our
internal version or creating over-sized bazel targets to avoid circular
dependencies.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2023-07-05 16:08:51 -04:00
Jason Furmanek
12005a82f2 Initial commit to resolve merge conflicts 2023-06-30 19:53:53 +00:00
Thomas
2e3182bab7 [BACKEND] Support scan on dimensions other that fastest moving one (#1863)
This relax the restriction in the scan lowering to support layout where
we scan along a dimension which isn't the fastest moving one. This is
done by relaxing how we accesses elements during scanning and allow
elements to be strided.
2023-06-30 12:40:48 -07:00
Philippe Tillet
f77015967d Revert "[FRONTEND][BACKEND] improved fp8 specs (#1841)" (#1865)
This reverts commit d4c941177e.
2023-06-29 21:07:01 -04:00
Jason Furmanek
2b38ab4b6c Merge remote-tracking branch 'oai/main' into ifu230620
Conflicts:
	include/triton/Conversion/TritonToTritonGPU/Passes.td
	include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
	lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
	python/test/unit/language/assert_helper.py
	python/triton/compiler/compiler.py
	python/triton/runtime/jit.py
	python/triton/tools/aot.py
	test/Conversion/triton_to_tritongpu.mlir
	test/Conversion/tritongpu_to_llvm.mlir
2023-06-29 21:47:27 +00:00
Thomas
3be060849a [FEATURE] Add associative_scan support (#1858)
Implement associative_scan in the front end and implement lowering to
LLVM for blocked layout where the scan happens on the fastest moving
dimension. This will later be generalized to support more layout.
2023-06-29 14:37:51 -07:00
Xinya Zhang
36f7c9872a [ROCM private] Change the default warp size to 64 for IR compatibility for existing MI GPUs.
CAVEAT: This commit is supposed to be a custom patch in ROCM/triton fork.
        Think twice before submitting this commit as PR to upstream.

MI Developers have collected a large set of IR files and use triton
command line tools for development extensively, where warp size == 64 is
assumed for AMD GPUs. Unfortunately the behavior of the compiler after the TTGIR pass
depends on the warp size property, and changing the default value will
make existing IRs unusable for MI GPUs.

This commit aims to preserve the old behavior when warp size is not specified in TTGIR.
For general Triton users it should have zero effects since warp size is
always set explicitly in compiler.py to match the target architecture.

Additionally this commit reverts part of the upstream change to maintain
the unit tests for wave64 architectures.
2023-06-28 22:25:14 -05:00
Xinya Zhang
75b86da598 Add configurable wavefront size support for Navi/MI.
[To squash] Configurable warp size in test_core_amd.py::test_convert2d

Note: test_core_amd.py::test_convert2d unit tests have been changed
because some of the old layouts exceed the shared memory limit (64KiB)
2023-06-28 22:25:14 -05:00