Commit Graph

820 Commits

Author SHA1 Message Date
Philippe Tillet
dadf7a9a50 [TUTORIAL] Faster flash attention; added non-causal (#1917) 2023-07-09 13:38:06 -07:00
Thomas
bd900e0a6f [BACKEND] Fix reductions when number of unique element is smaller than layout (#1913)
Fix calculation of unique number of threads within a warp. We need to
consider the number of elements per thread in the calculation. Also
change the layout test to integer sum in order to catch bugs with unique
data as max reduction may hide those kind of problems.
2023-07-07 19:48:13 -07:00
Natalia Gimelshein
778ed64a66 [BACKEND] make sure we always bind to primary context in loadBinary (#1912) 2023-07-07 14:28:03 -07:00
Bert Maher
38d767ea93 [FRONTEND] fix memory leak caused by retaining args to autotuned kernel (#1911) 2023-07-07 20:58:29 +00:00
Keren Zhou
cc5a7ed52f [FRONTEND][BACKEND] Materialize line info for triton kernels (#1902)
`export TRITON_DISABLE_LINE_INFO=1` to disable the feature.
2023-07-07 16:03:44 -04:00
Philippe Tillet
61e17db4f7 [TESTS] smaller problem sizes in matmul tests (#1908) 2023-07-06 14:36:02 -07:00
Philippe Tillet
6d1285e1ae [FRONTEND][BACKEND] improved fp8 specs (#1906)
This un-reverts commit
d4c941177e.
2023-07-06 13:03:53 -07:00
Thomas
787cdff0cd [TESTS] Enable parallel pytest in CI for CUDA (#1905)
Run most of the pytest in parallel, this allows to speed up CI from
36min to 10min for A100 and 22min to 6min for H100. Some tests still
need to run serially like runtime tests.
2023-07-06 11:40:33 -07:00
Thomas
ae0ee5248f [FRONTEND] Add cumprod scan op (#1894)
Add and test cumprod. This also allows testing a case of accumulation
where 0 is not the identity element.
Also add documention for scan functions.
2023-07-05 10:09:06 -07:00
Philippe Tillet
d57dcd9994 [FRONTEND] pattern-match ExpandDims + Mul + Reduce into DotOp (#1889)
transforms e.g.:
```
x1 = tl.expand_dims(x0, axis=2)
y1 = tl.expand_dims(y0, axis=0)
z = tl.sum(x1 * y1, axis=1)
```

into
```
z = tl.dot(x0, y0) 
```

uses allowTF32 = True by default.
2023-07-05 00:31:21 -07:00
Yongjik Kim
387cdc8fe9 [FRONTEND] improve error message. (#1887) 2023-07-05 00:19:28 -07:00
Christopher Hesse
cc93356c4b [DOCS] update print docs (#1884) 2023-07-05 00:19:12 -07:00
Keren Zhou
16de0d6f55 [BACKEND] Recover the correct axis in the parent layout for reduce ops (#1885)
https://github.com/openai/triton/issues/1883
2023-07-05 00:17:00 -07:00
Keren Zhou
aa55d3b5bf [FRONTEND] Remove unnecessary replace in while op's after block (#1886)
We've already updated the mapping between name and tensor before
visiting each compound statement in the while op. As a result, any
overwritten name gets up-to-date values updated in the while loop. And
any unchanged livein names hold the original tensors.
2023-07-05 00:16:38 -07:00
Ying Hang Eng
55eb32dff9 Fix tl.device_assert compilation error (#1875)
As mentioned in #1769, we set file name, function name to 'unknown' and
lineno to 0 if frame is None
2023-06-30 22:19:25 +00:00
Thomas
2e3182bab7 [BACKEND] Support scan on dimensions other that fastest moving one (#1863)
This relax the restriction in the scan lowering to support layout where
we scan along a dimension which isn't the fastest moving one. This is
done by relaxing how we accesses elements during scanning and allow
elements to be strided.
2023-06-30 12:40:48 -07:00
Oleg Shyshkov
66ed53d19d [FRONTEND] Support mixed-precision inputs in triton.ops.matmul. (#1754)
Support only combinations of float32 with float16 or bfloat16 for now.
Shouldn't change anything for cases when input types match.

That's a follow-up to the comment in my other PR:
https://github.com/openai/triton/pull/1746#issuecomment-1579630016.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-06-30 09:22:27 -07:00
Thomas
7a8a2da8ef [BACKEND] Enable lowering of f16 constant matmul (#1870)
Since the type expected for mma encoding is i32 when lowering f16 splat
we need to pack f16 constants into a i32 value. This allows re-enabling
the constant matmul unit test.
2023-06-30 07:00:25 -04:00
Philippe Tillet
f77015967d Revert "[FRONTEND][BACKEND] improved fp8 specs (#1841)" (#1865)
This reverts commit d4c941177e.
2023-06-29 21:07:01 -04:00
Thomas
3be060849a [FEATURE] Add associative_scan support (#1858)
Implement associative_scan in the front end and implement lowering to
LLVM for blocked layout where the scan happens on the fastest moving
dimension. This will later be generalized to support more layout.
2023-06-29 14:37:51 -07:00
Izzy Putterman
9961b5c7aa [TESTING] allow user to adjust warmup and repetition time for autotuning (#1850)
Adds an option to adjust warmup and repetition time for autotuning. It
should default to old values and have no effect on current kernels.
This is useful for bigger kernels where runtime might be a sizable
fraction 100ms and lead to less warmup and more variance during
benchmarking.
2023-06-28 11:04:43 -07:00
Thomas
e5d7411a69 [BACKEND] Add .wt store cache modifier (#1831) 2023-06-28 17:40:30 +00:00
Goran Flegar
938a6754b4 [BUILD] Export compile commands (#1854)
This can be used by IDEs to figure out how to correctly compile
individual sources and offer semantic code completion.
2023-06-28 14:11:59 +00:00
Keren Zhou
d2de3f37f0 [BACKEND] Reduce code cleanup and bug fix for the fast path (#1816)
https://github.com/openai/triton/issues/1715
2023-06-27 17:27:24 -07:00
Zahi Moudallal
2dcbf4783e [BACKEND] Use getOrder for mma layout warps order instead of the hardcoded col-major order (#1825) 2023-06-27 10:56:09 -07:00
Philippe Tillet
d4c941177e [FRONTEND][BACKEND] improved fp8 specs (#1841)
clearly differentiate between standard fp8e4 (which we'll stop
supporting on SM <= 89 because conversions are too expensive if we want
to handle the single NaN and clipping properly) and a software-optimized
fp8e4b15 format.
2023-06-26 16:19:03 -07:00
Izzy Putterman
3c400e7818 [FRONTEND] switch absolute imports to relative v2 (#1833) 2023-06-26 04:13:12 +00:00
Wang Weihan
a3c39d8fbe [TEST] Add device parameter for ut (#1817)
Triton has supported different codegen backends for different devices,
so enabling the unit test cases to support different devices also makes
sense. Otherwise, the third-party backend might have to intrusively
change the Triton test cases.
2023-06-25 15:38:59 +08:00
Thomas
3d1cd89b54 [BACKEND] Add store cache modifiers (#1826)
Plumb through store cache modifiers.
2023-06-23 09:29:10 -07:00
Zahi Moudallal
6ad8cd52e7 [CI] Added IR reference-check github workflow (#1755) 2023-06-22 18:00:40 -07:00
Wang Weihan
4d3a92f1b8 [BUILD] Make sure always build_ext first (#1819)
The third-party backend might install its python package to the
`triton/third_party` python package during the build process. But the
`build_py` could be executed before the `build_ext`, and then `build_py`
would only copy the `packages` defined in the `setup.py` w/o the
third-party related packages as the third-party backend has not been
built, which is triggered by `build_ext`. Therefore, this PR refined the
build order a little bit to ensure `build_ext` always happens before
`build_py`.
2023-06-22 13:32:03 -07:00
Zahi Moudallal
ca4f242c9b [TEST] Added matmul config for testing (#1758) 2023-06-22 13:31:37 -07:00
Goran Flegar
8d566e4196 [FRONTEND] Fix missing attribute access in DependenciesFinder (#1820)
It seems that patch #1773 introduced a bug, since the `lhs` object
doesn't necessarily have a `__name__` attribute.

I'm hitting this if I modify the matmul tutorial
(gflegar/triton@442b00f4d):

```
File "/home/gflegar/triton/python/triton/runtime/jit.py", line 74, in visit_Attribute
  if lhs is None or lhs.__name__ == "triton":
AttributeError: 'Tensor' object has no attribute '__name__'
```

I think the idea of that patch was to remove the need to import triton
by replacing `lhs is triton` with `lhs.__name__ == "triton"`. This patch
should have the same behavior as the original code, but withouth failing
if `lhs` doesn't havea `__name__` attribute.
2023-06-22 13:30:25 -07:00
Izzy Putterman
5686c51cdb [FRONTEND] allow pre-hook in autotuner configs to access config kwargs (#1814)
This is a very quick change that allows the configs' pre-hooks to see
the values in the config itself. This is useful if we'd like to allocate
intermediate tensor and the shape depends on tile size.
2023-06-22 05:40:48 -07:00
Philippe Tillet
0d6cd0307a [FRONTEND] add tie_break_left option to arg-reductions (#1813) 2023-06-21 19:35:52 -07:00
Philippe Tillet
4c0e3d907e [TOOLS] improved ahead-of-time compiler (#1805)
This is a revival of @gaxler initial ahead-of-time compiler proposal.
Code was simplified and some constraints were relaxed (i.e., we now
execute the entire file provided vs just the kernel AST) to promote
maintainability. A basic unit test was added, though it does not test
specialization right now.

co-authored by: Gregory Axler, thexler <g.axler@gmail.com>
2023-06-21 01:02:58 -07:00
Christopher Hesse
767dd5b5c1 [DOCS] update debug docs (#1790)
Haven't tested the docs (they don't seem to compile on my machine when I
use the makefile)
2023-06-20 20:45:17 -07:00
Thomas
4be1c94b1f Enable ir dumping for ttir toi ttgir phase. (#1804)
ttir_to_ttgir was missing enable_debug to be able to dump IR.
2023-06-20 17:17:18 +00:00
ZhenLei Xu
326061efd4 Remove Duplicate Definition of 'asm' Variable in triton/compiler/compiler.py (#1803) 2023-06-19 22:28:28 -07:00
Keren Zhou
1851c8ca99 [FRONTEND] Fix binary compare op on constexprs (#1801)
Example:

```
if static_a == 0 and static_b == 1:
    ...
```

The return value of `static_a == 0` should be `constexpr(True)` but not
`True`, otherwise the bool object (True/False) doesn't have the
`logical_and` method.
2023-06-18 20:27:56 -07:00
Christopher Hesse
981e98a213 [FRONTEND] update assert_helper.py (#1789) 2023-06-15 16:24:30 -07:00
Philippe Tillet
9a2580de13 [CI] Added H100 node (#1779) 2023-06-15 14:21:47 -07:00
Philippe Tillet
b24dc19741 [FRONTEND] cleaned up symbol names (#1782) 2023-06-14 18:55:32 -07:00
Izzy Putterman
71e21f5797 [FRONTEND] switch absolute imports to relative imports in Triton (#1773) 2023-06-14 23:59:24 +00:00
Philippe Tillet
754306b5d7 [FRONTEND] TRITON_PTXAS_PATH can now take in options (#1778) 2023-06-14 14:03:03 -07:00
Zahi Moudallal
ac15d00ef4 [TEST] Added f8xf16 tests (#1771) 2023-06-12 16:14:17 -07:00
Wang Weihan
b27a91a113 [FRONTEND] Enable triton to support register thirdparty backend at runtime (#1643)
This PR intends to provide a mechanism to support a third-party backend
at runtime to generate the backend-specific code.

The mechanism provided a common class to abstract the third-party
backend logic and two essential functions to register and get the
third-party backend at runtime.

- `BaseBackend`: A common class to abstract the third-party backend
logic
- `register_backend`: Register a third-party backend with a given device
type
- `get_backend`: Get the third-party backend with a given device type

Generally, a third-party backend must inherit from `BaseBackend` and
implement all the member functions according to the backend
characteristics. As long as the backend implementation is ready, the
third-party backend can invoke `register_backend` to register it under a
given device. During the kernel compilation and execution, the mechanism
will get the registered backend to generate the kernel and launcher code
for a given device.

This PR added a dummy backend to simulate a third-party backend and
demonstrate the usage.

-
[test_device_backend.py](https://github.com/openai/triton/pull/1643/files#diff-bbe4d50624f2d11bf17c878a1ed4d422918c124c182cf9357b993240c385bea1):
To define a third-party backend and register the backend
-
[ExtensionBackend](https://github.com/openai/triton/pull/1643/files#diff-bbe4d50624f2d11bf17c878a1ed4d422918c124c182cf9357b993240c385bea1R123):
Inherit from the `BaseBackend` and implement some specific logic like
[filter out some compile
stages](https://github.com/openai/triton/pull/1643/files#diff-bbe4d50624f2d11bf17c878a1ed4d422918c124c182cf9357b993240c385bea1R129-R135)
- [Register the `ExtensionBackend` for
`CPU`](https://github.com/openai/triton/pull/1643/files#diff-bbe4d50624f2d11bf17c878a1ed4d422918c124c182cf9357b993240c385bea1R279)
  
-
[extension_backend.c](https://github.com/openai/triton/pull/1643/files#diff-169c1d08b3a0a7b343cfa3258fbc32b47e0f6c46305a112652fa1bdaaec89d29):
To provide the utility function to load kernel binary and get the
backend properties.
2023-06-09 09:09:59 -07:00
Sophia Wisdom
2fb56dcf52 [FRONTEND] improve tl.dot error messages (#1757)
This makes it easier to figure out what's going on when you get these
error messages.
2023-06-07 19:39:16 -07:00
Keren Zhou
4fbadf6f6f [BACKEND] Fix tl.cat when the number of threads > the size of a tensor (#1751)
`tl.cat(tensor<64>, tensor<64>) -> tensor(128)`, because it concatenates
elements into a single thread, if number of threads is 128, each thread
should own at least 2 elements.
With this PR, we also disable remat of the cat op in some cases.
2023-06-07 15:42:38 -07:00
Zahi Moudallal
1db570f45c [FRONTEND] Added arch to the json (#1747) 2023-06-06 17:06:03 -07:00