mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Merge branch 'lidx0' into process_replay_limit
This commit is contained in:
@@ -619,7 +619,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
def test_grouped_dims(self):
|
||||
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes):
|
||||
# TODO: fix reverse_dims
|
||||
idxs = get_grouped_dims(prefix, 0, dims, max_sizes)
|
||||
idxs = get_grouped_dims(prefix, dims, max_sizes)
|
||||
loop_idxs = dedup(flatten([[y for y in sorted(list(x.sparents)) if y.op is UOps.SPECIAL] for x in idxs]))
|
||||
sizes = [x.arg[2] for x in loop_idxs]
|
||||
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
|
||||
|
||||
@@ -53,11 +53,11 @@ else:
|
||||
assert uvalid.dtype == dtypes.bool
|
||||
return uidx, uvalid
|
||||
|
||||
def get_grouped_dims(prefix, start_dim, dims, max_sizes:Optional[Tuple[int, ...]]) -> List[UOp]:
|
||||
def get_grouped_dims(prefix, dims, max_sizes:Optional[Tuple[int, ...]]) -> List[UOp]:
|
||||
# TODO: this should be per dim max
|
||||
maxdim = len(max_sizes) if max_sizes is not None else 0
|
||||
local_idxs = [UOp(UOps.SPECIAL, dtypes.bigint, (),
|
||||
(i, f"{prefix}{start_dim+i}", s)) for i,s in enumerate((prod(dims[:-(maxdim-1)]),) + dims[-(maxdim-1):] if len(dims) > maxdim else dims)]
|
||||
(i, f"{prefix}{i}", s)) for i,s in enumerate((prod(dims[:-(maxdim-1)]),) + dims[-(maxdim-1):] if len(dims) > maxdim else dims)]
|
||||
if maxdim != 0 and len(dims) > maxdim:
|
||||
dd = local_idxs[0]
|
||||
nli = []
|
||||
@@ -85,8 +85,8 @@ class IndependentLowerer:
|
||||
|
||||
if opts.has_local:
|
||||
# define indexes for GPU-like execution
|
||||
self.idxs = get_grouped_dims("gidx", 0, full_shape[:global_dims], opts.global_max) + \
|
||||
get_grouped_dims("lidx", global_dims, full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
|
||||
self.idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max) + \
|
||||
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
|
||||
else:
|
||||
# all loops are RANGES
|
||||
self.idxs = [UOp(UOps.RANGE, dtypes.bigint, (UOp.const(dtypes.bigint, 0), variable_to_uop(g)), (i, False))
|
||||
|
||||
Reference in New Issue
Block a user