mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
remove _pending_assigns (#15040)
This commit is contained in:
@@ -1104,6 +1104,7 @@ class TestUOpBecome(unittest.TestCase):
|
||||
from tinygrad.helpers import all_same
|
||||
assert all_same([x.uop.base.realized for x in [a,b,c]])
|
||||
|
||||
@unittest.skip("not clear if we want this")
|
||||
def test_setitem_becomes_subbuffer(self):
|
||||
a = Tensor.full((4,), 2.).contiguous().realize()
|
||||
b = a.shrink(((0, 2),)).assign(Tensor.full((2,), 1.0))
|
||||
|
||||
@@ -609,8 +609,8 @@ class TestAssign(unittest.TestCase):
|
||||
x = q + caches[i][:1] # next layer also references the same CONTIGUOUS through q
|
||||
GlobalCounters.reset()
|
||||
caches[-1][:1].contiguous().realize()
|
||||
# 2 kernels for first assign + 3 per remaining assign (matmul, contiguous, assign) + 1 final read = 3*N
|
||||
self.assertEqual(GlobalCounters.kernel_count, 3*N)
|
||||
# N matmuls + N assigns + 1 final read = 2*N+1 (AFTER embedding allows full graph scheduling with shared contiguous reuse)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2*N+1)
|
||||
|
||||
|
||||
class TestAssignOrdering(unittest.TestCase):
|
||||
@@ -767,13 +767,12 @@ class TestAssignOrdering(unittest.TestCase):
|
||||
np.testing.assert_equal(b.numpy(), [1, 2, 3, 4])
|
||||
|
||||
def test_variable_slice_ordering(self):
|
||||
"""Variable-indexed slices - tests symbolic dependency tracking."""
|
||||
"""Variable-indexed slices - conflicting variable binds in same schedule are rejected."""
|
||||
v_i = Variable("i", 0, 3)
|
||||
buf = Tensor.zeros(4, 4).contiguous().realize()
|
||||
buf[v_i.bind(0):v_i.bind(0)+1, :].assign(Tensor.ones(1, 4))
|
||||
buf[v_i.bind(1):v_i.bind(1)+1, :].assign(Tensor.ones(1, 4) * 2)
|
||||
self.assertEqual(buf[0:1, :].sum().item(), 4)
|
||||
self.assertEqual(buf[1:2, :].sum().item(), 8)
|
||||
with self.assertRaises(RuntimeError): buf[0:1, :].sum().item()
|
||||
|
||||
def test_multi_step_assign_read_write_same_buffer(self):
|
||||
"""Assign to m and param reading b, then update b, across multiple steps.
|
||||
|
||||
@@ -25,7 +25,9 @@ def disk_copy_is_buffer(ctx:AllocCtx, u:UOp):
|
||||
if from_creation: return tag_uop(ctx, u)
|
||||
|
||||
def apply_after(ctx:AllocCtx, u:UOp):
|
||||
ctx.buffer_map[u] = u.src[0]
|
||||
base = u.src[0]
|
||||
while base.op is Ops.AFTER: base = base.src[0]
|
||||
ctx.buffer_map[u] = base
|
||||
|
||||
# CONTIGUOUS and ASSIGN + parents are the only nodes that get updated
|
||||
add_tags = PatternMatcher([
|
||||
@@ -54,7 +56,7 @@ def replace_contig_with_assign(u:UOp):
|
||||
|
||||
def replace_assign_with_contig(u:UOp):
|
||||
assigned_to = u
|
||||
while assigned_to.op in {Ops.ASSIGN, Ops.BITCAST}: assigned_to = assigned_to.src[0].base
|
||||
while assigned_to.op in {Ops.ASSIGN, Ops.BITCAST, Ops.AFTER}: assigned_to = assigned_to.src[0].base
|
||||
if assigned_to.op is not Ops.BUFFER:
|
||||
return u.src[1].contiguous(tag=u.tag)
|
||||
|
||||
@@ -74,8 +76,9 @@ pm_early_transform_tensor_graph = PatternMatcher([
|
||||
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
||||
# add CONTIGUOUS to tagged UOps
|
||||
(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), lambda x: x.rtag(None).contiguous(tag=x.tag) if x.tag else x.replace(tag=None)),
|
||||
# remove extra CONTIGUOUS on ASSIGN
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.ASSIGN, name="a"),), name="c"), lambda a,c: a.replace(tag=a.tag+c.tag)),
|
||||
# remove extra CONTIGUOUS on ASSIGN (only when assign target is contiguous)
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.ASSIGN, name="a"),), name="c"),
|
||||
lambda a,c: a.replace(tag=a.tag+c.tag) if a.src[0].has_buffer_identity() else None),
|
||||
# replace ASSIGN with CONTIGUOUS
|
||||
(UPat(Ops.ASSIGN, name="u"), replace_assign_with_contig),
|
||||
# replace CONTIGUOUS with ASSIGNs
|
||||
|
||||
@@ -25,8 +25,7 @@ def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]:
|
||||
# *** all in scope Tensors are here. this gets relevant UOps ***
|
||||
|
||||
all_tensors: dict[weakref.ref[Tensor], None] = {}
|
||||
_pending_assigns: dict[UOp, list[UOp]] = {} # buffer_uop -> [assign_uops in insertion order]
|
||||
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None:
|
||||
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str, walk:bool=False) -> None:
|
||||
with cpu_profile(TracingKey(name), "TINY"):
|
||||
# get tensors in scope
|
||||
in_scope: dict[UOp, bool] = {}
|
||||
@@ -35,7 +34,7 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None:
|
||||
|
||||
# get all Tensors and apply the map
|
||||
sink = UOp.sink(*[t.uop for t in scope_tensors])
|
||||
new_sink = sink.substitute(applied_map, name=f"substitute {name}")
|
||||
new_sink = sink.substitute(applied_map, name=f"substitute {name}", walk=walk)
|
||||
|
||||
# set the relevant uop to the realized UOps
|
||||
for t,s,ns in zip(scope_tensors, sink.src, new_sink.src):
|
||||
@@ -278,23 +277,6 @@ class Tensor(OpMixin):
|
||||
@disable_gc()
|
||||
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
|
||||
"""Triggers the computation needed to create these Tensor(s)."""
|
||||
# side-realize pending assigns for buffers referenced by these tensors
|
||||
if _pending_assigns:
|
||||
def _realize_pending(buf):
|
||||
for assign_uop in _pending_assigns.pop(buf, []):
|
||||
# recursively realize pending assigns that this assign's value depends on
|
||||
for u in assign_uop.toposort():
|
||||
if u.op is Ops.BUFFER and u in _pending_assigns: _realize_pending(u)
|
||||
big_sink, becomes_map = transform_to_call(UOp.sink(assign_uop))
|
||||
schedule, var_vals = complete_create_schedule_with_vars(big_sink)
|
||||
_apply_map_to_tensors(becomes_map, name="Apply Pending Assign")
|
||||
run_schedule(schedule, var_vals, do_update_stats=do_update_stats)
|
||||
# update remaining pending assigns so they reference realized buffers instead of stale lazy graphs
|
||||
if becomes_map:
|
||||
for assigns in _pending_assigns.values():
|
||||
for i in range(len(assigns)): assigns[i] = assigns[i].substitute(becomes_map)
|
||||
for buf in {u for t in (self,)+lst for u in t.uop.toposort() if u.op is Ops.BUFFER}:
|
||||
if buf in _pending_assigns: _realize_pending(buf)
|
||||
if len(to_realize:=[x for x in (self,)+lst if not x.uop.has_buffer_identity()]):
|
||||
run_schedule(*Tensor.schedule_with_vars(*to_realize), do_update_stats=do_update_stats)
|
||||
return self
|
||||
@@ -323,13 +305,13 @@ class Tensor(OpMixin):
|
||||
if is_disk:
|
||||
self._buffer().copyin(x._data())
|
||||
return self
|
||||
result = self._apply_uop(UOp.assign, x)
|
||||
# track view assigns (not full-buffer or assign-chain) so they can be side-realized when the buffer is read
|
||||
if (buf_uop:=self.uop.base).op is Ops.BUFFER and self.uop.op is not Ops.ASSIGN and not self.uop.has_buffer_identity():
|
||||
# deduplicate: if the value is already a pending assign for this buffer (e.g. __iadd__ in __setitem__), remove it
|
||||
if x.uop in _pending_assigns.get(buf_uop, []): _pending_assigns[buf_uop].remove(x.uop)
|
||||
_pending_assigns.setdefault(buf_uop, []).append(result.uop)
|
||||
return self.replace(result)
|
||||
# NOTE: assign_uop is created before AFTER embedding (uses original self.uop),
|
||||
# but AFTER must be embedded before _apply_uop (so subsequent assigns see it)
|
||||
assign_uop = self.uop.assign(x.uop)
|
||||
base = self.uop.base
|
||||
if base.op in {Ops.BUFFER, Ops.AFTER} and not self.uop.has_buffer_identity():
|
||||
_apply_map_to_tensors({base: base.after(assign_uop)}, name="Embed View Assign", walk=True)
|
||||
return self.replace(self._apply_uop(lambda *_: assign_uop, x))
|
||||
|
||||
def detach(self) -> Tensor:
|
||||
"""
|
||||
@@ -1351,8 +1333,10 @@ class Tensor(OpMixin):
|
||||
if is_disk: raise RuntimeError("advanced setitem is not supported for DISK tensors")
|
||||
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
|
||||
self.assign(self._getitem(indices, v))
|
||||
elif is_disk or self.uop.is_realized: # basic setitem, self is realized. TODO: disk uop.base is a COPY and not realized
|
||||
self[indices].assign(v)
|
||||
elif is_disk or self.uop.is_realized or self.uop.base.op is Ops.AFTER: # basic setitem, self is realized
|
||||
view = self[indices]
|
||||
if isinstance(v, Tensor) and v.uop.op is Ops.ASSIGN and v.uop in view.uop.base.src: return
|
||||
view.assign(v)
|
||||
else: # basic setitem, self is not realized
|
||||
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
|
||||
# __iadd__/__isub__ on unrealized views creates a no-op ASSIGN; unwrap to get the computed value
|
||||
|
||||
Reference in New Issue
Block a user