Commit Graph

1164 Commits

Author SHA1 Message Date
Izzy Putterman
9961b5c7aa [TESTING] allow user to adjust warmup and repetition time for autotuning (#1850)
Adds an option to adjust warmup and repetition time for autotuning. It
should default to old values and have no effect on current kernels.
This is useful for bigger kernels where runtime might be a sizable
fraction 100ms and lead to less warmup and more variance during
benchmarking.
2023-06-28 11:04:43 -07:00
Thomas
e5d7411a69 [BACKEND] Add .wt store cache modifier (#1831) 2023-06-28 17:40:30 +00:00
Goran Flegar
938a6754b4 [BUILD] Export compile commands (#1854)
This can be used by IDEs to figure out how to correctly compile
individual sources and offer semantic code completion.
2023-06-28 14:11:59 +00:00
Keren Zhou
d2de3f37f0 [BACKEND] Reduce code cleanup and bug fix for the fast path (#1816)
https://github.com/openai/triton/issues/1715
2023-06-27 17:27:24 -07:00
Zahi Moudallal
2dcbf4783e [BACKEND] Use getOrder for mma layout warps order instead of the hardcoded col-major order (#1825) 2023-06-27 10:56:09 -07:00
Philippe Tillet
d4c941177e [FRONTEND][BACKEND] improved fp8 specs (#1841)
clearly differentiate between standard fp8e4 (which we'll stop
supporting on SM <= 89 because conversions are too expensive if we want
to handle the single NaN and clipping properly) and a software-optimized
fp8e4b15 format.
2023-06-26 16:19:03 -07:00
Alexander Efimov
d830164699 [Dot] Support slicing in dot operands (#241)
* [Dot] Support slicing in dot operands

This PR enables support of sliced tensors of dot operands.
One of consequences is we can reenable prefetch and sw pipeline optimizations.

* refactor to comply with next changes with swizzling support

* comment fixes for code review
2023-06-26 12:55:34 -05:00
Izzy Putterman
3c400e7818 [FRONTEND] switch absolute imports to relative v2 (#1833) 2023-06-26 04:13:12 +00:00
Wang Weihan
a3c39d8fbe [TEST] Add device parameter for ut (#1817)
Triton has supported different codegen backends for different devices,
so enabling the unit test cases to support different devices also makes
sense. Otherwise, the third-party backend might have to intrusively
change the Triton test cases.
2023-06-25 15:38:59 +08:00
Thomas
3d1cd89b54 [BACKEND] Add store cache modifiers (#1826)
Plumb through store cache modifiers.
2023-06-23 09:29:10 -07:00
Zahi Moudallal
6ad8cd52e7 [CI] Added IR reference-check github workflow (#1755) 2023-06-22 18:00:40 -07:00
Wang Weihan
4d3a92f1b8 [BUILD] Make sure always build_ext first (#1819)
The third-party backend might install its python package to the
`triton/third_party` python package during the build process. But the
`build_py` could be executed before the `build_ext`, and then `build_py`
would only copy the `packages` defined in the `setup.py` w/o the
third-party related packages as the third-party backend has not been
built, which is triggered by `build_ext`. Therefore, this PR refined the
build order a little bit to ensure `build_ext` always happens before
`build_py`.
2023-06-22 13:32:03 -07:00
Zahi Moudallal
ca4f242c9b [TEST] Added matmul config for testing (#1758) 2023-06-22 13:31:37 -07:00
Goran Flegar
8d566e4196 [FRONTEND] Fix missing attribute access in DependenciesFinder (#1820)
It seems that patch #1773 introduced a bug, since the `lhs` object
doesn't necessarily have a `__name__` attribute.

I'm hitting this if I modify the matmul tutorial
(gflegar/triton@442b00f4d):

```
File "/home/gflegar/triton/python/triton/runtime/jit.py", line 74, in visit_Attribute
  if lhs is None or lhs.__name__ == "triton":
AttributeError: 'Tensor' object has no attribute '__name__'
```

I think the idea of that patch was to remove the need to import triton
by replacing `lhs is triton` with `lhs.__name__ == "triton"`. This patch
should have the same behavior as the original code, but withouth failing
if `lhs` doesn't havea `__name__` attribute.
2023-06-22 13:30:25 -07:00
Izzy Putterman
5686c51cdb [FRONTEND] allow pre-hook in autotuner configs to access config kwargs (#1814)
This is a very quick change that allows the configs' pre-hooks to see
the values in the config itself. This is useful if we'd like to allocate
intermediate tensor and the shape depends on tile size.
2023-06-22 05:40:48 -07:00
Philippe Tillet
0d6cd0307a [FRONTEND] add tie_break_left option to arg-reductions (#1813) 2023-06-21 19:35:52 -07:00
Philippe Tillet
4c0e3d907e [TOOLS] improved ahead-of-time compiler (#1805)
This is a revival of @gaxler initial ahead-of-time compiler proposal.
Code was simplified and some constraints were relaxed (i.e., we now
execute the entire file provided vs just the kernel AST) to promote
maintainability. A basic unit test was added, though it does not test
specialization right now.

co-authored by: Gregory Axler, thexler <g.axler@gmail.com>
2023-06-21 01:02:58 -07:00
Christopher Hesse
767dd5b5c1 [DOCS] update debug docs (#1790)
Haven't tested the docs (they don't seem to compile on my machine when I
use the makefile)
2023-06-20 20:45:17 -07:00
Thomas
4be1c94b1f Enable ir dumping for ttir toi ttgir phase. (#1804)
ttir_to_ttgir was missing enable_debug to be able to dump IR.
2023-06-20 17:17:18 +00:00
ZhenLei Xu
326061efd4 Remove Duplicate Definition of 'asm' Variable in triton/compiler/compiler.py (#1803) 2023-06-19 22:28:28 -07:00
Xinya Zhang
a14b1e2cc4 Add Navi3x bitcode files. (#239)
Ideally Triton should be able to locate bitcode files under /opt/rocm/ (or
$ROCM_HOME).

Co-authored-by: Shucai Xiao <shucai.xiao@amd.com>
2023-06-19 10:30:47 -05:00
Keren Zhou
1851c8ca99 [FRONTEND] Fix binary compare op on constexprs (#1801)
Example:

```
if static_a == 0 and static_b == 1:
    ...
```

The return value of `static_a == 0` should be `constexpr(True)` but not
`True`, otherwise the bool object (True/False) doesn't have the
`logical_and` method.
2023-06-18 20:27:56 -07:00
oplavsic
64d7b521cf [MFMA] Enabled fused attention forward pass. (#226)
* [MFMA] Activated Fused Attention Forward Pass

Patch contains following changes:
1) make_range operator now works with MFMA layout.
2) Reduce operation is forced to run in block layout:
   inputs converted to block layouts, outputs returned to MFMA layout

* Use simple module walk instead of pattern rewritter.

* Remove pattern rewritter header.

* Enable basic reduce algorithm for MFMA layout

* Add TODO comment for fused attention backward pass

* Fix bug in fast codegen algorithm for reduce op

* Fix input type bug

* Increase block size to 128 since out of memory issue is not seen on MI210

* Fix block_size error

* Add mfma support in DecomposeDotOperand pattern.
2023-06-16 15:39:08 -05:00
Christopher Hesse
981e98a213 [FRONTEND] update assert_helper.py (#1789) 2023-06-15 16:24:30 -07:00
Philippe Tillet
9a2580de13 [CI] Added H100 node (#1779) 2023-06-15 14:21:47 -07:00
Philippe Tillet
b24dc19741 [FRONTEND] cleaned up symbol names (#1782) 2023-06-14 18:55:32 -07:00
Izzy Putterman
71e21f5797 [FRONTEND] switch absolute imports to relative imports in Triton (#1773) 2023-06-14 23:59:24 +00:00
Philippe Tillet
754306b5d7 [FRONTEND] TRITON_PTXAS_PATH can now take in options (#1778) 2023-06-14 14:03:03 -07:00
Zahi Moudallal
ac15d00ef4 [TEST] Added f8xf16 tests (#1771) 2023-06-12 16:14:17 -07:00
Wang Weihan
b27a91a113 [FRONTEND] Enable triton to support register thirdparty backend at runtime (#1643)
This PR intends to provide a mechanism to support a third-party backend
at runtime to generate the backend-specific code.

The mechanism provided a common class to abstract the third-party
backend logic and two essential functions to register and get the
third-party backend at runtime.

- `BaseBackend`: A common class to abstract the third-party backend
logic
- `register_backend`: Register a third-party backend with a given device
type
- `get_backend`: Get the third-party backend with a given device type

Generally, a third-party backend must inherit from `BaseBackend` and
implement all the member functions according to the backend
characteristics. As long as the backend implementation is ready, the
third-party backend can invoke `register_backend` to register it under a
given device. During the kernel compilation and execution, the mechanism
will get the registered backend to generate the kernel and launcher code
for a given device.

This PR added a dummy backend to simulate a third-party backend and
demonstrate the usage.

-
[test_device_backend.py](https://github.com/openai/triton/pull/1643/files#diff-bbe4d50624f2d11bf17c878a1ed4d422918c124c182cf9357b993240c385bea1):
To define a third-party backend and register the backend
-
[ExtensionBackend](https://github.com/openai/triton/pull/1643/files#diff-bbe4d50624f2d11bf17c878a1ed4d422918c124c182cf9357b993240c385bea1R123):
Inherit from the `BaseBackend` and implement some specific logic like
[filter out some compile
stages](https://github.com/openai/triton/pull/1643/files#diff-bbe4d50624f2d11bf17c878a1ed4d422918c124c182cf9357b993240c385bea1R129-R135)
- [Register the `ExtensionBackend` for
`CPU`](https://github.com/openai/triton/pull/1643/files#diff-bbe4d50624f2d11bf17c878a1ed4d422918c124c182cf9357b993240c385bea1R279)
  
-
[extension_backend.c](https://github.com/openai/triton/pull/1643/files#diff-169c1d08b3a0a7b343cfa3258fbc32b47e0f6c46305a112652fa1bdaaec89d29):
To provide the utility function to load kernel binary and get the
backend properties.
2023-06-09 09:09:59 -07:00
jayfurmanek
29f93b147b Merge pull request #229 from ROCmSoftwarePlatform/ifu230601
IFU 230601
2023-06-09 07:55:32 -05:00
jayfurmanek
49ef240047 get Arch Info using HSA (#233) 2023-06-08 14:30:46 -05:00
Sophia Wisdom
2fb56dcf52 [FRONTEND] improve tl.dot error messages (#1757)
This makes it easier to figure out what's going on when you get these
error messages.
2023-06-07 19:39:16 -07:00
Keren Zhou
4fbadf6f6f [BACKEND] Fix tl.cat when the number of threads > the size of a tensor (#1751)
`tl.cat(tensor<64>, tensor<64>) -> tensor(128)`, because it concatenates
elements into a single thread, if number of threads is 128, each thread
should own at least 2 elements.
With this PR, we also disable remat of the cat op in some cases.
2023-06-07 15:42:38 -07:00
Aleksandr Efimov
0a12031c75 [Triton] Fix MFMA dot operand loading
This PR fixes computation of indexes of MFMA dot operands and gives variables more informative names.
2023-06-07 21:30:52 +02:00
Zahi Moudallal
1db570f45c [FRONTEND] Added arch to the json (#1747) 2023-06-06 17:06:03 -07:00
Edward Z. Yang
f294a18864 [FRONTEND] force quantile tensors to be float; prevents accidents (#1741)
In particular, sometimes this was failing with:

```
RuntimeError: quantile() input tensor must be either float or double dtype
```

Fixes https://github.com/pytorch/pytorch/issues/103054

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

---------

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
2023-06-05 20:55:40 -07:00
Keren Zhou
1eedef3596 [FRONTEND] Fix next_power_of_2 for large integers (#1740)
test case: `next_power_of_2(2**33 + 1)`
2023-06-05 20:29:05 -07:00
Philippe Tillet
c52a91231a [FRONTEND][BACKEND] Add acquire/release semantics for atomics (#1739) 2023-06-05 19:09:13 -07:00
Jason Furmanek
0497f95982 [ROCM] Fix assert helper 2023-06-05 21:42:44 +00:00
Philippe Tillet
9c8d7c18b3 [FRONTEND] simpler/faster argmax code (#1736)
@peterbell10 @Jokeren I believe that the current argmax reduction code
is overly complicated. Am I missing something here?
2023-06-04 21:30:55 -07:00
Philippe Tillet
6c1992cb38 [FRONTEND] min/max now accept return_indices argument (#1731)
Not just syntactic sugar for successive max + argmax but also avoids
computing the max twice
2023-06-02 22:01:02 -07:00
chengjunlu
45ba9af6ed [BACKEND] Add a configurable parameter for the number of threads per warp (#1719)
Add a configurable parameter for the number of threads per warp for
other GPU. Like: Intel GPU.

Make it default to be 32 not change code logic on the CUDA/AMD GPU.

Note: The Intel GPU GenX ISA is explicit SIMD and can support variant
number of threads lane per HW execution unit.
2023-06-02 16:55:06 -07:00
Sharad Vikram
acf1ede5bf [DEBUGGER] Use from __future__ import annotations to avoid importing torch from annotations (#1722)
Co-authored-by: Philippe Tillet <phil@openai.com>
2023-06-01 14:57:17 -07:00
jayfurmanek
153ed472b8 Merge branch 'triton-mlir' into ifu230601 2023-06-01 16:18:25 -05:00
Daniil Fukalov
6be1dce41c [ROCM] Fix transposed operands processing in dot operation with MFMA. (#227)
- Applied to `loadA()` the same fix as 2c88ed6aab9ace22ccde1f0e443a1579727ee501.
- Minor cleanup of `mfmaLayout.getWarpsPerCTA()` usage.

Partially fixes ROCmSoftwarePlatform/frameworks-internal#4545
2023-06-01 23:08:58 +02:00
Jason Furmanek
56c55e7451 Initial commit to resolve merge conflicts 2023-06-01 20:58:37 +00:00
Jason Furmanek
28d9754b2a Merge remote-tracking branch 'oai/main' into ifu230601
Conflicts:
	python/test/unit/language/assert_helper.py
	test/Conversion/tritongpu_to_llvm.mlir
2023-06-01 20:53:33 +00:00
Keren Zhou
1e171bf270 [BACKEND] Pipeline pass rewrite part 1: functionality fixes (#1716)
Support the following three cases:
1. Operands of `load` depend on induction variables before `load`s.
2. Mixed use of induction variables and offset to update the `ptr`.
3. Cross iteration (>1) dependency values.
2023-06-01 12:07:43 -07:00
Mehdi Amini
440fd1bf20 [TESTS] Increase the paylog of the globaltimer kernel to reduce chances of fakiness (#1726)
If the kernel is too small, on a very fast GPU we may get 0 because the
resolution of the timer seems too coarse.

Fixes #1725
2023-06-01 02:53:07 -07:00