render lidx starting with 0 (#5478)

* render lidx starting with 0

changed from
```
  int gidx0 = gid.x; /* 4096 */
  int lidx4 = lid.x; /* 8 */
  int gidx1 = gid.y; /* 7 */
  int lidx5 = lid.y; /* 8 */
  int gidx2 = gid.z; /* 7 */
  int lidx6 = lid.z; /* 2 */
```
to
```
  int gidx0 = gid.x; /* 4096 */
  int lidx0 = lid.x; /* 8 */
  int gidx1 = gid.y; /* 7 */
  int lidx1 = lid.y; /* 8 */
  int gidx2 = gid.z; /* 7 */
  int lidx2 = lid.z; /* 2 */
```

the existing one started from pre-limited global dims which skip number if there are more than 3 global dims

* don't need start_dim

---------

Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
chenyu
2024-07-14 16:34:04 -04:00
committed by GitHub
parent 671779f280
commit 613a1dbeed
2 changed files with 5 additions and 5 deletions

View File

@@ -619,7 +619,7 @@ class TestLinearizer(unittest.TestCase):
def test_grouped_dims(self):
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes):
# TODO: fix reverse_dims
idxs = get_grouped_dims(prefix, 0, dims, max_sizes)
idxs = get_grouped_dims(prefix, dims, max_sizes)
loop_idxs = dedup(flatten([[y for y in sorted(list(x.sparents)) if y.op is UOps.SPECIAL] for x in idxs]))
sizes = [x.arg[2] for x in loop_idxs]
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"