view assign replaces at buffer identity (#15298)

matches what functions capture
This commit is contained in:
chenyu
2026-03-16 19:58:38 -04:00
committed by GitHub
parent 346596cdce
commit 3e2b7803e6
3 changed files with 22 additions and 18 deletions

View File

@@ -213,6 +213,17 @@ class TestFunction(unittest.TestCase):
np.testing.assert_equal(a.numpy(), [11,21,31]) # TODO: should be [1,2,3]
np.testing.assert_equal(b.numpy(), [10,20,30])
def test_view_assign_explicit_buffer(self):
"""view assign on an explicit param's buffer should not create implicit inputs."""
class State:
def __init__(self): self.buf = Tensor.zeros(2, 4).contiguous().realize()
@function(allow_implicit=False)
def __call__(self, x:Tensor) -> Tensor:
self.buf[:, 0:2].assign(x)
return self.buf[:, 0:2]
s = State()
np.testing.assert_equal(s(Tensor([[5., 6.], [7., 8.]])).numpy(), [[5., 6.], [7., 8.]])
@unittest.expectedFailure
def test_assign_slice(self):
@function

View File

@@ -132,15 +132,9 @@ class TransformerBlock:
q = apply_rope(q, self.freqs_cis[start_pos:start_pos+T])
k = apply_rope(k, self.freqs_cis[start_pos:start_pos+T])
# TODO: fix assign to behave like this
assigned_kv = self.cache_kv.uop.after(self.cache_kv[:, :, :, start_pos:start_pos+T, :].uop.assign(Tensor.stack(k, v).contiguous().uop))
tensor_assigned_kv = Tensor(assigned_kv, device=assigned_kv.device)
k = tensor_assigned_kv[0, :, :, 0:start_pos+T, :]
v = tensor_assigned_kv[1, :, :, 0:start_pos+T, :]
#self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v))
#k = self.cache_kv[0, :, :, 0:start_pos+T, :]
#v = self.cache_kv[1, :, :, 0:start_pos+T, :]
self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v))
k = self.cache_kv[0, :, :, 0:start_pos+T, :]
v = self.cache_kv[1, :, :, 0:start_pos+T, :]
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
# TODO: this if statement should be removed and it shouldn't generate extra kernels

View File

@@ -312,16 +312,15 @@ class Tensor(OpMixin):
store_uop = self.uop.store(x.uop)
base = self.uop.base
if base.op in {Ops.BUFFER, Ops.AFTER} and self.uop is not base and not self.uop.has_buffer_identity():
# view assign: inner AFTER(view, STORE) for correct shape/ranging, outer AFTER(base, inner) for dependency
original_uop = self.uop
view_after = self.uop.after(store_uop)
assigned_base = base.after(view_after)
_apply_map_to_tensors({base: assigned_base}, name="Embed View Assign", walk=True)
# view assign: inner AFTER(view, STORE) for correct shape/ranging, outer AFTER(ib, inner) for dependency
# replace at the buffer-identity level (e.g. RESHAPE(BUFFER)) so @function's substitution catches it
ib = self.uop
while not ib.has_buffer_identity() and ib is not base: ib = ib.src[0]
assigned_ib = ib.after(self.uop.after(store_uop))
_apply_map_to_tensors({ib: assigned_ib}, name="Embed View Assign", walk=True)
def replace_view_base(u:UOp) -> UOp:
return u.replace(src=((assigned_base if u.src[0] is base else replace_view_base(u.src[0])),)+u.src[1:])
ret = Tensor(replace_view_base(original_uop), device=self.device, requires_grad=self.requires_grad)
self.replace(self._apply_uop(lambda *_: replace_view_base(original_uop), x))
return ret
return u.replace(src=((assigned_ib if u.src[0] is ib else replace_view_base(u.src[0])),)+u.src[1:])
return Tensor(replace_view_base(self.uop), device=self.device, requires_grad=self.requires_grad)
# simple assign: AFTER wraps self.uop (may be RESHAPE'd buffer) with STORE effect
return self.replace(self._apply_uop(lambda *_: self.uop.after(store_uop), x))