Commit Graph

566 Commits

Author SHA1 Message Date
Francisco Massa
c1b057eee9 [FRONTEND] Add option to specify number of compilation threads during Triton compilation (#1450)
On some machines, the amount of available RAM might not be enough to
compile Triton with `2 * num_cpus` parallelism. For example, CircleCI's
`large` instance can't handle Triton compilation as is due to
insufficient memory.

Instead, I propose to take PyTorch's approach where we can define a
[`MAX_JOBS` env
var](0e4ddc2b40/tools/setup_helpers/cmake.py (L366-L368))
that gives the user the possibility to reduce (or increase) the
parallelism during compilation.

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-03-31 11:34:18 -07:00
Philippe Tillet
123afdf423 [DOCS] fixed typo assert_almost_equal -> assert_allclose in tutorials (#1456) 2023-03-31 11:27:18 -07:00
mcskatkat
109bfca5c0 [FRONTEND] CodeGenerator.statically_implemented_functions: fixed incorrect hacky initialization (#1444)
This fixes the problem indicated in #1443
2023-03-30 00:26:00 -07:00
Chenggang Zhao
1bead327fd [TUTORIALS] Add the missing tutorial: libdevice functions (#1430)
While merging `triton-mlir`, it seems that the libdevice tutorial was
missed. This PR adds it back and modifies it with current interface
`tl.math`.

Also found a bug in `test_core.py`, `extern_libs` arguments should still
pass `libdevice`. Or it will fail on my added test. Legacy code didn't
fail because `lib_path` is none and ignored.

---------

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Philippe Tillet <phil@openai.com>
2023-03-29 19:00:17 -07:00
zahimoud
3fe2901bfc [FRONTEND] Typehint improvement (#1442)
Fixed bug with typehint checking. Refactored typehint code for
specializations. Added typehint checking for sig_keys.
2023-03-29 18:12:40 -07:00
Keren Zhou
43eed392df [BACKEND] Fix tl.exp for fp16 (#1440)
https://github.com/openai/triton/issues/1438
https://github.com/openai/triton/issues/1360
2023-03-29 16:34:23 -07:00
Sophia Wisdom
f53bb6a1bc [FRONTEND] More friendly error message when non-Triton function is called from Triton function (#1429) 2023-03-28 22:38:03 -07:00
zahimoud
73b124155b [FRONTEND] Added typehints support to speedup triton kernel launch (#1431)
One of the possible optimizations for kernel launch overhead. Basically,
we are trying to avoid having to run `hasattr` and `isinstance` for each
argument, by adding typehints to the kernel definition. Also, added a
unit test to regression to make sure we keep the launch overhead within
an expected range.
2023-03-28 22:37:34 -07:00
Keren Zhou
ee593fca0b [BACKEND] Fix int8 dot (#1435) 2023-03-28 20:18:17 -07:00
Philippe Tillet
4bfbb8718a [FRONTEND] Added NoneType as an accepted condition type (#1436) 2023-03-28 18:35:12 -07:00
Philippe Tillet
888cbad0e5 [FRONTEND] parse_mlir_module now properly initializes LLVMDialect (#1433) 2023-03-28 15:25:31 -07:00
Keren Zhou
adc4d25276 [BACKEND] A general interface for initializing destination operands in load/store operations (#1427) 2023-03-27 22:13:01 -07:00
Philippe Tillet
fe76b12354 [BUILD] Back to cmake >= 3.18 (#1428) 2023-03-27 16:47:34 -07:00
Chenggang Zhao
72b071253e [FRONTEND] Support block pointer semantics (#1392)
This PR introduces a new semantics: **block pointer**, which makes users
easier & faster to load a block from a parent tensor.

Below is a detailed API change by an example:
```
# Make a block pointer, which points to a block in the parent shape
# `base`: the parent tensor
# `shape`: the shape of the parent tensor
# `strides`: the strides of the parent tensor
# `offsets`: the offsets of the block in the parent tensor
# `order`: the order of the data arrangement in memory
# Below is an example loading a 2D column-major matrix 
block_ptr = tl.make_block_ptr(base=ptr, shape=(M, N), strides=(stride_m, stride_n), offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))

# Advance the offsets; note that the striding information is already saved in `block_ptr`
# `base`: the block pointer to be advanced
# `offsets`: the offsets for each dimension
block_ptr = tl.advance(base=block_ptr, offsets=(BLOCK_M, -BLOCK_N))
block_ptr = tl.advance(base=block_ptr, offsets=(-BLOCK_M, BLOCK_N))

# Load from a block pointer, the output type is the dereferenced type of `block_ptr`, e.g. ptr<tensor<32x32xf32>> -> tensor<32x32xf32>
# `ptr`: the block pointer to be loaded
# `boundary_check`: a tuple of dimensions to check the boundary
# `padding`: padding strategy for elements out of bound
val = tl.load(ptr=block_ptr, boundary_check=(0, 1), padding="zero")

# Store by a block pointer, in which the pointer and the value tensor should have the same shape
# `ptr`: the block pointer to be stored
# `boundary_check`: a tuple of dimensions to check the boundary (no-write if out of bound)
tl.store(ptr=block_ptr, value=val, boundary_check=(0, 1))
```

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-03-27 16:46:49 -07:00
Xuehai Pan
c52219b5c3 [SETUP] avoid using deprecated distutils (#1400)
Module [`distutils`](https://docs.python.org/3/library/distutils.html)
is deprecated and will be removed in Python 3.12.

Ref:

- `distutils` documentation:

> ##
[distutils](https://docs.python.org/3/library/distutils.html#module-distutils)
— Building and installing Python modules
>
[distutils](https://docs.python.org/3/library/distutils.html#module-distutils)
is deprecated with removal planned for Python 3.12.

- PEP 632 – Deprecate distutils module:

> [PEP 632 – Deprecate distutils
module](https://peps.python.org/pep-0632)

------

This PR removes references to `distutils` and replaces them with
[`packaging`](https://pypi.org/project/packaging) and[
`sysconfig`](https://docs.python.org/3/library/sysconfig.html).
Alleviate potential breakage in the modern Python packaging system.

Changes:

- Removes references to `distutils` and replaces them with
[`packaging`](https://pypi.org/project/packaging) and[
`sysconfig`](https://docs.python.org/3/library/sysconfig.html)
- Add `cmake` and `package` in `build-system.requires` to install
necessary build dependencies prior to calling `setup.py`.
- Minor changes: `multiprocessing.cpu_count() -> os.cpu_count()` and add
PyPI classifiers.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-03-27 10:37:47 -07:00
Philippe Tillet
46672772b4 [FORMAT] autopep8 now uses max-line-length=88 (#1410) 2023-03-25 15:46:50 -07:00
Xuehai Pan
5b36cb48ad [CI][TEST] update pre-commit hooks and use pre-commit for style tests in CI (#1409)
Ref issue:

- #1408

Changes:

- Add `.editorconfig`
- Add `pre-commit-hooks`:

    ```yaml
    - repo: https://github.com/pre-commit/pre-commit-hooks
      rev: v4.4.0
      hooks:
        - id: check-symlinks
        - id: destroyed-symlinks
        - id: trailing-whitespace
        - id: end-of-file-fixer
        - id: check-yaml
        - id: check-toml
        - id: check-ast
        - id: check-added-large-files
        - id: check-merge-conflict
        - id: check-executables-have-shebangs
        - id: check-shebang-scripts-are-executable
        - id: detect-private-key
        - id: debug-statements
    ```
- Add `flake8` to `pre-commit` config and add `.flake8` file
- Use `pre-commit` for style tests in CI
- Run `pre-commit` and fix existing violations:
    - fix trailing spaces
    - fix end-of-files
    - fix mod file mode with `chmod -x`
    - run `autopep8` on existing code
    - fix `flake8` violations
2023-03-25 14:52:16 -07:00
peterbell10
6063fccd0b [FRONTEND][BACKEND] Lower tl.abs to math::Abs{I,F}Op (#1401)
This generates identical PTX for floating point, but for integer types
the resulting PTX is much better. For example `tl.abs` for int16
currently generates

```mlir
  cvt.s32.s16 %r1, %rs2;
  neg.s16     %rs4, %rs2;
  setp.lt.s32 %p4, %r1, 0;
  selp.b16    %rs3, %rs4, %rs2, %p4;
```

After, it becomes a single `abs.s16` instruction.

This also improves LLVM's ability to optimize floats. e.g. `abs(t) *
abs(t)` is optimized to `t * t` now which didn't happen before.

---------

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2023-03-24 21:58:24 -07:00
Michael Melesse
a9c87245b4 [ROCM] Enable ROCM Backend #1: Empty Kernel (#1312)
This PR is a first in a series of PRs to import the changes that we have
made to enable ROCM on [our
fork](https://github.com/ROCmSoftwarePlatform/triton) of triton.

The PR contains the major changes to the python frontend and enough
changes to the c++ backend to allow compilation and running of the empty
kernel. We use the ROCM ci added a few weeks ago to verify things.

---------

Co-authored-by: Ronan Keryell <ronan@keryell.fr>
2023-03-24 17:18:27 -07:00
wujingyue
89d8fe6502 [DOCS] Fixed a pytype annotation. (#1406)
llir_to_ptx returns a str not a Tuple.
2023-03-24 23:46:48 +00:00
Philippe Tillet
fc7c0b0e43 [FRONTEND] Removed torch dependency and cleaned up testing (#1394)
`assert triton.testing.allclose` -> `torch.testing.assert_allclose`
`triton.testing.assert_almost_equal` -> `torch.testing.assert_allclose`
2023-03-23 22:37:21 -07:00
xndcn
ff1d0377e0 [BACKEND] Fix wrong conversion from float8e5m2 <> bfloat16 (#1391)
exponent compensate should be 0x3800(112) instead of 0x3000(96)
also add a mantissa bit for float16 conversion to round to nearest
float8e5m2

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-03-24 04:42:08 +00:00
Natalia Gimelshein
3239c93a93 [TEST] add a test for inductor normalization pattern (#1390) 2023-03-23 00:29:28 +00:00
Philippe Tillet
7c7b769e37 [SETUP] Fixed dependencies (#1389) 2023-03-22 16:15:35 -07:00
xndcn
65d8d802d5 [BACKEND] Fix wrong conversion from float8e4m3 <> bfloat16 (#1384)
exponent compensate should be 0x3c00(120) instead of 0x3800(112)
2023-03-21 18:58:13 -07:00
mcskatkat
9ae78d21f1 [FRONTEND] CompilationError._format_message issue + tidying (#1362)
- fixed `CompilationError._format_message` fails when `error_message` is
a `constexpr`
- factored out `_is_constexpr()` checks and `_unwrap_if_constexpr()`
idioms
- Added `UnsupportedLanguageConstruct` exception, replaced some python
builtin exceptions raised in such cases.
- Some hardening in `.visit_If()`
- cleaner exception handling in `build_triton_ir()`
2023-03-21 19:52:18 +00:00
Keren Zhou
c1dd6df9ce [FRONTEND] Fix negative induction variable (#1382) 2023-03-21 08:38:16 -07:00
xndcn
84ffefc368 [BACKEND] Fix wrong conversion from float8e4m3 <> float16 (#1375)
after offset shifting, exponent compensate should not be forgotten
also add back some comments from `legacy_backend`
2023-03-20 21:45:25 -07:00
Phil Tillet
e650d3708b [FRONTEND] dot now uses tl.float32 by default for out_dtype. 2023-03-19 21:58:46 -07:00
Philippe Tillet
b4decbe155 [BACKEND] Now using call_once to initialize LLVM target (#1373) 2023-03-19 21:23:39 -07:00
Fei Hu
6366c5a254 [FRONTEND][BACKEND] Add support for FP16 output for tl.dot (#1258)
---------

Co-authored-by: Fei Hu <fhu@microsoft.com>
2023-03-19 19:52:14 -07:00
Philippe Tillet
e4b2d1bc3d [FRONTEND][BACKEND] no longer using indices for loops (#1370) 2023-03-19 14:57:50 -07:00
Philippe Tillet
39139258c8 [FRONTEND][BACKEND] tl.mathlib -> tl.math; internally reverted to mathlib -> libdevice (#1368) 2023-03-19 02:14:57 -07:00
rsanthanam-amd
c575911a01 [FRONTEND] Change libdevice to mathlib and fix abs (#1361)
Co-authored-by: Phil Tillet <phil@openai.com>
2023-03-19 01:34:16 -07:00
Philippe Tillet
2f035c0611 [FRONTEND] Fix contains_return_op when analyzing functions in another module (#1365) 2023-03-18 15:02:45 -07:00
Edward Z. Yang
6d61a5ca23 [FRONTEND] Don't use HOME envvar to get HOME (#1364)
Fixes https://github.com/pytorch/pytorch/issues/97076
2023-03-18 10:39:58 -07:00
Horace He
1d2871d0d1 [RUNTIME] Fix memory leak in (#1358)
Fixes a bug that causes Triton to leak 32 bytes on every kernel
invocation.

Also solves https://github.com/pytorch/pytorch/issues/96937
2023-03-16 17:52:06 -07:00
mcskatkat
611a2dc9bf [FRONTEND] CodeGenerator: enhanced (#1355)
Contents of this change to `CodeGenerator`:
- addressed mutable default value in constructor (GitHub #1353)
- structured and faster name lookup (replaces `.get_value`)
- added informative error messages in some places
- tidy mechanism for "static" (compile time) functions replaces inline
`if ... elif ...` chain in `.visit_Call`
- more robust `static_assert` and `static_print`
- more informative `CompilationError` display (saves scrolling up
through long tracebacks)
- dedicated `CompileTimeAssertionFailure` exception for `static_assert`
can be specially treated upstream by `Autotuner` to skip configurations
that violate constraints (as for `OutOfResources`)

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-03-16 17:00:43 -07:00
Berke Kocaoğlu
ba91f39dbf [DOC] Fix syntax errors, typos, formatting; increase consistency (#1357)
This PR;
- Fixes syntax errors like `.type values: dict[str,
Callable[[list[Any]], Any]]` to `:type values: dict[str,
Callable[[list[Any]], Any]]`,
- Fixes typos,
- Fixes formatting like `k ++` to ` k++`,
- Increases consistency (e.g. by transforming the minority `cd dir/` to
the majority `cd dir`).
2023-03-16 15:32:02 -07:00
mcskatkat
53e8e04d6e [FRONTEND] fix constexpr by annotation (#1352)
Fixed unjustified `TypeError` raised when arg is (strangely) annotated
with a non-type
2023-03-16 11:10:19 -07:00
mcskatkat
f5d22d5995 [FRONTEND] support f-strings in compiler with constexpr conversion (#1349)
This addition allows explanatory messages upon assertion failures:

```python
@triton.jit
def my_single_block_kernel(
    matrix_extent: tl.constexpr,
    block_size: tl.constexpr,      # must be >= extent (single block)
    matrix: Tensor,
    ...
):
    tl.static_assert(matrix_extent <= block_size, 
                     f"`matrix_extent` should not be more than the block size ({block_size}), but is {matrix_extent}")
```

Yielding, when called incorrectly:
```
AssertionError: `matrix_extent` should not be more than the block size (32), but is 57
```
2023-03-16 08:02:10 +00:00
Shintaro Iwasaki
4b774ee4d0 [OPS/BLOCKSPARSE] remove unnecessary mask (#1351)
This PR applies a minor patch that removes unnecessary masks in
`_dsd_kernel()`.

### Details

`offs_bn` is defined as follows and not updated after that.
```py
offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)
```

Because `offs_bn = offs_bn % DS0`, this mask is always `True`.
```py
b = tl.load(pb, mask=offs_bn[None, :] < DS0)
```
This PR removes this mask (as well as explicit `mask=True`).
2023-03-15 19:06:38 -07:00
mcskatkat
c175473bbf [FRONTEND] In JITFunction: infer constexpr arg only if annotated as such (#1345)
Fixed `JITFunction.__init__` to mark args as constexpr only when the
annotation is actually `tl.constexpr`, rather than treating any
annotated arg as constexpr.
2023-03-15 16:39:45 -07:00
Stonepia
109b5e2729 [BUILD] Fix the build bug when user use system package of llvm by setting LLVM_SYSPATH (#1336)
When the user set the `LLVM_SYSPATH` to use custom build llvm, it will
throw the error because there is no version.txt under the custom build
one.

This PR skips the version check If the `LLVM_SYSPATH` is set.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-03-15 13:28:19 -07:00
Philippe Tillet
56b23f433d [TEST] Temporarily disable test_dot mode that fails because of ptxas/nvptx (#1344) 2023-03-15 01:17:48 -07:00
peterbell10
01b177afe7 [FRONTEND] Mangle signed and unsigned integer types differently (#1340)
This is cherry-picked from #1305

If you call a `JITFunction` twice in the same kernel, first with `int32`
then with `uint32`, the second call will treat the unsigned value as
signed. This passes through MLIR without error because MLIR uses the
same types for both, but different operation calls will be generated so
you may silently get the wrong result.
2023-03-14 22:29:18 -07:00
Philippe Tillet
ad81447ad0 [FRONTEND] Marking int1 (bool) type as unsigned (#1343) 2023-03-14 22:05:13 -07:00
Philippe Tillet
6a8634e2a7 [BACKEND] No longer running LLVM-IR optimizations after codegen. (#1338)
This triggered some outrageous bugs. See #1337.
2023-03-13 22:50:15 -07:00
Philippe Tillet
dde34904d0 [TESTING] triton.testing.allclose now uses torch.allclose (#1333) 2023-03-13 17:48:32 -07:00
Nikita Shulga
663074460d [VERSION] Update triton/__init__.py (#1327)
Followup after
c7581c9a91
2023-03-13 10:38:38 -07:00