mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
This adds a pass that tries to reduce the shape of tensor arguments to
element-wise operations by moving splat and broadcast operations later
in the graph. So, for example say we have:
```python
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
tmp2 = 0.017453292519943295
tmp3 = tmp1 * tmp2
tmp4 = tl.sin(tmp3)
tl.store(out_ptr0 + (x0), tmp4, None)
```
Today this results in duplicate `sin` calls:
```
%27 = llvm.fmul %26, %3 : f32
%28 = llvm.call @__nv_sinf(%27) : (f32) -> f32
%29 = llvm.call @__nv_sinf(%27) : (f32) -> f32
```
The duplicate `llvm.fmul` calls are eliminated via CSE, but `llvm.call`
doesn't get CSE'd because it might be impure.
After this change, the sin is done on a scalar value in the triton IR
and splatted at the very end, so no duplicate calculation happens within
a thread.
---------
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Philippe Tillet <phil@openai.com>
The file is empty.