Commit Graph

855 Commits

Author SHA1 Message Date
Stonepia
a13ddf08e2 [FRONTEND] Fix bug when the _SYSPATH is set. (#1156) 2023-02-06 18:02:42 +00:00
fdrocha
972b761390 [FRONTEND] For __rshift__ operator, use arithmetic right shift if dtype is a signed int. (#1153) 2023-02-06 10:26:17 +00:00
Keren Zhou
ce47f94e59 [FRONTEND] Check if the data types of *A* and *B* in the dot op have the same data type (#1155) 2023-02-06 01:58:07 -08:00
Keren Zhou
546f2377ae [BACKEND] Get the right operand and result types in forward rematerialization passes (#1152) 2023-02-04 16:34:35 -08:00
Emil Masoumi
dff43abbb9 [Build] Prevent excessive hyphens from causing build errors. (#1151)
Prevents excessive hyphens from causing build errors on non-Windows
machines.
2023-02-04 00:22:57 -08:00
Yu Guo
474ed978b9 [BUILD] Fixed typo in CMake type tablegen (#1124) 2023-02-03 18:46:11 -08:00
Mehdi Amini
ce6d74e0b6 [BACKEND] Fix crash in test/TritonGPU/coalesce.mlir (#1148)
The call to `coalesceOp` is deleting the op it is processing and
replacing it with a new one. We can't `dyn_cast` the `curr` pointer
because it is dangling at this point.

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-02-04 02:40:27 +00:00
Philippe Tillet
8a4ca2c61a [CI][TEST][FRONTEND] Various small fixes (#1150)
- cancels CI runs in progress when a PR is updated
- atomics tests now use small int values that can be represented exactly
- replaced some old-style formatting by some f-string
2023-02-03 18:12:34 -08:00
Keren Zhou
bde52f9db2 [BACKEND] Fix alignment calculation (#1149)
`getDivisibility` represents if the address in bytes is divisible by a
certain number, so we should convert `#aligned bytes` to `#aligned
elements`.
2023-02-03 17:20:23 -08:00
Philippe Tillet
43798ab27e [BUILD] Restored wheels workflow (#1146)
- Dependent CUDA files (ptxas, cuda.h, libdevice.bc.10) are now packaged in
`triton/third_party/cuda`. `ptxas` is downloaded from conda repo at
install time.
- Can now be built with old glibc (as that used by manylinux2014)
2023-02-03 16:22:10 -08:00
Nishant Sikarwar
f9e26deb05 [FRONTEND] using literal syntax to create the data structure (#1119)
The literal syntax can give minor performance bumps compared to function
calls to create dict, list and tuple. This name dict must be looked up
in the global scope in case it has rebound. The same goes for the other
two types list() and tuple().

Signed-off-by: nishantsikarwar <nsikarwar@ch.iitr.ac.in>
Co-authored-by: Philippe Tillet <phil@openai.com>
2023-02-03 13:59:13 -08:00
Chenggang Zhao
f86843f815 Change libdevice.bc Path in Core Tests (#1141)
Only test `libdevice.bc` shipped with triton
2023-02-02 20:01:12 -08:00
George Karpenkov
a9d1935e79 [FRONTEND] Fix error message when atomics are not supported for a given dtype (#1134)
Otherwise, the construction of the exception crashes during string
concatenation.
2023-02-02 02:49:34 +00:00
Keren Zhou
82befe32ad [BACKEND] Improve torch inductor performance (#1108)
- Rewrite the AxisInfo analysis to handle each op case by case.
- Add bit shift, min max, div/rem, and select ops to AxisInfo.
- Rematerialize across load/store ops in the following two cases:
- A size 1 tensor is considered not expensive since all threads will
load the same
- the targeEncoding may expose more vectorization opportunities (more
elements per thread on the first dim)

**_res2next_** benchmark GPU Kernel time comparison on A100.
- Average kernel sum. Triton 16838630ns vs Triton-MLIR 17105166ns.
**1.016x slowdown**.
- Total kernel sum. Triton 6511735460ns vs Triton-MLIR 6512370620ns.
2023-02-01 18:21:15 -08:00
Philippe Tillet
ccd17d6bf9 [TESTS] Added test for flash-attention (#1138) 2023-02-01 11:26:29 -08:00
Keren Zhou
71c6d56b9b [BACKEND] Hotfix to remove SameOperandsAndResultEncoding for the trans operation (#1136)
The order of the input encoding is permuted to form the output encoding.

For example:

```
#A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
%b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T>
```
2023-02-01 10:03:38 -08:00
George Karpenkov
9c3f55cbee [BUILD] Allow multi-threading during compilation (#1133)
Currently, multi-threading is only allowed during PTX->cubin
compilation, but not for LLVM->PTX or TTIR->LLVM conversion.
2023-02-01 09:40:25 -08:00
Keren Zhou
1ec39fdf99 [BACKEND] Refactored the MoveConvertOutOfIf conversion to handle scf.if correctly (#1114)
Also removed duplicate code for `simulateBackwardRematerialization`.
2023-02-01 08:49:19 -08:00
Keren Zhou
5dd8ce3745 [BACKEND] Fix topological sort and add new test cases (#1132)
Previous https://github.com/openai/triton/pull/1113 forgot to consider
that a node may have multiple parents, visiting the instruction before
any parent violates the semantic of topological sort.

The fixed implementation exhaustively add all operations into a
candidate subgraph and move an operation to the "ready" queue once all
of its operands have been visited.
2023-01-31 23:41:20 -08:00
xndcn
fc846e5e1e [BUILD] Fix build when LLVM_LIBRARY_DIR is not set (#1128)
1. try to search llvm-config up to llvm-config-17 version
2. LLVM_LIBRARY_DIR is expected by MLIR_DIR, so set it to
LLVM_LIBRARY_DIRS, which is outputted by FindLLVM

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-01-31 20:55:10 -08:00
Philippe Tillet
8fea1fb478 [FRONTEND] Adding static range (#1130)
Included: Revert "[BACKEND] Replace `mlir::topologicalSort` with a
custom implementation (#1113)"
2023-01-31 18:04:19 -08:00
rsanthanam-amd
be3da96919 [FRONTEND] Fix restoration of llir IR from cache to give a string. (#1127)
Since the llir IR is a string when it is first generated, it should also
be a string when we fetch it from the cache.
2023-01-31 18:35:10 +00:00
Philippe Tillet
c4b9d699d2 [FRONTEND][BACKEND] Fixed many bugs (#1122)
- **temporarily commenting assertion in `MemBar.cpp`. We need to fix
this! but for now the following patches will unblock a number of
users.**
- Fixed frontend codegen issue for If / For / While. Emit an error when
replaced values' type mismatch.
- Added "top level" codepath for if statements, which allows users to
write patterns to exit early from kernels (e.g., `if cond1: if cond2:
return else: ...`). Added associated codegen in TritonToTritonGPUPass
- Added basic control flow tests
- Pipeline pass is no longer activated when memory accesses can't be
vectorized
- Added missing magic methods to `constexpr`
- Fixed issue in random.py: bitcast some values to uint when they need
to be.
- Added support for `Not`
- Fixed nondeterministic compilation issue
2023-01-30 23:22:36 -08:00
goostavz
3e8d83b7cc Minor fix to support sm_90 (#1125)
This fix enables the support on sm_90 (otherwise it will crash).

Logs like 
> 'sm_90' is not a recognized processor for this target (ignoring
processor)

could be ignored and should be eliminated with the update of llvm nvptx
backend.
2023-01-31 14:08:02 +08:00
Yan Chunwei
94b419c327 [FRONTEND] some tiny fix (#1120) 2023-01-30 19:39:38 -08:00
Nishant Sikarwar
653c8dc124 [FRONTEND] Replaced range with enumerate calls (#1110)
Using range(len(...)) is not pythonic. 
Python does not have not index-based loops. Instead, it uses collection
iterators. Python has a built-in method enumerate which adds a counter
to an iterable. Using this, you can access the counter and the value
from the iterable at the same time. It is therefore recommended to
replace range(len(...)) with enumerate(...).

for ex 


5bcf60a5c0/python/triton/language/extern.py (L68)



f62d556fff/python/triton/language/extern.py (L68)

Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2023-01-30 15:22:11 -08:00
Nishant Sikarwar
e5dbe35cc1 [FRONTEND] removed unnecessary comprehension (#1085) 2023-01-30 19:42:14 +00:00
Keren Zhou
bc8a26d56f [BACKEND] Replace mlir::topologicalSort with a custom implementation (#1113)
`multiRootTopologicalSort` is faster than `mlir::topologicalSort`
because it prunes nodes that have been visited before.
2023-01-29 18:57:21 -08:00
Keren Zhou
5bcf60a5c0 [BACKEND] Refactored the code to no longer include static functions in header files. (#1109) 2023-01-28 14:58:28 -08:00
Da Yan
82f5e988be [OPTIMIZER] Improve bf16 and i8 matmul performance (#1107)
Use i32 as the storage type for <2xi16> and <4xi8>, as NVPTX inserts
extra integer instructions for vector int types.

Performance before this PR: (8192x8192x8192-TN input)
bf16: 222 TFLOPS
i8:     339 TOPS

After this PR:
bf16: 272 TFLOPS
i8:      548 TOPS
2023-01-27 22:13:14 +00:00
Nikita Shulga
e9446c7ce3 [BUILD] Add ability to bundle CUDA dependencies (#1100) 2023-01-27 09:55:49 -08:00
Da Yan
394f2e6991 [OPTIMIZER] improved prefetch width (#1106)
Before this PR: 16
After this PR: 16 (fp16/bf16), 32(int8/fp8), 8 (tf32)
The new prefetch width works better with i8/f8/tf32 tensor cores.
2023-01-27 17:41:49 +00:00
Nikita Shulga
d3e753b5c0 [RUNTIME] Raise runtime error if C compiler is not found (#982)
Makes error reported in https://github.com/pytorch/pytorch/issues/90377
a bit easier to understand
2023-01-26 00:08:25 +00:00
Edward Z. Yang
cf0ae2ed76 [BUILD] Still build even if lit is not installed on user's system (#1095)
Otherwise it fails with

```
File "setup.py", line 147, in build_extension
    "-DLLVM_EXTERNAL_LIT=" + lit_dir,`
TypeError: can only concatenate str (not "NoneType") to str
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
2023-01-25 12:55:59 -08:00
Keren Zhou
34ac01e597 [BACKEND] Change DenseMap to std::map in Allocation (#1092)
Since we need to get the address/reference of each entry, it is not
appropriate to use DenseMap here because rehashing on DenseMap will
reallocate buckets and change addresses of entries. Unlike DenseMap,
associative containers in STL do not change references to existing
entries when other entries are inserted or deleted.


https://stackoverflow.com/questions/61771352/will-it-change-the-address-of-a-existed-keys-value-when-inserting-new-keys
2023-01-25 05:18:20 -08:00
Keren Zhou
b5d32896b1 [BACKEND] Verify the same operand and result element type for convert_layout (#1081)
And a hotfix for incorrect convert_layout construction in the GPU
combine pass.
2023-01-22 16:59:24 +00:00
Keren Zhou
c59fb4acca [FRONTEND] Fix libdevice elementwise compute for constexpr (#1082) 2023-01-22 07:11:44 +00:00
Keren Zhou
6531cd9f13 [CI] Change integration test branch from master to main (#1083) 2023-01-21 22:40:00 -08:00
Yan Chunwei
88498d104a [BACKEND] DotOp enable ld.v4 in MMAv1 (#1020)
The existing convert distributed to distributed layouts logic is based
on processing each MMA-block, this requires each MMA-block to share
exactly the same fixed pattern(such as the one described in the [NV PTX
doc](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-16816-float)).

While for MMAv1, things are different, the MMA-block has variant
patterns for different shapes and data layouts as below

<img width="200" alt="image"
src="https://user-images.githubusercontent.com/328693/213354941-731d7856-ad24-4f48-be0e-3cf41532cfa4.png">

This requires all the cell coordinates in DotOp output to be computed.
2023-01-19 09:42:33 -08:00
Philippe Tillet
408d1d7e87 [OPTIMIZER] Improved flash attention forward pass performance (#1075)
- Fixed typo in instruction reordering pass
- Minor additional optimizations for shared memory allocator
- Optimized flash attention tutorial forward pass kernel
2023-01-19 06:46:01 +00:00
Void Main
b2c522a451 [BACKEND] Remove duplicate def for create_get_program_id (#1013)
The same function is redefined in lines
[645-650](https://github.com/openai/triton/blob/master/python/src/triton.cc#L645-L650)
and
[1174-1179](https://github.com/openai/triton/blob/master/python/src/triton.cc#L1174-L1179),
compared these 2 definitions, looks like we should remove the code in
lines 645-650.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2023-01-18 16:25:10 -05:00
Philippe Tillet
660f2e8cce [OPTIMIZER] pipeline and prefetch pass now use a more ptxas-friendly schedule (#1065) 2023-01-17 15:21:19 -08:00
Nishant Sikarwar
7687f85ca4 [FRONTEND] decorating static methods with @staticmethod (#1069) 2023-01-17 14:35:06 -08:00
Keren Zhou
3f47e9aa0e [BACKEND] Fix unrealized conversion for fp32 dot (#1051) 2023-01-17 21:55:44 +00:00
Goran Flegar
e2923afc71 [BUILD] Add dependency to TritonTableGen in TritonAnalysis (#1067)
The `TritonAnalysis` target depends on `TritonTableGen` through
including `Triton/IR/Dialect.h`, which itself includes
`Triton/IR/Dialect.h.inc` generated by `TritonTableGen`.

Without it, the build might fail (seems to be happening inconsistently,
due to multithreaded builds)
2023-01-17 13:21:16 -08:00
Goran Flegar
4dcd455d2a [BUILD] Fix build config when not building python module (#1066)
In the case when `TRITON_BUILD_PYTHON_MODULE` is set to `OFF`, the
`triton` target ends up having no sources, which triggers a CMake
configuration error.

Fix this by only generating the target when building with python module
support.

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-01-17 13:19:48 -08:00
Nishant Sikarwar
fbd93d3f10 [FRONTEND] replaced unsafe exit with sys.exit (#1060) 2023-01-17 09:04:03 -08:00
Goran Flegar
afd02626ea [BUILD] Fix build issues of triton-translate tool (#1068) 2023-01-17 09:03:29 -08:00
Nishant Sikarwar
4a74d6eae9 [FRONTEND] replaced chains comparison operator with in (#1059) 2023-01-15 20:14:35 +00:00
Yan Chunwei
86003c83dd [Optimizer] Add UpdateMmaForVolta Pass (#1048)
This PR adds UpdateMmaForVolta pass to help update the MMA encoding for
Volta.
Some context is told in https://github.com/openai/triton/pull/1014

# Changes

1. Moving the related MMAv1 patterns from GPUCombine pass to
UpdateMmaForVolta pass,
2. Updating both the versionMinor and warpsPerCTA fields for Volta MMA
encodings since they could only be determined after the GPUCombine Pass,
3. Moving the FixupLoop pattern from the Combine.cpp to new
Utility.h/.cpp files
4. Adding an ID field(takes 5 bits to store an integer) to versionMinor
to help assigning a unique ID(on Volta) for each MMA encodings, the
reason is as below
- Currently, there is a cyclic dependency between {DotOperand, Slice}
with MMA layouts, we use a map to help cluster all the DotOperand,
Slice, and MMA layout instances into the same group for further updating
in bulk
- When there are multiple DotOps in a module with the same MMA(literally
equivalent), it is possible to get the wrong groups
- an ID field is used to help to identify the MMA from different DotOps,
thus getting all the MMA, DotOperand, and Slice layout instances in the
right groups
2023-01-14 11:54:19 +08:00