Commit Graph

110 Commits

Author SHA1 Message Date
Jason Furmanek
4c4e42e524 Merge remote-tracking branch 'openai/main' into IFU-230517
Conflicts:
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
	lib/Target/LLVMIR/LLVMIRTranslation.cpp
	python/test/unit/language/assert_helper.py
	python/triton/third_party/cuda/bin/ptxas
	test/Conversion/tritongpu_to_llvm.mlir

 It looks like you may be committing a merge.
 If this is not correct, please remove the file
	.git/MERGE_HEAD
 and try again.
2023-05-17 15:03:42 +00:00
Keren Zhou
3baab48eaf [FRONTEND] Differentiate between bool and int in the frontend (#1678)
`bool` is a subclass of `int`, so `isinstance(bool_var, int) == True`,
and a `bool` constant will be converted to an `int` constant.

In triton specifically, if a bool var is treated as an integer, it
prevents us using the `logical_and` operator which requires both
operands have the same bit length.

> Cannot bitcast data-type of size 32 to data-type of size 1

By differentiating int and bool, it allows us to make the syntax more
close to native python. We can now use `if bool_var and condition` to
check the truthiness, and `if bool_var is True` to check identity.
2023-05-16 18:24:16 +00:00
Natalia Gimelshein
0daee68d71 [FRONTEND] Don't call set_device in tl.dot (#1646)
This breaks multiprocess compilation
2023-05-10 20:39:27 -04:00
Mario Lezcano Casado
6b1af5fe37 [FRONTEND] Add support for scalar conditions in device_assert (#1641)
This sometimes happens in TorchInductor. See
https://github.com/pytorch/pytorch/pull/100880.
More generally, it's useful to be able to write `tl.device_assert(False,
msg)`.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2023-05-09 23:05:00 -07:00
Michaël Benesty
d196302cb0 [FRONTEND] make torch optional (#1604)
make torch optional to fix circular dependency issue
2023-05-02 21:56:25 -07:00
Michael Melesse
2784b804d9 Merge remote-tracking branch 'upstream/main' into ifu_4_26_2023 2023-04-26 12:04:21 -05:00
Himanshu Pathak
6d226431b1 [FRONTEND] do not run AccelerateMatmul on pre-Volta GPUs (#1505)
Related to #1271 . I am currently working on adding support for
Pre-volta GPUs in Triton.

---------

Co-authored-by: Himanshu Pathak <himanshu@mtatva.com>
Co-authored-by: Philippe Tillet <phil@openai.com>
2023-04-24 15:53:02 -07:00
Sharad Vikram
cf26e05a8f [FRONTEND] remove debug print (#1538) 2023-04-17 15:17:19 -07:00
Chenggang Zhao
c624778e73 [TUTORIALS] Add tutorial for block pointers (#1519)
This PR contains:
- Several fixes for the matrix multiplication (M and N dimensions may
have out-of-bound access)
- A type check for block-based store
- The tutorial for block pointers
- Fix some formats
2023-04-14 00:40:41 -07:00
peterbell10
e152183570 [FRONTEND][BACKEND] ReduceOp to support arbitrary reduce operations (#1305)
Fixes #1285

This changes `tt.reduce` to replace `redOp` by a region containing
arbitrary code. For example, `tl.sum` is now lowered as:
```mlir
%res = "tt.reduce"(%arg0) ({
^bb0(%arg1: f32, %arg2: f32):
  %add = arith.addf %arg1, %arg2 : f32
  tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<128x128xf32>) -> tensor<128xf32>
```
Support for index reductions at the MLIR level are also dropped in favor
of simultaneous reductions over multiple tensors. Which generalizes the
code without loss of performance. So for example `argmin` gets lowered
as:
```mlir
  %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
  %8 = tt.view %7 : (tensor<256xi32>) -> tensor<1x256xi32>
  %9:2 = "tt.reduce"(%6, %8) ({
  ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32):
    %14 = arith.cmpf olt, %arg4, %arg6 : f32
    %15 = arith.cmpf ogt, %arg4, %arg6 : f32
    %16 = arith.cmpi slt, %arg5, %arg7 : i32
    %17 = arith.select %16, %arg5, %arg7 : i32
    %18 = arith.select %15, %arg7, %17 : i32
    %19 = arith.select %14, %arg5, %18 : i32
    %20 = arith.cmpf olt, %arg4, %arg6 : f32
    %21 = arith.select %20, %arg4, %arg6 : f32
    tt.reduce.return %21, %19 : f32, i32
  }) {axis = 1 : i32} : (tensor<1x256xf32>, tensor<1x256xi32>) -> (tensor<1xf32>, tensor<1xi32>)
```
2023-04-13 01:37:39 +00:00
Keren Zhou
6743e42eb5 [FRONTEND] Data type specification for math functions (#1485) 2023-04-07 10:26:19 -07:00
Rahul Batra
a27b388df5 Merge remote-tracking branch 'upstream/main' into IFU_04-06-2023 2023-04-06 16:18:31 -05:00
Keren Zhou
7f3f58f332 [FRONTEND] Fix broadcast semantics (#1480)
https://github.com/openai/triton/pull/1183

---------

Co-authored-by: Yen-Chen Lin <yenchenlin1994@gmail.com>
2023-04-06 10:40:40 -07:00
Keren Zhou
ee593fca0b [BACKEND] Fix int8 dot (#1435) 2023-03-28 20:18:17 -07:00
Chenggang Zhao
72b071253e [FRONTEND] Support block pointer semantics (#1392)
This PR introduces a new semantics: **block pointer**, which makes users
easier & faster to load a block from a parent tensor.

Below is a detailed API change by an example:
```
# Make a block pointer, which points to a block in the parent shape
# `base`: the parent tensor
# `shape`: the shape of the parent tensor
# `strides`: the strides of the parent tensor
# `offsets`: the offsets of the block in the parent tensor
# `order`: the order of the data arrangement in memory
# Below is an example loading a 2D column-major matrix 
block_ptr = tl.make_block_ptr(base=ptr, shape=(M, N), strides=(stride_m, stride_n), offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))

