mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
add failing assign test (#3796)
* that was a hack * tests to reveal the issue * add assign for realized assign
This commit is contained in:
@@ -84,6 +84,29 @@ class TestAssign(unittest.TestCase):
|
||||
for _ in range(4): f(y)
|
||||
assert y.item() == 4
|
||||
|
||||
def test_assign_changes(self):
|
||||
a = Tensor.ones(4).contiguous().realize()
|
||||
old_a = a
|
||||
a.assign(Tensor.full((4,), 2.).contiguous())
|
||||
# NOTE: old_a is now 2, and this would match the behavior of pytorch
|
||||
new = a + old_a
|
||||
np.testing.assert_allclose(new.numpy(), 4)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_assign_diamond(self):
|
||||
a = Tensor.ones(4).contiguous().realize()
|
||||
times_a = a*3
|
||||
a.assign(Tensor.full((4,), 2.).contiguous())
|
||||
new = a + times_a
|
||||
np.testing.assert_allclose(new.numpy(), 5)
|
||||
|
||||
def test_assign_diamond_alt(self):
|
||||
a = Tensor.ones(4).contiguous().realize()
|
||||
a.assign(Tensor.full((4,), 2.).contiguous())
|
||||
times_a = a*3
|
||||
new = a + times_a
|
||||
np.testing.assert_allclose(new.numpy(), 8)
|
||||
|
||||
def test_assign_kv_cache(self):
|
||||
bsz, max_context = 2, 8
|
||||
|
||||
|
||||
@@ -59,10 +59,7 @@ class LazyBuffer:
|
||||
shape = self.shape if shape is None else shape
|
||||
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=cast_scalar(val, self.dtype)).reshape((1,)*len(shape)).expand(shape)
|
||||
|
||||
def assign(self, x:LazyBuffer) -> LazyBuffer:
|
||||
if self.base.realized is not None or self is not self.base: new_self = self
|
||||
else: new_self = create_lazybuffer(self.device, self.st, self.dtype, self.op, self.arg, self.srcs, enable_cache=False)
|
||||
return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, src=(x, new_self))
|
||||
def assign(self, x:LazyBuffer) -> LazyBuffer: return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, src=(x, self))
|
||||
def contiguous(self):
|
||||
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
|
||||
ret = self.e(LoadOps.CONTIGUOUS)
|
||||
|
||||
@@ -109,6 +109,7 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Var
|
||||
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
|
||||
if buf.op in {LoadOps.CONTIGUOUS, LoadOps.ASSIGN}:
|
||||
assert first
|
||||
assert buf.op is not LoadOps.ASSIGN or buf.srcs[1].base.realized is not None, "assign must be already realized to schedule"
|
||||
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False,
|
||||
assign_to=buf.srcs[1].base if buf.op is LoadOps.ASSIGN else None)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user