Commit Graph

528 Commits

Author SHA1 Message Date
Nishant Sikarwar
7687f85ca4 [FRONTEND] decorating static methods with @staticmethod (#1069) 2023-01-17 14:35:06 -08:00
Keren Zhou
3f47e9aa0e [BACKEND] Fix unrealized conversion for fp32 dot (#1051) 2023-01-17 21:55:44 +00:00
Nishant Sikarwar
fbd93d3f10 [FRONTEND] replaced unsafe exit with sys.exit (#1060) 2023-01-17 09:04:03 -08:00
Nishant Sikarwar
4a74d6eae9 [FRONTEND] replaced chains comparison operator with in (#1059) 2023-01-15 20:14:35 +00:00
Yan Chunwei
86003c83dd [Optimizer] Add UpdateMmaForVolta Pass (#1048)
This PR adds UpdateMmaForVolta pass to help update the MMA encoding for
Volta.
Some context is told in https://github.com/openai/triton/pull/1014

# Changes

1. Moving the related MMAv1 patterns from GPUCombine pass to
UpdateMmaForVolta pass,
2. Updating both the versionMinor and warpsPerCTA fields for Volta MMA
encodings since they could only be determined after the GPUCombine Pass,
3. Moving the FixupLoop pattern from the Combine.cpp to new
Utility.h/.cpp files
4. Adding an ID field(takes 5 bits to store an integer) to versionMinor
to help assigning a unique ID(on Volta) for each MMA encodings, the
reason is as below
- Currently, there is a cyclic dependency between {DotOperand, Slice}
with MMA layouts, we use a map to help cluster all the DotOperand,
Slice, and MMA layout instances into the same group for further updating
in bulk
- When there are multiple DotOps in a module with the same MMA(literally
equivalent), it is possible to get the wrong groups
- an ID field is used to help to identify the MMA from different DotOps,
thus getting all the MMA, DotOperand, and Slice layout instances in the
right groups
2023-01-14 11:54:19 +08:00
Da Yan
eb8af19eb4 [FRONTEND] Raise an expection when pytorch tensor is not on a cuda device (#1052)
Always raise an exception when JITFunction's argument is not on cuda.
2023-01-14 01:57:52 +00:00
Philippe Tillet
259f4c5f7d [OPTIMIZER] Added new optimization passes (#1055)
This PR adds a couple of optimization passes that should substantially
improve the performance of Triton on fused attention kernels:
- DecomposeConversionsPass: This decomposes some instructions of the
form `convert_layout` into
- ReorderInstructions: this reorders instructions in a way that is more
amenable to good code generation from `ptxas`.
2023-01-13 13:15:53 -08:00
Keren Zhou
4167b6281b [BACKEND] Support int64 constant (#1050)
Code snippet from a case in torchbench.

```Python
@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 2
    rnumel = 77
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.int64) + -9223372036854775808
```
2023-01-11 17:43:28 -08:00
Philippe Tillet
83df07aac6 [OPERATORS] blocksparse softmax now backpropagates nans (#1046) 2023-01-10 18:00:49 -08:00
Philippe Tillet
dc7ecf4535 [FRONTEND] Fix output datatype of reduce (#1045) 2023-01-10 15:04:54 -08:00
Da Yan
0f5c6e619c [BUILD] Add the missing triton/impl to setup.py (#1042) 2023-01-09 19:03:45 +00:00
Connor Baker
c20215dad1 [FRONTEND] Update PTX/SM support for LLVM14 (PR #1038 redux) (#1039)
=
2023-01-09 10:31:55 -08:00
Keren Zhou
733301ff31 [Backend] Rewrite code for linking external library to expose more inlining opportunities (#1037)
- Also make it cleaner. 
- And mark out the code needs to be fixed in `semantic.py`.
2023-01-08 13:44:29 -08:00
Keren Zhou
4023149ee3 [Frontend] Convert constexpr to value for store and load ops (#1030)
Fixing problem 2 in https://github.com/openai/triton/issues/1017

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-01-05 14:40:16 -05:00
Gregory Axler
2193bee94e [Example] Fix the compile function in copy_strided.py (#1029) 2023-01-05 10:37:41 -08:00
Sophia Wisdom
411bacb2a8 [FRONTEND] Add logical operations on constexprs (#1033) 2023-01-04 18:06:32 -08:00
Sharad Vikram
bc73bbb12c [FRONTEND] Fix argmin/max output type (#1012)
Currently Triton returns tensors with the input types rather than i32
when doing reduce argmax/argmin.
2023-01-03 23:12:16 -08:00
Keren Zhou
8460ea3df1 [Frontend] Fix import for libdevice (#1028)
This is a hotfix for issue 1 in
https://github.com/openai/triton/issues/1017
2023-01-03 15:48:05 -08:00
fdrocha
194ba103b1 [BUILD] Fixed error when compiling in systems with multiple versions of python installed (#1019) 2022-12-29 15:10:34 -08:00
Keren Zhou
fd2da4aff6 [BACKEND] Support splat constant on the DotOperandLayout (#1008) 2022-12-22 00:48:46 -08:00
Sharad Vikram
925d3d7f98 [FRONTEND] Export broadcast and broadcast_to in triton.language (#1007) 2022-12-22 01:57:33 +00:00
Keren Zhou
b5aafb0dab [FRONTEND] Fix 3d indexing (#1006) 2022-12-21 12:52:32 -08:00
Philippe Tillet
20100a7254 Merge triton-mlir branch - Complete rewrite of the backend from scratch (#1004)
This PR merges the `triton-mlir` branch, in which we have been quietly
rewriting the Triton backend from scratch to increase maintainability,
stability and ultimately performance. Changes to the runtime are
minimal, and this new version aims to remain backward-compatible with
the previous commit. The legacy backend is now officially deprecated,
but can still be accessed via the `legacy-backend` tag.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com>
Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com>
Co-authored-by: Yan Da <dyanab@connect.ust.hk>
Co-authored-by: Jun Yang <yangjunpro@gmail.com>
Co-authored-by: Ian Bearman <ianb@microsoft.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
Co-authored-by: Qingyi Liu <qingyil@nvidia.com>
Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com>
Co-authored-by: Chenggang Zhao <lyricz@yeah.net>
Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
Co-authored-by: dongdongl <dongdongl@nvidia.com>
2022-12-21 01:30:50 -08:00
Yang Hau
8650b4d1cb [DRIVER] Fix typos (#939) 2022-12-02 11:13:46 -08:00
Crutcher Dunnavant
44f577984d Fix format double substitution bug: {i} => {{i}} (#886)
The previous `{i}` was silently expanding to the `i` from the
enumeration loop on `regular_args` (when it wasn't empty).
2022-11-20 11:44:42 -08:00
Crutcher Dunnavant
0e4691e6dd [FRONTEND] Fix ExternLibrary(format=) bug; type annotate build_extern.py (#883)
Ran mypy over `build_extern.py`, cleaned up type annotations.

Found a fixed a bug where `ExternLibrary(format=)` was being ignored.
2022-11-17 18:45:30 +01:00
Natalia Gimelshein
0d7e753227 [TESTING] use torch.int for autotuning cache (#840)
For stupid reasons, ops on int8 are 3 times slower than on int, and for
another set of stupid reasons we are not using cudaMemset for `zero_`,
so using `int8` buffer in `do_bench` makes it slow.

Co-authored-by: Philippe Tillet <phil@openai.com>
2022-11-04 18:05:16 -07:00
Shintaro Iwasaki
77bc5187b5 Better NVIDIA Pascal GPU Support (#827)
This PR clarifies which features are supported on P100 via its tests,
though Pascal is not officially and fully supported by Triton.

## What this PR does

- Skip unsupported tests on P100.
  - Atomic RMW
- `tl.dot()` (perhaps not all patterns, but basically most `tl.dot()`
tests do not work on P100).
- Add an explicit error if shared memory size >= 64K on P100.
- Otherwise it causes `Invalid CUDA argument` error at
`cuLaunchKernel()`, but this error is not very straightforward to
understand. Instead of this generic CUDA argument error, this PR makes
Triton show an error during codegen when `sm < 70`. This check happens
in C/C++ so won't add an overhead in Triton's Python runtime.
- 3 tests (see below) are currently failing, but these are not marked as
skipped because any codegen update in the future can change the kernel
size of the other tests.
- This change won't affect Triton-MLIR. Hopefully Triton-MLIR's generic
`tl.dot()` implementation would support P100.

Importantly, Triton passed all the other tests on P100. Though this
support is not official, it is great for, for example, PyTorch's
TorchDynamo/Inductor, which can use Triton (without `tl.dot()`) for its
backend (https://github.com/pytorch/torchdynamo/issues/1591).

### Results on P100 (Google Cloud)

```sh
$ pytest test/unit
...
================================================================================== short test summary info ==================================================================================
FAILED test/unit/language/test_core.py::test_reduce2d[argmin-float32-shape99-1] - RuntimeError: Device does not support shared memory of 65536bytes
FAILED test/unit/language/test_core.py::test_reduce2d[argmax-float32-shape113-1] - RuntimeError: Device does not support shared memory of 65536bytes
FAILED test/unit/language/test_core.py::test_permute[float32-shape5-perm5] - RuntimeError: Device does not support shared memory of 67584bytes
================================================================== 3 failed, 3824 passed, 952 skipped in 470.90s (0:07:50) ==================================================================
```

<details><summary> <b>Environment Details (collapsed)</b></summary>
<p>

### VM details (Google Cloud)
https://cloud.google.com/
```
# You need a paid account (free trial does not cover GPUs)
Google Cloud -> New Project -> Compute-Engine -> VM Instance
Machine:
GPU: NVIDIA Tesla P100 x 1
CPU: 2 vCPUs, 7.5GB memory
Boot disk:
  OS: Ubuntu 18.04 LTS
  Disk: 40GB (cannot build Triton on the default 10GB disk)
- When I tried, about $1.2 per hour.
- US instances were full when I tried.  I used Asia or Australia.
- Needed a paid account (GPU is not covered by free trial)
- Needed quota request for any GPU instance (by default, no GPU instance is allowed).  Needed to wait an hour for approval
```

### Reproducer
```sh
## 1. Install CUDA and a driver
# Update the apt key (https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key/)
sudo apt-key del 7fa2af80
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-keyring_1.0-1_all.deb
# Download CUDA as instructed
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-ubuntu1804.pin
sudo mv cuda-ubuntu1804.pin /etc/apt/preferences.d/cuda-repository-pin-600
sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
sudo add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/ /"
sudo apt-get update
sudo apt-get -y install cuda
# Are you using P100?
nvidia-smi | grep "Tesla P100"

## 2. Setup the build environment
sudo apt update
sudo apt install -y build-essential wget git libz-dev
wget https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh
bash Anaconda3-2022.05-Linux-x86_64.sh -b -p $(pwd)/anaconda3
eval "$($(pwd)/anaconda3/bin/conda shell.bash hook)"
conda create -y --name triton_base
conda activate triton_base
conda install -y cmake setuptools

## 3. Build Triton
git clone https://github.com/openai/triton.git
cd triton/python
pip3 install -e '.[tests]'

## 4. Test
pytest test/unit
```

### Environment
```sh
$ nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05    Driver Version: 520.61.05    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla P100-PCIE...  On   | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    25W / 250W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
```

</p></details>
2022-11-03 00:11:52 -07:00
Chenggang Zhao
f16138d447 [Frontend] Interface fixes for libdevice (#830)
- Unifying several interfaces with different types to a single one, e.g.
`fsub_ru` and `dsub_ru` -> `sub_ru`;
- Minor bug fix: `fast_pow` is incorrectly classified into the `pow`
interface, of which arguments are the same as `powf`;
- Explicit interfaces for casting functions, e.g. decoupling
`ll2float_ru` to `ll2float_ru` and `ull2float_ru`;
- Removing interfaces that are not in NVIDIA's official documents, e.g.
`fmaf_ieee_rn`, which is confusing together with `fmaf_rn`.

Note that this PR for the master branch is different from #829, which is
for the MLIR branch.
2022-11-01 10:51:58 -07:00
Keren Zhou
3ca667dfa8 [Frontend] Return a scalar if all input args are scalar (#816) 2022-10-28 23:27:06 -07:00
Yanbo Liang
5ca1ed0101 Add bf16/fp16/fp64 support for ty_to_cpp (#800)
In ```torch._inductor```, we [convert 0d CPU tensor to scalar during
triton codegen](https://github.com/pytorch/pytorch/pull/87329), so need
add missing triton support for bf16/fp16/fp64.
2022-10-24 19:41:25 -07:00
Keren Zhou
db3aa1d1fb [FRONTEND] Fix libdevice (#776)
Fix two problems in libdevice and external dispatch:

1. Use static triton types (e.g., tl.int32) instead of creating new
types. Otherwise, `tl.int32` and `tl.dtype('int32')` are not the same
thing.

2. The name of an extern inst should be empty but not the symbol name of
the inst. TTIR generator will assign names automatically. Otherwise, we
have the same variable name when there are multiple same extern insts.

Before the PR:

```bash
  __nv_exp = extern_elementwise f64<1024> %11;
  __nv_exp = extern_elementwise f64<1024> %11;
```

After the PR:

```bash
  %12 = extern_elementwise f64<1024> %11;
  %13 = extern_elementwise f64<1024> %11;
```
2022-10-13 17:18:16 -07:00
Keren Zhou
bc98aead33 [Backend] Fix for mov.u8 (#766)
Init a potential fix for mov.u8 which is not supported by ptx for now.
Use mov.u16 instead and cast it to u8.
2022-10-12 14:32:27 -07:00
Yu Guo
71b46acc42 [IR] Added special-purpose dequantize instruction (#759)
It is currently necessary for optimal performance in quantized workloads to add a special-purpose instruction in the IR. Backward compatibility with this instruction is *NOT* guaranteed.
2022-10-12 14:14:45 -07:00
Philippe Tillet
af76c989eb [RUNTIME] Make entry point cache key depend on triton version hash (#765) 2022-10-11 13:24:30 -07:00
Bin Bao
09cc2d454b [FRONTEND] Fix a bool tensor storing problem (#746) 2022-10-10 12:11:50 -07:00
Felipe Petroski Such
5d4b26d380 [RUNTIME] support multiple devices in the same process (#757) 2022-10-09 20:30:04 -07:00
Chris
9a11a567ce [DOCS] Fixed typos in 01-vector-add.py (#751) 2022-10-09 18:12:46 -07:00
Keren Zhou
11345e9b74 [RUNTIME] Add callback functions for external tools (#738) 2022-10-05 14:46:55 -07:00
Philippe Tillet
bdfdb9a1d2 [RUNTIME] Fixed JIT bug that leg some constexpr values to be overriden by specialization parameters (#742) 2022-10-05 11:00:32 -07:00
shenggan
77c752dc78 [RUNTIME] remove fixed cu_include_dir (#739)
Use environment variable `CUDA_HOME` with default value`/usr/local/cuda` for `cu_include_dir` #731
2022-10-04 19:49:57 -07:00
Natalia Gimelshein
d3c925db8a [FRONTEND] properly broadcast scalar where condition (#736) 2022-10-04 12:44:03 -07:00
fdrocha
2b0f877fad [RUNTIME] Support environments with multiple cudalibs (#733) 2022-10-03 18:36:24 +00:00
Keren Zhou
4a2d3b7d79 [RUNTIME] Dump llvm, ttir, and sass to help debugging (#732) 2022-10-03 00:39:52 +00:00
Natalia Gimelshein
f55960e773 [FRONTEND] fix broadcasting for where (#729)
Fixes #532, all 3 inputs to where have to be broadcast together.
2022-10-01 13:18:47 -07:00
Phil Tillet
b244db06da [TUTORIALS] Attention tutorial fixup 2022-09-30 19:31:43 -07:00
Shintaro Iwasaki
ae59f51c2d [CODEGEN] Fix an inliner to call a function with a phi-node (#727) 2022-09-29 21:36:40 -07:00
albanD
f45e31ba7c [FRONTEND] Make sure to hold the gil when creating python objects (#726)
Without this patch, a debug version of python complains that:
```
Fatal Python error: Python memory allocator called without holding the GIL
Python runtime state: initialized
```
2022-09-29 18:06:22 -07:00
Philippe Tillet
dad97528b2 [TESTING] allclose fixup (#724) 2022-09-28 22:49:05 +00:00
Jason Ansel
998fd5f9af [FRONTEND] Make triton.compile work without a cuda context (#708)
This allows compiling in a subprocess. I'm not seeing a ton of speedup from this, but figure it is a good change anyway.
2022-09-24 13:41:47 -07:00