Revert "one run_schedule for assign realize (#14835)" (#14837)

This reverts commit df7c37f611.
This commit is contained in:
chenyu
2026-02-17 14:34:26 -05:00
committed by GitHub
parent df7c37f611
commit aec8a6c85b
2 changed files with 6 additions and 17 deletions

View File

@@ -745,7 +745,6 @@ class TestAssignOrdering(unittest.TestCase):
np.testing.assert_equal(a.numpy(), [5, 6, 7, 8])
np.testing.assert_equal(b.numpy(), [1, 2, 3, 4])
@unittest.expectedFailure # NOTE: we don't support binding to two different values in one schedule
def test_variable_slice_ordering(self):
"""Variable-indexed slices - tests symbolic dependency tracking."""
v_i = Variable("i", 0, 3)

View File

@@ -271,34 +271,24 @@ class Tensor(OpMixin):
@disable_gc()
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
"""Triggers the computation needed to create these Tensor(s)."""
schedules: list[tuple[list[ExecItem], dict[str, int]]] = []
# collect pending assigns for relevant buffers
# side-realize pending assigns for buffers referenced by these tensors
if _pending_assigns:
def _collect_pending(buf:UOp):
def _realize_pending(buf):
for assign_uop in _pending_assigns.pop(buf, []):
# recursively realize pending assigns that this assign's value depends on
for u in assign_uop.toposort():
if u.op is Ops.BUFFER and u in _pending_assigns: _collect_pending(u)
if u.op is Ops.BUFFER and u in _pending_assigns: _realize_pending(u)
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(assign_uop))
_apply_map_to_tensors(becomes_map, name="Apply Pending Assign")
schedules.append((schedule, var_vals))
run_schedule(schedule, var_vals, do_update_stats=do_update_stats)
# update remaining pending assigns so they reference realized buffers instead of stale lazy graphs
if becomes_map:
for assigns in _pending_assigns.values():
for i in range(len(assigns)): assigns[i] = assigns[i].substitute(becomes_map)
for buf in {u for t in (self,)+lst for u in t.uop.toposort() if u.op is Ops.BUFFER}:
if buf in _pending_assigns: _collect_pending(buf)
if buf in _pending_assigns: _realize_pending(buf)
if len(to_realize:=[x for x in (self,)+lst if not x.uop.has_buffer_identity()]):
schedules.append(Tensor.schedule_with_vars(*to_realize))
exec_items: list[ExecItem] = []
merged_var_vals: dict[str, int] = {}
for schedule, var_vals in schedules:
exec_items.extend(schedule)
merged_var_vals = merge_dicts((merged_var_vals, var_vals))
run_schedule(exec_items, merged_var_vals, do_update_stats=do_update_stats)
run_schedule(*Tensor.schedule_with_vars(*to_realize), do_update_stats=do_update_stats)
return self
def replace(self, x:Tensor) -> Tensor: