The default values used by JITFunction for num_warps and num_stages are
coupled with Nvidia GPU architecture. We should use the proper default
values based on the device backend for the kernel to be compiled to.
1. Add two functions to return the default num_warps and num_stages for
the specific device backend.
2. JITFunction uses the proper default num_warps and num_stages based on
the specific device backend.
Co-authored-by: Wang Weihan <eikan.wang@intel.com>
This PR makes the following change to AOT kernel
- Allow the client to generate AOT kernels with different sets of
constexprs and meta-parameters. Each combination of constexpr set and
meta-parameters is referred to an "algo". Within an algo client can
still give different hints about integer arguments.
- Add a API int ${kernle_name}_get_num_algos() that returns the total
number of algos.
- Add a algo_id to allow client to the generated kernel to select the
algo
- Remove gX, gY and gZ from the kernel parameter list. This is because
the launch grid is usually different with different algos, and the
client should not need to care about how to compute the launch grid for
each algo. Instead, we ask the client to pass the expression of
computing gX, gY and gZ for compile.py (when AOT kernels are generated).
The expression can only use kernel parameter or const values.
- We also change the testing flow. Now we first build the kernels into a
shared library libkernel.so, then the client test.c code is built and
link with libkernel.so. This is closer to a typical AOT kernel usage
flow.
1. Optimize the conversion and packing for 2xf32 -> 2xf16.
2. Split TMA store block into multiple slices of size 64x64.
3. Distribute the TMA store to all the warps.
4. Fix some naming issue.
For warp specialized persistent kernel, the instruction sequence for
Warp Groups are
```
// warp group 0
for wave in 0..num_waves:
idx = wave * num_inner_loop_steps;
for k_tile_idx in 0..num_k_tiles:
mbarrier.wait EB[idx];
W0;
mbarrier.arrive FB[idx];
idx++;
```
```
// warp group 1
for wave in 0..num_waves:
idx = wave * num_inner_loop_steps;
for k_tile_idx in 0..num_k_tiles:
mbarrier.wait FB[idx];
R0;
mbarrier.arrive EB[idx];
idx++;
```
then this would form a sequence of morally-strong relations W0 -> R0 ->
W1 -> R1 in causality order.
But if GEMM K is small than K-TileShape, then the num_inner_loop_steps
of persistent kernel is 0. The buffer id and mbarrier id will always be
0 in this case. And it may form W0 -> W1 -> R0 -> R1 order, which is
contradicts with the atomicity --
"If a read R precedes an overlapping write W in causality order, then R
cannot read from W."
`if _unwrap_if_constexpr(cond)` then enters `node.body` is wrong when
cond is a tensor since we cannot statically evaluate a dynamic tensor's
value.
The right way to solve the problem is probably:
1. visit the ast of IfExp (do not build IRs)
2. get the type of the last statement
3. initialize the return value and assign it to livein
4. call visit_If
Add a new operation to be able to implement packed inline assembly for
elementwise operations. This way inline assembly can be used to control
elementwise operations. It also allows to pack elements to be able to
manually vectorize operations.
Before this PR, the determination of `TritonGPUToLLVMIRPass` to generate
NVVM-compatible LLVM or ROCDL-compatible LLVM is controlled by a boolean
`isROCM`. This method is hard to scale.
This PR changes it to use an enum instead, where new target can be added
easily when needed.
---------
Signed-off-by: Tsang, Whitney <whitney.tsang@intel.com>
Co-authored-by: Philippe Tillet <phil@openai.com>
When I'm using Kaggle GPU (https://www.kaggle.com/), I find that
`ldconfig -p` does not show libcuda.so, but requires `ldconfig` (run
with sudo) to refresh the cache to find the libcuda.so.
Therefore, I added this informative message to help users find
libcuda.so.
- These minor fixes are not specific to interface changes from LLVM main
or official llvm-17 branch and can be applied on triton main branch.
- https://github.com/darkbuck/triton/tree/darkbuck/main/llvm-main-branch
has extra changes to build again LLVM main branch build to enable me to
work on other backends on the main branch only. That's the hobby effort
and just FYR.
maximum used to generate a cmp/sel even for floating point types. Always
using max op allows better code quality and avoids having different
behavior than tl.math.max
libtriton.so is pretty large these days and hashing it is slow.
Switching the hash from md5 to sha1 shaves close to 300ms off the time
for me (as well as being a better hash, for whatever that's worth).
As far as I could tell, sha1 is the fastest stable hash in the Python
standard library, including things like zlib.crc32