mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
view assign replaces at buffer identity (#15298)
matches what functions capture
This commit is contained in:
@@ -213,6 +213,17 @@ class TestFunction(unittest.TestCase):
|
||||
np.testing.assert_equal(a.numpy(), [11,21,31]) # TODO: should be [1,2,3]
|
||||
np.testing.assert_equal(b.numpy(), [10,20,30])
|
||||
|
||||
def test_view_assign_explicit_buffer(self):
|
||||
"""view assign on an explicit param's buffer should not create implicit inputs."""
|
||||
class State:
|
||||
def __init__(self): self.buf = Tensor.zeros(2, 4).contiguous().realize()
|
||||
@function(allow_implicit=False)
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
self.buf[:, 0:2].assign(x)
|
||||
return self.buf[:, 0:2]
|
||||
s = State()
|
||||
np.testing.assert_equal(s(Tensor([[5., 6.], [7., 8.]])).numpy(), [[5., 6.], [7., 8.]])
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_assign_slice(self):
|
||||
@function
|
||||
|
||||
@@ -132,15 +132,9 @@ class TransformerBlock:
|
||||
q = apply_rope(q, self.freqs_cis[start_pos:start_pos+T])
|
||||
k = apply_rope(k, self.freqs_cis[start_pos:start_pos+T])
|
||||
|
||||
# TODO: fix assign to behave like this
|
||||
assigned_kv = self.cache_kv.uop.after(self.cache_kv[:, :, :, start_pos:start_pos+T, :].uop.assign(Tensor.stack(k, v).contiguous().uop))
|
||||
tensor_assigned_kv = Tensor(assigned_kv, device=assigned_kv.device)
|
||||
k = tensor_assigned_kv[0, :, :, 0:start_pos+T, :]
|
||||
v = tensor_assigned_kv[1, :, :, 0:start_pos+T, :]
|
||||
|
||||
#self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v))
|
||||
#k = self.cache_kv[0, :, :, 0:start_pos+T, :]
|
||||
#v = self.cache_kv[1, :, :, 0:start_pos+T, :]
|
||||
self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v))
|
||||
k = self.cache_kv[0, :, :, 0:start_pos+T, :]
|
||||
v = self.cache_kv[1, :, :, 0:start_pos+T, :]
|
||||
|
||||
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
|
||||
# TODO: this if statement should be removed and it shouldn't generate extra kernels
|
||||
|
||||
@@ -312,16 +312,15 @@ class Tensor(OpMixin):
|
||||
store_uop = self.uop.store(x.uop)
|
||||
base = self.uop.base
|
||||
if base.op in {Ops.BUFFER, Ops.AFTER} and self.uop is not base and not self.uop.has_buffer_identity():
|
||||
# view assign: inner AFTER(view, STORE) for correct shape/ranging, outer AFTER(base, inner) for dependency
|
||||
original_uop = self.uop
|
||||
view_after = self.uop.after(store_uop)
|
||||
assigned_base = base.after(view_after)
|
||||
_apply_map_to_tensors({base: assigned_base}, name="Embed View Assign", walk=True)
|
||||
# view assign: inner AFTER(view, STORE) for correct shape/ranging, outer AFTER(ib, inner) for dependency
|
||||
# replace at the buffer-identity level (e.g. RESHAPE(BUFFER)) so @function's substitution catches it
|
||||
ib = self.uop
|
||||
while not ib.has_buffer_identity() and ib is not base: ib = ib.src[0]
|
||||
assigned_ib = ib.after(self.uop.after(store_uop))
|
||||
_apply_map_to_tensors({ib: assigned_ib}, name="Embed View Assign", walk=True)
|
||||
def replace_view_base(u:UOp) -> UOp:
|
||||
return u.replace(src=((assigned_base if u.src[0] is base else replace_view_base(u.src[0])),)+u.src[1:])
|
||||
ret = Tensor(replace_view_base(original_uop), device=self.device, requires_grad=self.requires_grad)
|
||||
self.replace(self._apply_uop(lambda *_: replace_view_base(original_uop), x))
|
||||
return ret
|
||||
return u.replace(src=((assigned_ib if u.src[0] is ib else replace_view_base(u.src[0])),)+u.src[1:])
|
||||
return Tensor(replace_view_base(self.uop), device=self.device, requires_grad=self.requires_grad)
|
||||
# simple assign: AFTER wraps self.uop (may be RESHAPE'd buffer) with STORE effect
|
||||
return self.replace(self._apply_uop(lambda *_: self.uop.after(store_uop), x))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user