Merge branch 'lidx0' into process_replay_limit

This commit is contained in:
qazal
2024-07-14 22:32:07 +03:00
2 changed files with 5 additions and 5 deletions

View File

@@ -619,7 +619,7 @@ class TestLinearizer(unittest.TestCase):
def test_grouped_dims(self):
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes):
# TODO: fix reverse_dims
idxs = get_grouped_dims(prefix, 0, dims, max_sizes)
idxs = get_grouped_dims(prefix, dims, max_sizes)
loop_idxs = dedup(flatten([[y for y in sorted(list(x.sparents)) if y.op is UOps.SPECIAL] for x in idxs]))
sizes = [x.arg[2] for x in loop_idxs]
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"

View File

@@ -53,11 +53,11 @@ else:
assert uvalid.dtype == dtypes.bool
return uidx, uvalid
def get_grouped_dims(prefix, start_dim, dims, max_sizes:Optional[Tuple[int, ...]]) -> List[UOp]:
def get_grouped_dims(prefix, dims, max_sizes:Optional[Tuple[int, ...]]) -> List[UOp]:
# TODO: this should be per dim max
maxdim = len(max_sizes) if max_sizes is not None else 0
local_idxs = [UOp(UOps.SPECIAL, dtypes.bigint, (),
(i, f"{prefix}{start_dim+i}", s)) for i,s in enumerate((prod(dims[:-(maxdim-1)]),) + dims[-(maxdim-1):] if len(dims) > maxdim else dims)]
(i, f"{prefix}{i}", s)) for i,s in enumerate((prod(dims[:-(maxdim-1)]),) + dims[-(maxdim-1):] if len(dims) > maxdim else dims)]
if maxdim != 0 and len(dims) > maxdim:
dd = local_idxs[0]
nli = []
@@ -85,8 +85,8 @@ class IndependentLowerer:
if opts.has_local:
# define indexes for GPU-like execution
self.idxs = get_grouped_dims("gidx", 0, full_shape[:global_dims], opts.global_max) + \
get_grouped_dims("lidx", global_dims, full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
self.idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max) + \
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
else:
# all loops are RANGES
self.idxs = [UOp(UOps.RANGE, dtypes.bigint, (UOp.const(dtypes.bigint, 0), variable_to_uop(g)), (i, False))