mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
This reverts commit df7c37f611.
This commit is contained in:
@@ -745,7 +745,6 @@ class TestAssignOrdering(unittest.TestCase):
|
||||
np.testing.assert_equal(a.numpy(), [5, 6, 7, 8])
|
||||
np.testing.assert_equal(b.numpy(), [1, 2, 3, 4])
|
||||
|
||||
@unittest.expectedFailure # NOTE: we don't support binding to two different values in one schedule
|
||||
def test_variable_slice_ordering(self):
|
||||
"""Variable-indexed slices - tests symbolic dependency tracking."""
|
||||
v_i = Variable("i", 0, 3)
|
||||
|
||||
@@ -271,34 +271,24 @@ class Tensor(OpMixin):
|
||||
@disable_gc()
|
||||
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
|
||||
"""Triggers the computation needed to create these Tensor(s)."""
|
||||
schedules: list[tuple[list[ExecItem], dict[str, int]]] = []
|
||||
# collect pending assigns for relevant buffers
|
||||
# side-realize pending assigns for buffers referenced by these tensors
|
||||
if _pending_assigns:
|
||||
def _collect_pending(buf:UOp):
|
||||
def _realize_pending(buf):
|
||||
for assign_uop in _pending_assigns.pop(buf, []):
|
||||
# recursively realize pending assigns that this assign's value depends on
|
||||
for u in assign_uop.toposort():
|
||||
if u.op is Ops.BUFFER and u in _pending_assigns: _collect_pending(u)
|
||||
if u.op is Ops.BUFFER and u in _pending_assigns: _realize_pending(u)
|
||||
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(assign_uop))
|
||||
_apply_map_to_tensors(becomes_map, name="Apply Pending Assign")
|
||||
schedules.append((schedule, var_vals))
|
||||
run_schedule(schedule, var_vals, do_update_stats=do_update_stats)
|
||||
# update remaining pending assigns so they reference realized buffers instead of stale lazy graphs
|
||||
if becomes_map:
|
||||
for assigns in _pending_assigns.values():
|
||||
for i in range(len(assigns)): assigns[i] = assigns[i].substitute(becomes_map)
|
||||
for buf in {u for t in (self,)+lst for u in t.uop.toposort() if u.op is Ops.BUFFER}:
|
||||
if buf in _pending_assigns: _collect_pending(buf)
|
||||
|
||||
if buf in _pending_assigns: _realize_pending(buf)
|
||||
if len(to_realize:=[x for x in (self,)+lst if not x.uop.has_buffer_identity()]):
|
||||
schedules.append(Tensor.schedule_with_vars(*to_realize))
|
||||
|
||||
exec_items: list[ExecItem] = []
|
||||
merged_var_vals: dict[str, int] = {}
|
||||
for schedule, var_vals in schedules:
|
||||
exec_items.extend(schedule)
|
||||
merged_var_vals = merge_dicts((merged_var_vals, var_vals))
|
||||
|
||||
run_schedule(exec_items, merged_var_vals, do_update_stats=do_update_stats)
|
||||
run_schedule(*Tensor.schedule_with_vars(*to_realize), do_update_stats=do_update_stats)
|
||||
return self
|
||||
|
||||
def replace(self, x:Tensor) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user