mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
This exposes `semantic.expand_dims` in the public API and builds upon it
with support for expanding multiple dimensions at once. e.g.
```python
tl.expand_dims(tl.arange(0, N), (0, -1)) # shape = [1, N, 1]
```
Compared to indexing with `None`, this API is useful because the
dimensions can be constexpr values rather than hard-coded into the
source. As a basic example
```python
@triton.jit
def max_keepdim(value, dim):
res = tl.max(value, dim)
return tl.expand_dims(res, dim)
```