render lidx starting with 0 (#5478)

* render lidx starting with 0

changed from
```
  int gidx0 = gid.x; /* 4096 */
  int lidx4 = lid.x; /* 8 */
  int gidx1 = gid.y; /* 7 */
  int lidx5 = lid.y; /* 8 */
  int gidx2 = gid.z; /* 7 */
  int lidx6 = lid.z; /* 2 */
```
to
```
  int gidx0 = gid.x; /* 4096 */
  int lidx0 = lid.x; /* 8 */
  int gidx1 = gid.y; /* 7 */
  int lidx1 = lid.y; /* 8 */
  int gidx2 = gid.z; /* 7 */
  int lidx2 = lid.z; /* 2 */
```

the existing one started from pre-limited global dims which skip number if there are more than 3 global dims

* don't need start_dim

---------

Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
chenyu
2024-07-14 16:34:04 -04:00
committed by GitHub
parent 671779f280
commit 613a1dbeed
2 changed files with 5 additions and 5 deletions

View File

@@ -619,7 +619,7 @@ class TestLinearizer(unittest.TestCase):
def test_grouped_dims(self):
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes):
# TODO: fix reverse_dims
idxs = get_grouped_dims(prefix, 0, dims, max_sizes)
idxs = get_grouped_dims(prefix, dims, max_sizes)
loop_idxs = dedup(flatten([[y for y in sorted(list(x.sparents)) if y.op is UOps.SPECIAL] for x in idxs]))
sizes = [x.arg[2] for x in loop_idxs]
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"

View File

@@ -53,11 +53,11 @@ else:
assert uvalid.dtype == dtypes.bool
return uidx, uvalid
def get_grouped_dims(prefix, start_dim, dims, max_sizes:Optional[Tuple[int, ...]]) -> List[UOp]:
def get_grouped_dims(prefix, dims, max_sizes:Optional[Tuple[int, ...]]) -> List[UOp]:
# TODO: this should be per dim max
maxdim = len(max_sizes) if max_sizes is not None else 0
local_idxs = [UOp(UOps.SPECIAL, dtypes.bigint, (),
(i, f"{prefix}{start_dim+i}", s)) for i,s in enumerate((prod(dims[:-(maxdim-1)]),) + dims[-(maxdim-1):] if len(dims) > maxdim else dims)]
(i, f"{prefix}{i}", s)) for i,s in enumerate((prod(dims[:-(maxdim-1)]),) + dims[-(maxdim-1):] if len(dims) > maxdim else dims)]
if maxdim != 0 and len(dims) > maxdim:
dd = local_idxs[0]
nli = []
@@ -85,8 +85,8 @@ class IndependentLowerer:
if opts.has_local:
# define indexes for GPU-like execution
self.idxs = get_grouped_dims("gidx", 0, full_shape[:global_dims], opts.global_max) + \
get_grouped_dims("lidx", global_dims, full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
self.idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max) + \
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
else:
# all loops are RANGES
self.idxs = [UOp(UOps.RANGE, dtypes.bigint, (UOp.const(dtypes.bigint, 0), variable_to_uop(g)), (i, False))