mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Add tl.expand_dims (#1614)
This exposes `semantic.expand_dims` in the public API and builds upon it
with support for expanding multiple dimensions at once. e.g.
```python
tl.expand_dims(tl.arange(0, N), (0, -1)) # shape = [1, N, 1]
```
Compared to indexing with `None`, this API is useful because the
dimensions can be constexpr values rather than hard-coded into the
source. As a basic example
```python
@triton.jit
def max_keepdim(value, dim):
res = tl.max(value, dim)
return tl.expand_dims(res, dim)
```
This commit is contained in:
@@ -34,6 +34,7 @@ Shape Manipulation Ops
|
||||
:nosignatures:
|
||||
|
||||
broadcast_to
|
||||
expand_dims
|
||||
reshape
|
||||
ravel
|
||||
|
||||
|
||||
Reference in New Issue
Block a user