mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
I've add an option to yapf to do what we want for long lines, see https://github.com/google/yapf/pull/1177. We can now have a real Python formatter, yay! To make this PR, I ran my modified yapf over the repository, then looked over the full diff. Where yapf was mangling the param list of long function decls/calls (mostly kernels), I manually added `#` to put linebreaks where we want. I fixed up other formatting too -- mostly adding or removing a trailing comma from lists. Overall, trailing `#` was sufficient to get formatting similar to our current code. I didn't have to disable yapf anywhere. --------- Co-authored-by: Phil Tillet <phil@openai.com>
19 lines
537 B
Python
19 lines
537 B
Python
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
# triton kernel
|
|
@triton.jit
|
|
def kernel(X, stride_xm, #
|
|
Z, stride_zn, #
|
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
|
off_m = tl.arange(0, BLOCK_M)
|
|
off_n = tl.arange(0, BLOCK_N)
|
|
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * 1
|
|
Zs = Z + off_m[:, None] * 1 + off_n[None, :] * stride_zn
|
|
tl.store(Zs, tl.load(Xs))
|
|
|
|
|
|
ret = triton.compile(kernel, signature="*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64})
|
|
print(ret.asm["ttgir"])
|