# Advance the offsets; note that the striding information is already saved in `block_ptr`
# `base`: the block pointer to be advanced
# `offsets`: the offsets for each dimension
block_ptr = tl.advance(base=block_ptr, offsets=(BLOCK_M, -BLOCK_N))
block_ptr = tl.advance(base=block_ptr, offsets=(-BLOCK_M, BLOCK_N))

# Load from a block pointer, the output type is the dereferenced type of `block_ptr`, e.g. ptr<tensor<32x32xf32>> -> tensor<32x32xf32>
# `ptr`: the block pointer to be loaded
# `boundary_check`: a tuple of dimensions to check the boundary
# `padding`: padding strategy for elements out of bound
val = tl.load(ptr=block_ptr, boundary_check=(0, 1), padding="zero")

# Store by a block pointer, in which the pointer and the value tensor should have the same shape
# `ptr`: the block pointer to be stored
# `boundary_check`: a tuple of dimensions to check the boundary (no-write if out of bound)
tl.store(ptr=block_ptr, value=val, boundary_check=(0, 1))
```

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-03-27 16:46:49 -07:00
Philippe Tillet
46672772b4 [FORMAT] autopep8 now uses max-line-length=88 (#1410) 2023-03-25 15:46:50 -07:00
Xuehai Pan
5b36cb48ad [CI][TEST] update pre-commit hooks and use pre-commit for style tests in CI (#1409)
Ref issue:

- #1408

Changes:

- Add `.editorconfig`
- Add `pre-commit-hooks`:

    ```yaml
    - repo: https://github.com/pre-commit/pre-commit-hooks
      rev: v4.4.0
      hooks:
        - id: check-symlinks
        - id: destroyed-symlinks
        - id: trailing-whitespace
        - id: end-of-file-fixer
        - id: check-yaml
        - id: check-toml
        - id: check-ast
        - id: check-added-large-files
        - id: check-merge-conflict
        - id: check-executables-have-shebangs
        - id: check-shebang-scripts-are-executable
        - id: detect-private-key
        - id: debug-statements
    ```
- Add `flake8` to `pre-commit` config and add `.flake8` file
- Use `pre-commit` for style tests in CI
- Run `pre-commit` and fix existing violations:
    - fix trailing spaces
    - fix end-of-files
    - fix mod file mode with `chmod -x`
    - run `autopep8` on existing code
    - fix `flake8` violations
2023-03-25 14:52:16 -07:00
peterbell10
6063fccd0b [FRONTEND][BACKEND] Lower tl.abs to math::Abs{I,F}Op (#1401)
This generates identical PTX for floating point, but for integer types
the resulting PTX is much better. For example `tl.abs` for int16
currently generates

```mlir
  cvt.s32.s16 %r1, %rs2;
  neg.s16     %rs4, %rs2;
  setp.lt.s32 %p4, %r1, 0;
  selp.b16    %rs3, %rs4, %rs2, %p4;
