Files
ROCm/python
peterbell10 e3d9478d31 [OPTIMIZER] Add pass to move broadcasts after elementwise operations (#1811)
This adds a pass that tries to reduce the shape of tensor arguments to
element-wise operations by moving splat and broadcast operations later
in the graph. So, for example say we have:

```python
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset  + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (0))
    tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
    tmp2 = 0.017453292519943295
    tmp3 = tmp1 * tmp2
    tmp4 = tl.sin(tmp3)
    tl.store(out_ptr0 + (x0), tmp4, None)
```

Today this results in duplicate `sin` calls:
```
    %27 = llvm.fmul %26, %3  : f32
    %28 = llvm.call @__nv_sinf(%27) : (f32) -> f32
    %29 = llvm.call @__nv_sinf(%27) : (f32) -> f32
```

The duplicate `llvm.fmul` calls are eliminated via CSE, but `llvm.call`
doesn't get CSE'd because it might be impure.

After this change, the sin is done on a scalar value in the triton IR
and splatted at the very end, so no duplicate calculation happens within
a thread.

---------

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Philippe Tillet <phil@openai.com>
2023-07-10 11:44:38 -07:00
..