Commit Graph

871 Commits

Author SHA1 Message Date
peterbell10
0d76c4ca95 [FRONTEND] Rename tl.reduction -> tl.reduce and improve testing (#1521)
`tl.reduction` is currently tested indirectly through the existing
reduction operators, but it's good to have a direct test for the
function itself.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-04-14 14:35:31 -07:00
Bert Maher
bfd1f65ac7 [FRONTEND] cache path to ptxas (#1526)
When running python 3.8, I've found that process creation gets slower
over time (e.g. after creating a CUDA context, it can take 50-300ms per
subprocess.run), and we do one of these calls to `ptxas --version` for
every kernel, so a model with thousands of kernels can end up spending
substantial time just calling ptxas redundantly.

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-04-14 17:01:42 +00:00
Chenggang Zhao
c624778e73 [TUTORIALS] Add tutorial for block pointers (#1519)
This PR contains:
- Several fixes for the matrix multiplication (M and N dimensions may
have out-of-bound access)
- A type check for block-based store
- The tutorial for block pointers
- Fix some formats
2023-04-14 00:40:41 -07:00
Phil Tillet
92d07d1b8e [DOCS] Fixed up workflow 2023-04-13 16:05:29 -07:00
Phil Tillet
0e3290963e [DOCS] re-enabled flash attention tutorial 2023-04-13 15:49:32 -07:00
Philippe Tillet
2aad2336e9 [DOCS] Documentation job now uses A100 GPUs (#1522) 2023-04-13 15:35:16 -07:00
Keren Zhou
fdf1c1f2a1 [DOCS] Fix documentation workflow (#1520)
Co-authored-by: Phil Tillet <phil@openai.com>
2023-04-13 13:49:36 -07:00
peterbell10
6550c528b7 [FRONTEND] don't call tl.view in arg{min,max} (#1518)
A small oversight in #1305, since `view` can rearrange elements it
should be avoided here. Instead I use indexing with `None` to create new
dimensions.

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-04-13 07:32:23 +00:00
Philippe Tillet
c0d86d3b04 [RUNTIME] refactor driver (#1515)
Improved separation between different backends
2023-04-12 23:50:44 -07:00
peterbell10
e152183570 [FRONTEND][BACKEND] ReduceOp to support arbitrary reduce operations (#1305)
Fixes #1285

This changes `tt.reduce` to replace `redOp` by a region containing
arbitrary code. For example, `tl.sum` is now lowered as:
```mlir
%res = "tt.reduce"(%arg0) ({
^bb0(%arg1: f32, %arg2: f32):
  %add = arith.addf %arg1, %arg2 : f32
  tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<128x128xf32>) -> tensor<128xf32>
```
Support for index reductions at the MLIR level are also dropped in favor
of simultaneous reductions over multiple tensors. Which generalizes the
code without loss of performance. So for example `argmin` gets lowered
as:
```mlir
  %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
  %8 = tt.view %7 : (tensor<256xi32>) -> tensor<1x256xi32>
  %9:2 = "tt.reduce"(%6, %8) ({
  ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32):
    %14 = arith.cmpf olt, %arg4, %arg6 : f32
    %15 = arith.cmpf ogt, %arg4, %arg6 : f32
    %16 = arith.cmpi slt, %arg5, %arg7 : i32
    %17 = arith.select %16, %arg5, %arg7 : i32
    %18 = arith.select %15, %arg7, %17 : i32
    %19 = arith.select %14, %arg5, %18 : i32
    %20 = arith.cmpf olt, %arg4, %arg6 : f32
    %21 = arith.select %20, %arg4, %arg6 : f32
    tt.reduce.return %21, %19 : f32, i32
  }) {axis = 1 : i32} : (tensor<1x256xf32>, tensor<1x256xi32>) -> (tensor<1xf32>, tensor<1xi32>)
```
2023-04-13 01:37:39 +00:00
Philippe Tillet
5b9119117b [CI] No longer install triton in editable mode to run tests (#1476) 2023-04-12 17:55:44 -07:00
Phil Tillet
9530d93504 [TESTING] change do_bench defaults 2023-04-11 22:03:52 -07:00
Phil Tillet
d7d62ddae9 Revert "[BUILD] Fixed typo in setup.py"
This reverts commit 2931bb8195.
2023-04-11 20:12:22 -07:00
Phil Tillet
2931bb8195 [BUILD] Fixed typo in setup.py 2023-04-11 20:09:09 -07:00
Philippe Tillet
02e3c18f04 [TESTING] clean up testing.do_bench (#1513) 2023-04-11 20:05:58 -07:00
zahimoud
081f64046c [DOCUMENTATION] Added documentation to the code that lowers reduce to llvm (#1514) 2023-04-11 19:15:36 -07:00
zahimoud
fd34b20fba [BACKEND] Fixed bug in reduce; add tests 2023-04-11 18:09:18 -07:00
Phil Tillet
3e22e18295 [TESTING] do_bench now return min time by default.
This is likely to be more stable in general for benchmarks that have L2
hit rate comparable to what is encountered in practice
2023-04-11 17:18:01 -07:00
Keren Zhou
272f23457a [DOCS] Restore the documentation workflow (#1503)
Not sure if it works at this moment, but at least we can restore the
workflow first.
2023-04-11 13:36:15 -07:00
Philippe Tillet
0fedf6b79a [TESTS] disable launch latency test (#1510) 2023-04-11 10:31:16 -07:00
Philippe Tillet
e0d6f5f4f5 [BUILD] updated LLVM binaries (#1504)
Co-authored-by: Christian Sigg <csigg@google.com>
2023-04-11 00:14:00 -07:00
Keren Zhou
6d0ed41307 [BACKEND] Replace Func Dialect with custom triton ops (func, call, return) (#1502)
MLIR current only supports a custom inlining interface per dialect, so
we cannot change the inlining decision of `func.func`.


https://discourse.llvm.org/t/avoid-inlining-some-functions-using-the-func-dialect/69830/3

Could revert it back once they've designed a better inliner interface.

Inlining attributes will be implemented in the next PR since this PR is
already huge.
2023-04-10 21:08:40 -07:00
Philippe Tillet
640f3c3921 [OPTIMIZER] Tweaked layout removal conversion heuristics (#1501)
Loads are now consider cheap to rematerialize when there are more
threads than elements in the tensor
2023-04-10 15:19:08 -07:00
peterbell10
2c06f875e4 [TESTS] Add triton version of mlir-reduce (#1498)
[`mlir-reduce`](https://mlir.llvm.org/docs/Tools/mlir-reduce/) is a tool
to reduce the complexity of bug reproducers written in mlir. Similar to
`triton-opt`, `triton` needs to have its own version with the dialects
registered properly for it to work.
2023-04-10 13:31:11 -07:00
petterreinholdtsen
8c55276c90 [DOCS] include link to github project in README (#1494)
This make it easier for those working from tarball releases or clones to
know where the upstream project is located.
2023-04-10 09:29:59 -07:00
Keren Zhou
032509384a [ANALYSIS] Fine-tune comments for shared memory allocation (#1492)
And add a new test to check multiple color cases which have never be
tested before
2023-04-10 09:00:36 -07:00
Zilin Zhu
19e424eb98 [ops/blocksparse] Fix grid shape for large lm (#1491)
When the language model grows really large, the axis 1 of the origin
grid shape (`c.shape[1]`, correspond to the number of nonzero elements
in the layout) will be larger than 65536, exceeds the cuda limit and
results in `[CUDA]: invalid argument`.

This PR is moving the axis 1 of the origin grid to axis 0, as the limit
for axis 0 is 2^31 - 1.

Thank you for your time on reviewing this PR :)
2023-04-10 09:00:12 -07:00
Philippe Tillet
adc760dac1 [OPTIMIZER] enable loop pipelining using pointer increments from vector look-up tables (#1490) 2023-04-10 08:59:42 -07:00
who who who
fd0516fb90 [DOCS] Fixed typo (#1489) 2023-04-09 16:06:34 -07:00
Philippe Tillet
b86425a28e [TEST] made lut_bmm pipeline test more concise and specific (#1488) 2023-04-08 19:17:35 -07:00
long.chen
f7ad8ae022 [Refine] remove const ref of mlir::Attribute (#1486)
https://mlir.llvm.org/docs/DefiningDialects/AttributesAndTypes/

https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#f16-for-in-parameters-pass-cheaply-copied-types-by-value-and-others-by-reference-to-const
```
The C++ Attribute and Type classes in MLIR (like Ops, and many other things) are value-typed. 
This means that instances of Attribute or Type are passed around by-value, 
as opposed to by-pointer or by-reference. 
The Attribute and Type classes act as wrappers around internal storage objects that are uniqued within an instance of an MLIRContext.
```
2023-04-08 10:38:59 -07:00
mcskatkat
82ec1a89ea [FRONTEND] code_generator.py TODOs fixed & removed (#1484)
Handled TODOs that were waiting for the circular import issue to be
resolved
2023-04-07 22:05:46 -07:00
Ian O'Connell
bc0b007e4b [FRONTEND] Allow cache manager to be overridden, and tweak apis to easier work with remote caches (#1478)
The changes here come with a few separate bits:

- Allow replacing the cache manager with an ENV variable to make it
pluggable
- Make the `make_path` api private since its leaking some internal bits
of the cache and allowing file access. Use a get operation instead.
- For the `compile` operation we have a several files part of a single
compile pipeline that are small, this can be not the most performant
with remote caches. Also some operations like
`_triton.get_shared_memory_size` only work when everything is cached or
none(or some key ones aren't). They segfault otherwise. So grouping
these as an entity avoids that.
2023-04-07 13:38:28 -07:00
Keren Zhou
6743e42eb5 [FRONTEND] Data type specification for math functions (#1485) 2023-04-07 10:26:19 -07:00
Keren Zhou
7f3f58f332 [FRONTEND] Fix broadcast semantics (#1480)
https://github.com/openai/triton/pull/1183

---------

Co-authored-by: Yen-Chen Lin <yenchenlin1994@gmail.com>
2023-04-06 10:40:40 -07:00
Philippe Tillet
8cbf9b40a4 [TESTING] Minor fixes (#1479) 2023-04-06 00:48:33 -07:00
Phil Tillet
4c1d001ae4 [TESTING] Now using numpy instead of pytorch in triton.assert_close
More memory-efficient than pytorch
2023-04-04 23:57:12 -07:00
Eta
577cafff0a [BUILD] Add missing subpackages to build (#1475)
The `triton/compiler`, `triton/runtime/driver`, and `triton/third_party`
subpackages were missing from the distribution built with the old
`setup.py` after #1464, causing an immediate error upon importing Triton
with a non-editable installation. This change adds the missing Python
subpackages and moves `triton/third_party` inclusion to `MANIFEST.in`,
where it will automatically be included in wheels due to the existing
`include_package_data` setup flag.
2023-04-04 22:41:08 -07:00
Phil Tillet
0e11f1e167 [TESTING] Added triton.allclose wrapper around torch.testing.allclose.
This adds a convenience layer to test linear algebra kernels and their perf.
2023-04-04 21:53:36 -07:00
Philippe Tillet
47e73aadda [BACKEND] Revert inline PTX for conversions supported by LLVM (#1474)
No longer needed now that we initialize all registers. Motivation for
reverting this workaround now that we can is that it introduced
performance regressions
2023-04-04 17:52:26 -07:00
Keren Zhou
00a9143bb4 [FRONTEND] Expose Autotuner to users (#1473)
The Autotuner is a handy utility. By allowing external access to the
Autotuner, users can overwrite some functions (e.g., `run`) to
load/store best configurations, initialize tensors based on
configuration values, and change benchmarking standard (e.g., based on
bytes instead of time).
2023-04-04 17:12:00 -07:00
Christian Sigg
01a93185a1 [BACKEND][OPTIMIZER] Switch from llvm::Optional to std::optional. (#1416) 2023-04-04 09:06:28 -07:00
Philippe Tillet
053af4e9f8 [FRONTEND] Refactor file hierarchy (#1464)
The purpose of this PR is to remove some circular dependencies and
separate concerns better in the frontend. It's still not perfect --
`triton.compile` still includes a few runtime architecture-specific
component, but at least much better than before.

This PR still assumes that AMD only supports empty kernels right now.
Other PRs will follow to make the frontend supports multiple devices in
a more modular way.
2023-04-02 12:07:08 -07:00
Keren Zhou
0855cacdd8 [BACKEND] Fix small matmul dot (#1463)
https://github.com/openai/triton/issues/1449

In theory, we might be able to support even 8x8 dot if we also wrap
around `cOff`.
2023-04-02 02:05:05 +00:00
Keren Zhou
801bb9d3b5 [ANALYSIS] Fix divisibility calculation for addptr (#1453) 2023-03-31 17:57:31 -07:00
Keren Zhou
859952a0aa [FRONTEND] Include the debug field when computing the kernel hash (#1458)
Co-authored-by: Philippe Tillet <phil@openai.com>
2023-04-01 00:52:51 +00:00
Da Yan
bf158bf01f [FRONTEND] kwargs as autotune key (#1457) 2023-03-31 17:09:14 -07:00
Kern Handa
2c0417da96 [DOCS] fixed typo triton.testing.allclose -> torch.allclose in MatMul tutorial (#1460) 2023-03-31 17:06:46 -07:00
Keren Zhou
cc4aa1ebbc [FRONTEND] Fix if-exp parsing for size-1 tensors (#1455) 2023-03-31 15:05:52 -07:00
Phil Tillet
966e5d955b [TEST] Increase `LATENCY_THRESHOLD_US 2023-03-31 11:38:18 -07:00