```

After, it becomes a single `abs.s16` instruction.

This also improves LLVM's ability to optimize floats. e.g. `abs(t) *
abs(t)` is optimized to `t * t` now which didn't happen before.

---------

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2023-03-24 21:58:24 -07:00
Fei Hu
6366c5a254 [FRONTEND][BACKEND] Add support for FP16 output for tl.dot (#1258)
---------

Co-authored-by: Fei Hu <fhu@microsoft.com>
2023-03-19 19:52:14 -07:00
Rohit Santhanam
a84b4883e6 Merge remote-tracking branch 'upstream/main' into triton-mlir-IFU-03192023 2023-03-19 13:46:50 +00:00
Philippe Tillet
39139258c8 [FRONTEND][BACKEND] tl.mathlib -> tl.math; internally reverted to mathlib -> libdevice (#1368) 2023-03-19 02:14:57 -07:00
rsanthanam-amd
c575911a01 [FRONTEND] Change libdevice to mathlib and fix abs (#1361)
Co-authored-by: Phil Tillet <phil@openai.com>
2023-03-19 01:34:16 -07:00
Rohit Santhanam
6ff54b495c Merge remote-tracking branch 'upstream/main' into triton-mlir-IFU-03122023 2023-03-13 18:09:12 +00:00
Philippe Tillet
3fe3adbcde [FRONTEND][BACKEND] Add support for float8e5m2 type (#1314) 2023-03-10 19:14:47 -08:00
Keren Zhou
78b311f6e2 [FRONTEND] Fix cast when both src_ty and dst_ty are of block_type (#1301)
Commonly used in atomic_rmw ops
2023-03-08 09:25:00 -08:00
JiCheng
849a40baad [FRONTEND] Add check for the axis of reduction op (#1268) 2023-03-06 22:11:43 -08:00
Keren Zhou
d376020f90 [FRONTEND][BACKEND] Implement tl.device_assert and rename tl.printf to tl.device_print (#1143)
Note that `tl.device_print` and `print` accepts different arguments than
the normal `print`. The first argument must be a string, following by
variables.

Device side:

- `tl.device_print`
- `tl.device_assert`
- `print`
- `assert`

Compilation time:

- `tl.static_assert`
- `tl.static_print`

Usage example:

1.
```Python
tl.device_assert(x == 0, "x != 0")
```

Output:

```Python
...
python/test/unit/language/assert_helper.py:18: kernel: block: [0,0,0], thread: [33,0,0] Assertion `x != 0` failed.
...
```

2.
```Python
tl.device_print("hello ", x)
```

Output:

```Python
...
hello 1
...
```

The environment variable `TRITON_DEBUG` sets the default debugging flag; if it's true, `tl.device_assert` or `assert` will be skipped.
2023-03-04 08:08:29 -08:00
Keren Zhou
65e5a3bc24 [FRONTEND] Improve tl.full to accept both static and dynamic values (#1269) 2023-03-02 12:19:54 -08:00
Rohit Santhanam
cd9ae1cd36 Merge remote-tracking branch 'upstream/main' into triton-mlir-IFU-02232023 2023-02-23 21:41:54 +00:00
rsanthanam-amd
e7f84448bf Merge pull request #127 from dfukalov/dfukalov/work-3
[ROCM] Enable float16 and int8 types for FMA based `dot` implementation.
2023-02-22 16:39:04 -06:00
Daniil Fukalov
2d678efb89 [ROCM] Enable float16 and int8 types for FMA based dot implementation.
By default Triton generates MLIR with f32 result of the tt.dot operation on f16
 typed operands. So we have "tt.dot(f16,f16,f32)->f32" types in .ttgir. But
LLVM FMA instruction requires for the same type for all three operands. So first
two operands are implicitly casted f16->f32 as
"unrealized_conversion_cast struct{f16,f16,...}->struct{f32,f32}".
The change fixed incorrect implicit cast generation.
For the int8 typed operands result operand is also casted after performing dot.

As the next step to improve FMA based dot operation FMA on f16 and int8 target
specific intrinsics (e.g. fma(f16,f16,f16)->f16) could be used, perhaps as an
option.
2023-02-22 22:36:20 +01:00
Eric Wang
320ae18093 [FRONTEND] Add error messages for arange (#1218)
Fix issue https://github.com/openai/triton/issues/244

Check `end` is greater than `start`.
Check if the range can fit in `int32`.
Check the number of elements less than or equal to
`TRITON_MAX_TENSOR_NUMEL = 131072`.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-02-22 00:37:28 +00:00
Michaël Benesty
940f394a35 [Frontend] fix crash on cast when dest is constexpr (#1222)
This pull request addresses a crash that occurs when casting to a
tl.constexpr type in the frontend.

More info and repro code available in:
https://github.com/openai/triton/issues/1221
2023-02-20 10:50:33 -08:00
Rohit Santhanam
841784d1e3 Merge remote-tracking branch 'upstream/main' into upgrade_triton_mlir_rocm_to_llvm_head 2023-02-18 09:25:20 +00:00
Philippe Tillet
4d067f5120 [FRONTEND] Now emit an error for tl.reshape, instead of silently calling tl.view (#1212) 2023-02-17 20:21:20 -08:00
Eric Wang
30db959dae [FRONTEND] Add error messages for load/store (#1179)
Fix issue https://github.com/openai/triton/issues/633
2023-02-13 10:52:50 -05:00
Rohit Santhanam
a2416e0901 Merge remote-tracking branch 'upstream/main' into triton-mlir-IFU-02112023 2023-02-11 14:48:19 +00:00
Philippe Tillet
2aba985daa [OPTIMIZER] Improved layout simplifications heuristics (#1168) 2023-02-09 20:17:25 -08:00
fdrocha
972b761390 [FRONTEND] For __rshift__ operator, use arithmetic right shift if dtype is a signed int. (#1153) 2023-02-06 10:26:17 +00:00
Keren Zhou
ce47f94e59 [FRONTEND] Check if the data types of *A* and *B* in the dot op have the same data type (#1155) 2023-02-06 01:58:07 -08:00
Rohit Santhanam
8cb6ab5b1a Merge remote-tracking branch 'upstream/main' into triton_mlir_IFU_02022023 2023-02-02 22:54:53 +00:00
George Karpenkov
a9d1935e79 [FRONTEND] Fix error message when atomics are not supported for a given dtype (#1134)
Otherwise, the construction of the exception crashes during string
concatenation.
2023-02-02 02:49:34 +00:00
Philippe Tillet
c4b9d699d2 [FRONTEND][BACKEND] Fixed many bugs (#1122)
- **temporarily commenting assertion in `MemBar.cpp`. We need to fix
this! but for now the following patches will unblock a number of
users.**
- Fixed frontend codegen issue for If / For / While. Emit an error when
replaced values' type mismatch.
- Added "top level" codepath for if statements, which allows users to
write patterns to exit early from kernels (e.g., `if cond1: if cond2:
return else: ...`). Added associated codegen in TritonToTritonGPUPass
- Added basic control flow tests
- Pipeline pass is no longer activated when memory accesses can't be
vectorized
- Added missing magic methods to `constexpr`
- Fixed issue in random.py: bitcast some values to uint when they need
to be.
- Added support for `Not`
- Fixed nondeterministic compilation issue
2023-01-30 23:22:36 -08:00
Yan Chunwei
94b419c327 [FRONTEND] some tiny fix (#1120) 2023-01-30 19:39:38 -08:00
Nishant Sikarwar
653c8dc124 [FRONTEND] Replaced range with enumerate calls (#1110)
Using range(len(...)) is not pythonic. 
Python does not have not index-based loops. Instead, it uses collection
iterators. Python has a built-in method enumerate which adds a counter
to an iterable. Using this, you can access the counter and the value
from the iterable at the same time. It is therefore recommended to
replace range(len(...)) with enumerate(...).

for ex 


5bcf60a5c0/python/triton/language/extern.py (L68)



f62d556fff/python/triton/language/extern.py (L68)

Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2023-01-30 15:22:11 -08:00
Michael Melesse
a9f955f862 Merge remote-tracking branch 'upstream/main' into triton-mlir-IFU-2023-30-1 2023-01-30 14:04:01 -06:00
Nishant Sikarwar
e5dbe35cc1 [FRONTEND] removed unnecessary comprehension (#1085) 2023-01-30 19:42:14 +00:00
Rohit Santhanam
2d0ee0fa0f Merge remote-tracking branch 'upstream/main' into triton-mlir-IFU-01232023 2023-01-24 03:59:17 +00:00
Daniil Fukalov
e6983feb91 [ROCM] Implement next part of atomics.
- fixed scalar atomic_rmw implementation for fmin/fmax for f32
- fixed tensor atomic_rmw
- added atomic_cas implementation.

TODO: fix atomic_rmw for f16, implement fmin/fmax for f32 with
native instructions (asm inline in case of LLVM 14) instead of
tweak used as for NV.
2023-01-23 14:01:08 +01:00
Keren Zhou
3f47e9aa0e [BACKEND] Fix unrealized conversion for fp32 dot (#1051) 2023-01-17 21:55:44 +00:00