mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
`arith::SelectOp` supports a form where the condition argument is a
scalar and the result is a tensor. This isn't generated from `tl.where`,
but can still show up from canonicalization of `scf.if`.
Currently if this happens, the conversion to gpu IR will fail because
`triton_gpu.select` doesn't support this form. For example,
```python
import triton
import triton.language as tl
import torch
@triton.jit
def _triton_test(
in_ptr, out_ptr, cond, XBLOCK: tl.constexpr
):
xindex = tl.arange(0, XBLOCK)
tmp = tl.load(in_ptr + xindex)
if cond:
a = tl.zeros_like(tmp)
else:
a = tmp
tl.store(out_ptr + xindex, a)
t = torch.randn(128, device="cuda")
out = torch.empty(128, device="cuda")
_triton_test[(1,)](t, out, True, t.numel())
```
Fails with the error
```
error: 'triton_gpu.select' op requires the same shape for all operands and results
```
Co-authored-by: Keren Zhou <kerenzhou@openai.com>