diff --git a/test/unit/test_assign.py b/test/unit/test_assign.py index d8dca74602..eb1f97a6cf 100644 --- a/test/unit/test_assign.py +++ b/test/unit/test_assign.py @@ -745,7 +745,6 @@ class TestAssignOrdering(unittest.TestCase): np.testing.assert_equal(a.numpy(), [5, 6, 7, 8]) np.testing.assert_equal(b.numpy(), [1, 2, 3, 4]) - @unittest.expectedFailure # NOTE: we don't support binding to two different values in one schedule def test_variable_slice_ordering(self): """Variable-indexed slices - tests symbolic dependency tracking.""" v_i = Variable("i", 0, 3) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index ac927b9a85..f52228a0cd 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -271,34 +271,24 @@ class Tensor(OpMixin): @disable_gc() def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor: """Triggers the computation needed to create these Tensor(s).""" - schedules: list[tuple[list[ExecItem], dict[str, int]]] = [] - # collect pending assigns for relevant buffers + # side-realize pending assigns for buffers referenced by these tensors if _pending_assigns: - def _collect_pending(buf:UOp): + def _realize_pending(buf): for assign_uop in _pending_assigns.pop(buf, []): # recursively realize pending assigns that this assign's value depends on for u in assign_uop.toposort(): - if u.op is Ops.BUFFER and u in _pending_assigns: _collect_pending(u) + if u.op is Ops.BUFFER and u in _pending_assigns: _realize_pending(u) becomes_map, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(assign_uop)) _apply_map_to_tensors(becomes_map, name="Apply Pending Assign") - schedules.append((schedule, var_vals)) + run_schedule(schedule, var_vals, do_update_stats=do_update_stats) # update remaining pending assigns so they reference realized buffers instead of stale lazy graphs if becomes_map: for assigns in _pending_assigns.values(): for i in range(len(assigns)): assigns[i] = assigns[i].substitute(becomes_map) for buf in {u for t in (self,)+lst for u in t.uop.toposort() if u.op is Ops.BUFFER}: - if buf in _pending_assigns: _collect_pending(buf) - + if buf in _pending_assigns: _realize_pending(buf) if len(to_realize:=[x for x in (self,)+lst if not x.uop.has_buffer_identity()]): - schedules.append(Tensor.schedule_with_vars(*to_realize)) - - exec_items: list[ExecItem] = [] - merged_var_vals: dict[str, int] = {} - for schedule, var_vals in schedules: - exec_items.extend(schedule) - merged_var_vals = merge_dicts((merged_var_vals, var_vals)) - - run_schedule(exec_items, merged_var_vals, do_update_stats=do_update_stats) + run_schedule(*Tensor.schedule_with_vars(*to_realize), do_update_stats=do_update_stats) return self def replace(self, x:Tensor) -> Tensor: