Commit Graph

863 Commits

Author SHA1 Message Date
Keren Zhou
307b5caa49 [BACKEND] Fix scan issues on repetitive warps and improve perf when there's a single warp on the axis (#2330)
1. On the axis, using `getAxisNumWarpsWithUniqueData` instead of getting
the raw number of warps to avoid communication among warps that handle
the same piece of data.
2. When there's a single warp on the axis, using warp Intrinsics for
communication and skip shared memory.

Need a follow up PR for code clean up.
2023-09-18 17:45:05 -04:00
Philippe Tillet
894fa9e943 [RUNTIME][INTERPRETER] now also override __str__ method for tensors (#2325) 2023-09-17 16:49:30 -07:00
Philippe Tillet
e686b4d6d4 [FRONTEND] interpreter rewrite (#2321)
This is a new interpreter mode that shares semantic analysis with the
JIT'ed codepath and that the Triton core team is committed to maintain
2023-09-17 14:58:50 -07:00
Myeonghwan Ahn
2b066000aa [FRONTEND] fix matmul int8 overflow issue (#2297)
Previously on matmul, if inputs are int8, output was also int8.
This commit fixes the overflow problem with int32 output.
#2296
2023-09-17 16:41:02 +00:00
Stonepia
68e1bd162c [FRONTEND] fix xpu stages logic (#2305) 2023-09-17 09:19:14 -07:00
jon-chuang
4f2d995fad [FRONTEND] Explicitly forbid dot(.., out_dtype=bfloat16) (#2308)
Fixes: https://github.com/openai/triton/issues/2302
2023-09-17 09:15:06 +00:00
Justin Lebar
073aa16379 [BUILD] use ninja (#2318) 2023-09-17 02:08:04 -07:00
Thomas Raoux
31b0c52142 [FRONTEND][BACKEND] Add flag to control accumulation for fp8 (#2300)
Change the dot to allow taking an initial accumulator and add a flag
that will allow the compiler to accumulate in a lower precision than the
output type.
On Hopper this flag is on by default which allows accumualting with
lower precision.
This only affect Hopper fp8 dot.
2023-09-15 18:42:54 -07:00
Zahi Moudallal
db5c793f82 [FRONTEND] Add sass to asm dict with lazy evaluation (#2309) 2023-09-15 15:31:43 -07:00
Keren Zhou
08c1658957 [FRONTEND] Accommodate new triton IR format (#2294)
- Support memory space for pointers (e.g., `!tt.ptr<f32, 1>`).
- Support parsing function attribute, though not used yet.
2023-09-14 09:03:23 -07:00
Zahi Moudallal
36087a108f [FRONTEND] Added SASS to asm dict (#2280) 2023-09-13 21:21:01 +00:00
Khushi Agrawal
c61d772eee [DOCS] add missing docs (#2154) 2023-09-13 19:30:40 +00:00
Thomas Raoux
b63e8f87fc [FRONTEND] Override prototype (#2214)
Low tech but very useful way to override kernels on the fly. This can be
use for debugging functionality or performance problems this lets user
dump modify and feed back IR into the jit compiler.
2023-09-13 10:05:47 -07:00
Zahi Moudallal
e95e1f12eb [BACKEND] Convert layout illegal mem access fix (#2287) 2023-09-13 10:02:25 -07:00
Thomas Raoux
994f7e4460 [BACKEND] Remove dependency between NVGPU and TritonNvidiaGPU (#2282) 2023-09-12 11:02:20 -07:00
Ying Zhang
37f12497b0 [FRONTEND] Add PyTorch fp8 dtypes to Triton (#2279)
Add PyTorch fp8 dtypes
(8025b193a9/torchgen/api/types/types.py (L50-L51))
to Triton.
2023-09-12 08:57:01 -07:00
Zahi Moudallal
a47f1f5c28 [BACKEND] Unify slow/fast reduce codegen (#2220) 2023-09-12 08:46:19 -07:00
jsh-20
fc5d7e6e7c [FRONTEND] Improve grid calculation for persistent kernels to hoist pe… (#2283)
…rf on problems that need few blocks.

constrain the number of launched blocks to what it exactely needs for
persistent warp specialized kernel. It's useful when problems need very
few blocks.
e.g. MxNxK=800x800x60000, f16_f16_f32, block size=128x128x64,
non-split-k. Experiments show it can achieve ~16% speedup.
2023-09-12 09:14:47 +00:00
peterbell10
ab9da3b2b8 [FRONTEND] Fix expand_dims and tl.full to handle scalar tensors (#2275)
This fixes a few bugs related to scalar tensors:
- `tl.full([], fill_value, dtype)` fails with `TypeError('0d block_type
is forbidden')`
- `scalar[None]` fails with `TypeError("'constexpr' object is not
iterable")`
- `scalar[None, None]` fails with `AttributeError("'dtype' object has no
attribute 'shape'")`
- `scalar.shape` returns `[1]` instead of 0-dim `[]`
- Also related, `tl.zeros_like(scalar)` returns a 1d tensor instead of
another scalar
2023-09-11 20:59:13 -07:00
Philippe Tillet
bf4f9375a7 [FRONTEND] allow mixed precision FP8 matmul on pre-H100 hardware (#2281) 2023-09-11 20:54:29 -07:00
Shintaro Iwasaki
8da27c1c95 [Build] Fix very minor compilation problems (#2277)
This PR fixes a few very minor compilation issues found in internal
deployment at Meta. It looks like nit-picking, but it'd be really
appreciated if it could be addressed in OSS Triton (to reduce
differences from OSS), and we believe these changes are not bad in
general. Neither performance nor functionality is affected by this PR.

1. Type cast in `python/triton/runtime/backends/cuda.c`. Implicit `void
*` -> `cuuint{32,64}_t *` cast is not allowed by many compilers (with
certain flags). It'd be nice to add an explicit cast (like
`backends/hip.c`).

2. Inconsistent include path specification in
`lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp`. Unlike other
`DotOpToLLVM/*.cpp`, include paths used in `WGMMA.cpp` are not relative.
This is problematic in some compilation settings since a compiler
somehow needs to find headers in a parent directory. It'd be great to
use a relative path, like other source files in Triton.

cc: @yuguo68
2023-09-11 19:28:31 -07:00
Thomas Raoux
a9db6b94b9 Remove wrong dependency between TritonGPU and NVGPU dialect (#2276) 2023-09-11 16:30:13 -07:00
danny.jang
ec4a968d44 [TESTS] Enhance benchmark flexibility (#2239)
User can pass custom arguments to benchmarks. For example, user can pass
`dtype` which will be used to create tensors in a benchmark.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2023-09-11 15:31:30 -04:00
jon-chuang
5231d57c71 [TESTS] replace deprecated torch.testing.assert_allclose (#2250)
Prior to this PR, matmul on sm_89 (RTX 4070)
(`test/unit/operators/test_matmul.py::test_op`) would result in test
failure due to too strict atol/rtol.

To avoid having to choose strictness ourselves, and to have better
defaults based on dtype, use the non-deprecated torch testing util.

See: https://github.com/pytorch/pytorch/issues/61844

Replace: https://github.com/openai/triton/pull/2242
2023-09-11 15:31:17 -04:00
Lixun Zhang
28d4c3bdb4 [BACKEND] Make sure getAxisBlockStride does not return 0 (#2273)
This can happen when the CTA shape is larger than the tensor shape along
the non-axis dim during scanOp lowering.
2023-09-11 11:02:56 -07:00
Keren Zhou
10f59d8ce0 [RUNTIME] Get the correct end idx for regular arguments of GPU kernels (#2262)
Previously, if there were any specializations of "1" or "constexpr"
mixed with unspecialized arguments in arbitrary order, we might have
encountered errors due to passing incorrect arguments. This was because
the length of the signature did not indicate the maximum index of
regular arguments.

https://github.com/openai/triton/issues/2229

@shunting314 @amjames 

More specifically for cases like:

```
kernel(
b: tl.tensor,
a: tl.constexpr,
c: tl.int = 1,
d,
e: tl.constexpr,
...
)
```
2023-09-07 23:31:07 -07:00
Izzy Putterman
7d01c1852a Revert unintentional change (#2257)
This change seems to have been unintentionally reverted in the hopper
PR:
38d767ea93

Adding it back.
2023-09-07 10:48:12 -07:00
Zahi Moudallal
f21b36c8c5 [CLEANUP] Delete binaries that went in by mistake (#2256) 2023-09-06 20:42:42 +00:00
jon-chuang
36859aebff [DOCS] Add MLIR Autogenerated Docs to Sphinx Docs (#2234)
Partially fixes: https://github.com/openai/triton/issues/2226

Here are some example renderings:
![Screenshot from 2023-09-04
18-39-20](https://github.com/openai/triton/assets/9093549/e9c4af04-aeae-4021-a8db-6a4a82b59ae7)
![Screenshot from 2023-09-04
18-39-30](https://github.com/openai/triton/assets/9093549/410391b8-e07e-4bed-909c-8ce5484072d1)
![Screenshot from 2023-09-04
18-39-41](https://github.com/openai/triton/assets/9093549/f1eaef95-66c1-4506-a153-c6069e2b5072)
2023-09-06 08:17:12 +00:00
Wang Weihan
e721911705 [FRONTEND] clean build directly when executing python setup.py clean (#2238)
Current setup.py could not clean the build directly because the default
build directly has been changed in `CMakeBuild`. This PR is to clean
build directly in this regard.
2023-09-04 21:31:38 -07:00
jon-chuang
99f8f912aa [OPS] Remove unnecessary perf bug workaround (#2240)
This bug previously existed and I verified it in previously nightly
release of triton (20230714).

However, according to new benchmarks, this bug no longer exists on
Triton main. See:
https://github.com/google/jax/pull/17328#issuecomment-1705010065
2023-09-04 21:30:54 -07:00
Keren Zhou
9e9fbe01f0 [FRONTEND] Fix specialization on triton integer types (#2236)
https://github.com/openai/triton/issues/2231
2023-09-03 23:57:08 -07:00
Shantanu
a4df60e20a [FRONTEND] Fix GIL handling in error conditions (#2225)
The use of the opaque GIL state APIs should mean that the
PyErr_SetString is now safe, regardless of whether the caller has the
GIL or not.
2023-09-01 13:30:42 -07:00
Ethan Pronovost
1367f3a6d2 [FRONTEND/OPS] wap stride_vn and stride_vk in flash attention (#2208)
I'm not sure if this was a typo or if I'm missing something. To me code
like
```
(offs_n[:, None] * stride_vk + offs_k[None, :] * stride_vn)
```
seems off.
In case this is a typo I made this PR to correct it. This PR should have
no functional changes.
If this is not a typo would you mind explaining the reasoning behind
these variable names?
2023-08-31 23:19:40 -07:00
Michael Melesse
c6d33dcebf [ROCM] Core Functionality for AMD (#1983)
* this pr adds a third party backend for triton that works on AMD 
* this expose a lot of the work that has been done in our
[fork](https://github.com/ROCmSoftwarePlatform/triton)
* most unit tests on `test_core.py` pass
* it skips some unit tests for various reasons
* we plan to follow up with more prs improving Functionality and
Performance in the future

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-08-31 14:02:00 -07:00
Philippe Tillet
ec51552fff [BACKEND] Lift restriction for float8e4b15 to only support row-col layout (#2212) 2023-08-30 14:06:31 -07:00
jon-chuang
9af76e7d5a [RUNTIME] Fix cache dir (#2196)
---------

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2023-08-29 21:07:16 -04:00
goostavz
1465b573e8 [TESTS][HOPPER] Prune hopper tests to speedup CI (#2193)
Co-authored-by: Goostav Zhu <gzhu@nvidia.com>
2023-08-27 20:45:23 -07:00
Philippe Tillet
5f448b2f08 [FRONTEND] remove dead libhopper_helpers.bc file (#2190) 2023-08-26 12:17:17 -07:00
Greg Brockman
a9b8c8c37d [FRONTEND] drop GIL for launch, and set value=false upon pointer error (#2185) 2023-08-26 17:07:57 +00:00
Keren Zhou
6e4932cda8 [BACKEND] Fix fma mixed-precision (#2184)
and expose the allow_tf32 argument to the matmul op

@shunting314
2023-08-26 09:49:58 -07:00
Greg Brockman
ab3e8b0dad [FRONTEND] fix handling of do_not_specialize with interior constantexprs (#2188) 2023-08-26 09:19:34 -07:00
Mohammed Anany
ebfe0ffb29 [FRONTEND] fix for undefined dtypes in jit during loading defaults (#2114)
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2023-08-25 10:28:23 -07:00
Ethan Pronovost
56fee37a0d [FRONTEND] Fix benchmark plotting (#2177) 2023-08-24 20:34:04 -07:00
Greg Brockman
64d8df4c69 [FRONTEND] handle errors from launch_enter_hook (#2178) 2023-08-24 20:32:01 -07:00
Shantanu
7083dae4f2 [FRONTEND] drop the GIL around more CUDA ops (#2173) 2023-08-24 20:31:38 -07:00
Zahi Moudallal
120ce0a5bf [DOCS] Fixing docs (#2175) 2023-08-24 15:58:59 -07:00
chengjunlu
6cb67185f8 [FRONTEND]To use proper default num_warps and num_stages based on the device backend in JITFucntion (#2130)
The default values used by JITFunction for num_warps and num_stages are
coupled with Nvidia GPU architecture. We should use the proper default
values based on the device backend for the kernel to be compiled to.
1. Add two functions to return the default num_warps and num_stages for
the specific device backend.
2. JITFunction uses the proper default num_warps and num_stages based on
the specific device backend.

Co-authored-by: Wang Weihan <eikan.wang@intel.com>
2023-08-24 21:58:18 +08:00
Bin Fan
dad83f9dcb [TOOLS] Add support for autotuning AOT kernel (#2123)
This PR makes the following change to AOT kernel

- Allow the client to generate AOT kernels with different sets of
constexprs and meta-parameters. Each combination of constexpr set and
meta-parameters is referred to an "algo". Within an algo client can
still give different hints about integer arguments.
- Add a API int ${kernle_name}_get_num_algos() that returns the total
number of algos.
- Add a algo_id to allow client to the generated kernel to select the
algo
- Remove gX, gY and gZ from the kernel parameter list. This is because
the launch grid is usually different with different algos, and the
client should not need to care about how to compute the launch grid for
each algo. Instead, we ask the client to pass the expression of
computing gX, gY and gZ for compile.py (when AOT kernels are generated).
The expression can only use kernel parameter or const values.
- We also change the testing flow. Now we first build the kernels into a
shared library libkernel.so, then the client test.c code is built and
link with libkernel.so. This is closer to a typical AOT kernel usage
flow.
2023-08-23 09:38:29 -07:00
Zahi Moudallal
5282ed890d [CI] Add back pre-commit to nvidia CI job (#2159) 2023-08-23 01:11:03 +00:00