Commit Graph

1390 Commits

Author SHA1 Message Date
Thomas Raoux
376acb610b [BUILD] Fix macos x86 build (#2505)
There was a mismatch in the llvm link name
2023-10-17 09:49:09 -07:00
Philippe Tillet
05dc28be0e [CI] refactor workflows (#2504)
no longer run third-party tests on every PR
2023-10-17 00:27:17 -07:00
Mehdi Amini
721897fcc4 upgrade llvm to b1115f8c (NFC) (#2403)
Co-authored-by: Thomas Raoux <thomas.raoux@openai.com>
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Phil Tillet <phil@openai.com>
2023-10-16 16:38:49 -07:00
Maksim Levental
87a223d76f bump triton_shared (#2501)
Bump to tip.

cc @manbearian @nhat-nguyen
2023-10-16 13:49:20 -07:00
Zahi Moudallal
726bdb984f [FRONTEND][BACKEND] Fix constexpr assignment ; revert #2430 (#2496)
Without this change, a constexpr assignment (ie. `A = B & C`, where `B`
and `C` are both constexpr) is getting assigned to a triton tensor,
which becomes an issue when `A` is used as the condition of an If
statement.
Note: I had to add `not isinstance(node.value, ast.Constant)` to the
condition because if we are assigning `x = 0` then the assigned value is
also a constexpr, but in this case we do want to assign a triton tensor
to `x` so that we can do `x.to(tl.int64)` for example, which cannot be
done on a constexpr.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-10-16 12:35:19 -07:00
Stewart Hall
29828fe491 [FRONTEND] add option to disable fp mul/add fusion (#2495)
By default, ptxas will enable fusion of mul/add to fma instructions. The
backend was also being configured unconditionally to enable this on
conversion from LLVM IR to PTX. This commit adds an option which can be
used to disable the FP fusion behavior in both locations.
2023-10-14 12:23:30 -07:00
Philippe Tillet
3b6ec763d5 Revert "[BACKEND] Disable BreakPhiStruct pass (#2458)" (#2498)
This reverts commit b1bc9b20a0.
2023-10-14 10:40:49 -07:00
Philippe Tillet
8db4fac3b0 Revert "[OPTIMIZER] Tweak warpsPerCTA based on the shape of MMA output (#2485)" (#2497)
Reverts openai/triton#2485
2023-10-13 23:32:59 -07:00
Weixing Zhang
76858bd917 [OPTIMIZER] Tweak warpsPerCTA based on the shape of MMA output (#2485)
In current implementation, warpsPerCTA is always set to [numWarps, 1]
for 2 tt.dot fusion scenario. But, it is not optimal for cases such that
tt.dot doesn't have enough parallelism on row dimension but on column
dimension.
2023-10-12 22:25:42 -07:00
Thomas Raoux
a777e1d8db [OPTIMIZER] Propagate mma layout when the transitive use has dot_operand encoding (#2482) 2023-10-12 23:57:40 +00:00
lanzhiguanhuang
03af50b040 [BUILD] Makes the link_libraries(stdc++fs) global for all targets in the CMake project. (#2481)
### Problem
The previous change still didn't link libstdc++fs into libtriton.so,
which caused the runtime error: undefined symbol
_ZNKSt10filesystem7__cxx114path11parent_pathEv
`link_libraries(stdc++fs)` should be placed before `add_library`.

### What this PR does
This PR Makes the link_libraries(stdc++fs) global for all targets in the
CMake project. By doing so, we ensure that the stdc++fs library is
consistently linked to all targets, addressing potential build issues on
old linux OS like debian10 which uses GCC8.3.0 as the default C/C++
compiler.
2023-10-12 23:46:43 +00:00
Keren Zhou
f81d9d876f [FRONTEND] Fix math for constant values (#2472)
https://github.com/openai/triton/issues/2470
2023-10-12 12:11:42 -07:00
Thomas Raoux
cda298fae7 [Pipeliner] Allocate less shared memory when possible (#2466)
The pipeliner was overallocating shared memory for the inputs
for current schedule. This reduces the shared memory usage to only
what is needed.
Note that improving membar analysis could allow taking advantage of
allocating extra buffers to remove barriers.
2023-10-12 12:10:06 -07:00
Thomas Raoux
6f46c93b9e [BACKEND] Add back dot.wait when generating async_dot (#2478)
Based on discussion this is needed to make sure there is no race
condition when reading shared memory.
2023-10-10 21:45:28 -07:00
Zahi Moudallal
4749072fbd [BACKEND] Allow reduce with sliced 3D layout as input (#2480) 2023-10-10 15:19:11 -07:00
Beal Wang
5812d970a8 [HOPPER][OPTIMIZER] remove divOp and remOp from gemm math loop (#2402)
This is just for Warp Specialization kernels on Hopper. Replace DivOp
and RemOp with SelectOp and AndOp/XorOp.
2023-10-09 14:42:06 +08:00
Philippe Tillet
424e67e727 [FRONTEND] improved while loop error messages (#2463) 2023-10-06 18:37:52 -07:00
Tori Baker
ab4549310b [OPTIMIZER] erase ops after use in iterator (#2455)
This seems to have worked fine in opt mode (although it may be producing
undefined behavior), but in debug mode on a newer version of llvm, it
segfaults without this PR as the iterators get invalidated.

This is also consistent with other places it is done in this file.
2023-10-06 18:02:56 -07:00
Thomas Raoux
b1bc9b20a0 [BACKEND] Disable BreakPhiStruct pass (#2458)
This is causing functional failures in pytorch workload. Disabling it
until I figure out the problem.
2023-10-06 17:59:53 -07:00
Thomas Raoux
a7061e19b2 [BACKEND] Fix multiple bugs in WGMMA (#2457)
Fix dependencies in wgmma_wait op to prevent the scheduler from moving
it past the uses of wgmma accumulator. We need to explicitly represent
the dependency between the wait and the accumulator uses otherwise LLVM
is free to re-order those.
This allows us to remove a workaround to prevent the re-ordering. We can
also remove the wait op added in the loop during pipelining.

Also fix the descritpor calcuation for wgmma, we should calculate the
same descriptor for the whole warpgroup.
Added a workaround for a bug that was exposed by different timing due to
those changes. We shouldn't insert operations between the loop and
async_wait or we may have race conditions.
2023-10-06 17:59:28 -07:00
Lixun Zhang
ded79e87ee [TUTORIALS] Enable causal=False in FA fwd kernel (#2459) 2023-10-06 17:54:45 -07:00
Sam Shleifer
fb3c2f3b2b [TUTORIALS] attention: support torch 2.1 (#2461) 2023-10-06 17:50:11 -07:00
Zahi Moudallal
be19cf3103 [BACKEND] Enable reduce with 3D tensors and added tests (#2460) 2023-10-06 15:08:22 -07:00
Keren Zhou
a42d517021 [FRONTEND] Better warning on nested jit functions (#2453) 2023-10-06 14:22:51 -07:00
Hongtao Yu
eed4559df2 [TOOLS] Enable per-pass IR printing in triton-translate (#2449)
Enabling per-pass IR printing such as `--mlir-print-ir-after-all`
2023-10-05 13:23:46 -07:00
Thomas Raoux
38f184b7cf [BACKEND] Use native fp8 convert ops when possible (#2448)
On Hopper we can use native fp8 conversion ops that are significantly
more efficient.

Improves epilogue in matmul. 8192x8192x512xf8 goes from 567 TFlops to
630 TFlops (the kernel is highly latency bound but this is a good proxy
for epilogue performance)
2023-10-05 18:28:58 +00:00
Philippe Tillet
7bc6b99132 [DOCS] Fix FP8E4M3B15 docs (#2451) 2023-10-05 10:59:45 -07:00
Philippe Tillet
e9d9ddd86d [OPTIMIZER] More V100 conversion removal tweaks (#2446) 2023-10-04 16:31:20 -07:00
Philippe Tillet
a7ff4eddae [OPTIMIZER] hasConvertToMMATransisitiveUse=False for MMAv1 (#2445) 2023-10-04 14:56:12 -07:00
Thomas Raoux
24560b8152 Better tuning for H100 flash attention. (#2444)
Improves performance of fwd pass from 420 to 440 TF
2023-10-04 14:43:41 -07:00
Thomas Raoux
5a0170a27c [BACKEND] Minor removing of unnecessary code and cleanup (#2443) 2023-10-04 12:14:08 -07:00
Justin Lebar
71a8544ce7 Improve docs for atomic and load/store operations. (#2437)
- Move atomic_cas and atomic_xchg to "atomic ops" section of
   documentation.
 - Don't talk about the `cmp` operand for operations which don't have
   it.
 - Document the `sem` operand.
 - :code:`foo` and ``foo`` don't work inside a :type: annotation,
   apparently.  (They are rendered literally, instead of being treated
   as a formatting command.)  Get rid of them.
- Format the bulleted lists in the load/store operations as intended.
2023-10-04 04:17:42 +00:00
Christian Sigg
5458014282 [BACKEND] Lower to PTX with trap-unreachable (#2429)
We've seen cases where the entire kernel is poisoned due to
division-by-zero, resulting in a single `unreachable` instruction at the
LLIR level. Emit this instruction as `trap` (instead of dropping it) so
that the kernel doesn't run successfully without writing any outputs.
2023-10-03 21:05:10 -07:00
Thomas Raoux
c656a139d3 [BACKEND] Fix for FP8 QK inputs in flash attention forward pass (#2435) 2023-10-03 21:02:13 -07:00
Zahi Moudallal
0d84a7d70c [BACKEND] Adding support for slice layout in InsertSliceAsyncOp (#2438) 2023-10-03 20:59:53 -07:00
Bin Fan
6b860e7a74 [Backend] fix a bug in lowering ExtractSliceOp from TritonGPU to LLVM (#2436) 2023-10-03 21:52:07 -04:00
Thomas Raoux
020f43d5a3 [NFC] Minor clean ups found during LLVM upgrade (#2433)
Pull some of the changes required for LLVM upgrade to make the upgrade
simpler.
2023-10-03 08:22:46 -07:00
apgoucher
cd38642ec5 Fix denormal handling in fp8e5 --> bf16 conversion PTX (#2430) 2023-10-02 17:26:30 +01:00
Keren Zhou
ac9fa68d18 [BACKEND] Fine-tune SharedMemoryObject definition and fix related problems (#2428) 2023-10-01 21:43:05 -07:00
Philippe Tillet
a0025cfc44 [FRONTEND] add missing implicit constexpr conversion in dot (#2427) 2023-10-01 16:07:50 -07:00
Tori Baker
97e35b677b [BACKEND] fix division by 0 pathway (#2412)
It was possible for multiDimWarpId[1] to be 0 which then gets translated
into a `urem 0, 0` and results in an unreachable when going through
llvm, an empty kernel, and nans. This PR uses ceiling to clamp the
result to be >=1.

chsigg is working on a fix to lower the unreachable in llvm to a trap
(https://github.com/llvm/llvm-project/pull/67478).
2023-09-30 10:53:43 -07:00
Philippe Tillet
98039658d4 [CI] disable pypy wheel (continued) (#2424)
there's a typo in the previous commit
2023-09-30 00:38:06 -07:00
Philippe Tillet
c4f3afc020 [CI] disable pypy wheel (#2423)
emitting warnings from C++ code requires "#include pybind11/exec.h"
which is not compatible with pypy. I think using the python interpreter
form C++ is a bad idea in general... but we probably don't care much
about pypy wheels anyway
2023-09-29 23:48:08 -07:00
Philippe Tillet
533efd0cac [FRONTEND][BACKEND] changed float8e4b15 clipping semantics from +-1.875 to +-1.75 (#2422)
clipping float8e4b15 to +-1.875 is a bad idea because these are
represented as 0x7f and 0xff, which are +- nan on H100 for float8e4nv.
We lose two values but this will make compatibility with float8e4nv way
less painful. (it will just be a matter of adjusting the bias)
2023-09-29 23:33:28 -07:00
Ying Zhang
ee013d8978 Fix PTX issues in bf16 / fp8_e4m3 conversion (#2421)
Fix bugs in https://github.com/openai/triton/pull/2415. cc @htyu 

Previously corresponding tests failed on H100 with latest torch version.
It passed CI because CI doesn't use latest torch, so the tests were
skipped.
2023-09-29 19:36:00 -07:00
Hongtao Yu
e0edb70f78 [BACKEND] support of Fp8E4M3Nv to Bf16 conversion (#2415) 2023-09-29 17:29:41 -07:00
Keren Zhou
e284112818 Revert "[TUTORIALS] Remove unneeded quantiles parameter (#2408)" (#2419)
This reverts commit 99af23f6f4.

`quantiles` shouldn't be the problem. The documentation workflow failed
because of other issues.
2023-09-29 14:24:50 -07:00
Keren Zhou
f2f5f1d457 [TUTORIALS] Add missing docstrings (#2420)
Depend on https://github.com/openai/triton/pull/2419 to fix the
documentation workflow
2023-09-29 14:24:30 -07:00
Thomas Raoux
90bef57acf [BACKEND] turn on MMA V3 by default on Hopper (#2414) 2023-09-28 22:45:28 -07:00
Thomas Raoux
d4fae90169 [BACKEND][NFC] Simplify conversion to TritonGPU (#2416)
Remove ad hoc patterns. This will help LLVM transition.
2023-09-28 13:59:15 -07:00