[FRONTEND] Add tl.expand_dims (#1614)

This exposes `semantic.expand_dims` in the public API and builds upon it
with support for expanding multiple dimensions at once. e.g.
```python
tl.expand_dims(tl.arange(0, N), (0, -1))  # shape = [1, N, 1]
```

Compared to indexing with `None`, this API is useful because the
dimensions can be constexpr values rather than hard-coded into the
source. As a basic example
```python
@triton.jit
def max_keepdim(value, dim):
    res = tl.max(value, dim)
    return tl.expand_dims(res, dim)
```
This commit is contained in:
peterbell10
2023-05-04 17:46:24 +01:00
committed by GitHub
parent f387a6c863
commit deb2c71fb4
4 changed files with 122 additions and 4 deletions

View File

@@ -34,6 +34,7 @@ Shape Manipulation Ops
:nosignatures:
broadcast_to
expand_dims
reshape
ravel