Files
tinygrad/test/unit/test_assign.py
chenyu 0c63f63ee4 recursive resolve assign dependency (#14688)
remove the .realize in llm.py
2026-02-11 17:41:05 -05:00

769 lines
30 KiB
Python

#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad import dtypes, Tensor, TinyJit, GlobalCounters, Variable
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import temp, CI, CPU_LVP, Context
N = 200 # has to be bigger than the cache to fail
class TestAssign(unittest.TestCase):
def test_simple_assignment(self):
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
a.realize()
b.realize()
ba1 = a.uop.base.realized
bb1 = b.uop.base.realized
a += b
a.realize()
ba2 = a.uop.base.realized
assert ba1 == ba2 and ba1 != bb1
np.testing.assert_allclose(a.numpy(), (np.arange(N*N)*2).reshape((N,N)))
def test_assign_zeros_good(self):
a = Tensor.zeros(10,10).contiguous()
a.assign(Tensor.ones(10,10))
b = Tensor.zeros(10,10).contiguous()
a.realize()
np.testing.assert_allclose(b.numpy(), 0)
def test_assign_zeros(self):
a = Tensor.zeros(10,10).contiguous()
b = Tensor.zeros(10,10).contiguous()
a.assign(Tensor.ones(10,10))
a.realize()
np.testing.assert_allclose(b.numpy(), 0)
def test_assign_add(self):
def f(x):
x += 1
x.realize()
x = Tensor([0])
f(x)
assert x.item() == 1
def test_assign_add_twice(self):
# NOTE: this has two kernels
def f(x):
x += 1
x += 1
x.realize()
x = Tensor([0])
f(x)
assert x.item() == 2
def test_assign_add_double(self):
def f(x):
x += 1
x.realize()
x = Tensor([0])
f(x)
out = x.item()
assert out == 1, f"expected 1, got {out}"
x = Tensor([0])
f(x)
out = x.item()
assert out == 1, f"expected 1, got {out}"
def test_assign_add_jit(self):
@TinyJit
def f(x):
x += 1
x.realize()
x = Tensor([0])
for _ in range(5): f(x)
assert x.item() == 5
def test_assign_add_jit_other(self):
@TinyJit
def f(x):
x += 1
x.realize()
x = Tensor([0])
for _ in range(5): f(x)
assert x.item() == 5
y = Tensor([0])
for _ in range(4): f(y)
assert y.item() == 4
def test_assign_other_jit(self):
@TinyJit
def f(x, a):
x.assign(a)
x.realize()
x = Tensor([0])
for i in range(1, 6):
f(x, x.full_like(i).contiguous()) # const would be implicitly folded without contiguous
assert x.item() == i
def test_assign_add_other_jit(self):
@TinyJit
def f(x, a):
x += a
x.realize()
x = Tensor([0])
a = 0
for i in range(1, 6):
a += i
f(x, x.full_like(i).contiguous())
assert x.item() == a
def test_assign_changes(self):
a = Tensor.ones(4).contiguous().realize()
old_a = a
a.assign(Tensor.full((4,), 2.).contiguous())
# NOTE: old_a is now 2, and this would match the behavior of pytorch
new = a + old_a
np.testing.assert_allclose(new.numpy(), 4)
def test_assign_changes_alt(self, realize=False):
a = Tensor(1).contiguous()
if realize: a.realize()
b = a.contiguous() # b returns a new Tensor
b.assign(2)
b.realize()
self.assertNotEqual(a.item(), b.item())
# on a realized Tensor contiguous child changes the source
@unittest.expectedFailure
def test_assign_changes_realized_alt(self): return self.test_assign_changes_alt(realize=True)
@unittest.skip("assign to contiguous shouldn't change the base buffer")
def test_assign_changes_buffer_alt(self):
a, b = [Tensor(Tensor(0).contiguous().realize().uop.buf_uop) for _ in range(2)]
Tensor.realize(a.contiguous().assign(1), b.contiguous().assign(2))
self.assertEqual((a + b).item(), 3)
def test_assign_diamond_cycle(self):
# NOTE: should *not* raise AssertionError from numpy
with self.assertRaisesRegex(RuntimeError, "cycle"):
a = Tensor.ones(4).contiguous().realize()
times_a = a*3
a.assign(Tensor.full((4,), 2.).contiguous())
new = a + (times_a-1)
np.testing.assert_allclose(new.numpy(), 4)
def test_assign_diamond_contiguous_cycle(self):
with self.assertRaisesRegex(RuntimeError, "cycle"):
a = Tensor.ones(4).contiguous().realize()
times_a = a*3
a.assign(Tensor.full((4,), 2.))
new = a.contiguous() + times_a-1
np.testing.assert_allclose(new.numpy(), 4)
def test_assign_diamond_possible(self):
a = Tensor.ones(4).contiguous().realize()
times_a = a*3
a.assign(Tensor.full((4,), 2.))
new = a + (times_a-1).contiguous()
np.testing.assert_allclose(new.numpy(), 4)
def test_assign_diamond_possible_contiguous(self):
a = Tensor.ones(4).contiguous().realize()
times_a = a*3
a.assign(Tensor.full((4,), 2.).contiguous())
new = a + (times_a-1).contiguous()
np.testing.assert_allclose(new.numpy(), 4)
def test_assign_diamond_both_contiguous(self):
a = Tensor.ones(4).contiguous().realize()
times_a = a*3
a.assign(Tensor.full((4,), 2.))
new = a.contiguous() + (times_a-1).contiguous()
np.testing.assert_allclose(new.numpy(), 4)
def test_assign_diamond_alt(self):
a = Tensor.ones(4).contiguous().realize()
a.assign(Tensor.full((4,), 2.).contiguous())
times_a = a*3
new = a + times_a
np.testing.assert_allclose(new.numpy(), 8)
@unittest.skipIf(CI and CPU_LVP, "flaky in CI")
def test_double_assign(self):
a = Tensor.ones(4).contiguous().realize()
a += 1
a += 1
np.testing.assert_allclose(a.numpy(), 3)
def test_crossover_assign(self):
a = Tensor.full((4,), 2).contiguous().realize()
b = Tensor.full((4,), 3).contiguous().realize()
a += b
b += a
Tensor.realize(a,b)
np.testing.assert_allclose(a.numpy(), 5)
np.testing.assert_allclose(b.numpy(), 8)
def test_assign_double_diamond(self):
a = Tensor.full((4,), 2).contiguous().realize()
b = Tensor.full((4,), 3).contiguous().realize()
a_prev = a*4
b_prev = b+3
b += a_prev.contiguous()
a += b_prev.contiguous()
Tensor.realize(a, b)
np.testing.assert_equal(b.numpy(), 11)
np.testing.assert_equal(a.numpy(), 8)
def test_assign_double_diamond_reduce(self):
a0 = Tensor.full((16, 16), 10).contiguous().realize()
a1 = Tensor.full((16, 16), 20).contiguous().realize()
b0 = Tensor.full((16, ), 1).contiguous().realize()
b1 = Tensor.full((16, ), 2).contiguous().realize()
r0 = (a0 - b1.contiguous()).sum(1)
r1 = (a1 - b0.contiguous()).sum(1)
b0.assign(r0 * b0)
b1.assign(r1 * b1)
Tensor.realize(b0, b1)
np.testing.assert_equal(b0.numpy(), 128)
np.testing.assert_equal(b1.numpy(), 608)
@unittest.skip("TODO: bring this assert back")
def test_crossunder_assign(self):
# NOTE: should *not* raise AssertionError from numpy
with self.assertRaisesRegex(RuntimeError, "cycle"):
a = Tensor.full((4,), 2).contiguous().realize()
b = Tensor.full((4,), 3).contiguous().realize()
c = a+9
a += b
b += c
Tensor.realize(a,b)
np.testing.assert_allclose(a.numpy(), 2+3)
np.testing.assert_allclose(b.numpy(), 3+2+9)
def test_assign_kv_cache(self):
bsz, max_context = 2, 8
class Attn:
@TinyJit
def __call__(self, xk:Tensor, start_pos:Variable):
seqlen = xk.shape[1]
if not hasattr(self, "cache_k"):
self.cache_k = Tensor.zeros(bsz, max_context, 1, 1).contiguous()
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk
self.cache_k.assign(keys.pad((None,(0,max_context-start_pos-seqlen),None,None)).contiguous()).realize()
attn = Attn()
xk = Tensor.ones(bsz, 3, 1, 1).contiguous()
attn(xk, 0)
for i in range(3,6):
# copied from LLaMA
start_pos = Variable("start_pos", 1, max_context).bind(i)
xk = Tensor.ones(bsz, 1, 1, 1).contiguous()
attn(xk, start_pos)
out = attn.cache_k.flatten().numpy()
np.testing.assert_allclose(out, [1.,1.,1.,1.,1.,1.,0.,0.,1.,1.,1.,1.,1.,1.,0.,0.])
def test_assign_contiguous(self):
b = Tensor.arange(16).reshape(4,4).contiguous().realize()
a = (Tensor.arange(16).reshape(4,4).contiguous().realize() + 1)
kc = GlobalCounters.kernel_count
b.assign(a.contiguous()).realize()
assert GlobalCounters.kernel_count - kc == 2
def test_assign_contiguous_permute(self):
b = Tensor.arange(16).reshape(4,4).contiguous().realize()
a = (Tensor.arange(16).reshape(4,4).contiguous().realize() + 1).permute((1,0))
kc = GlobalCounters.kernel_count
b.assign(a.contiguous()).realize()
assert GlobalCounters.kernel_count - kc == 2
def test_permuted_assignment(self):
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
a.realize()
b.realize()
ba1 = a.uop.base.realized
bb1 = b.uop.base.realized
a = a.permute(1,0)
a += b
a.realize()
ba2 = a.uop.base.realized
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
# permute and base are the same buffer
assert ba1 == ba2 and ba1 != bb1
def test_post_permuted_assignment(self):
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
a.realize()
b.realize()
#GlobalCounters.cache = []
ba1 = a.uop.base.realized # noqa: F841
bb1 = b.uop.base.realized # noqa: F841
a.assign(a.permute(1,0) + b) # this should not work!
a.realize()
ba2 = a.uop.base.realized # noqa: F841
# NOTE: don't test that it's assigned
#assert ba1 == ba2 and ba1 != bb1
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
def test_post_permuted_assignment_alt(self):
a = Tensor.arange(N*N).reshape(N,N).contiguous().realize()
b = Tensor.arange(N*N).reshape(N,N).contiguous().realize()
new_a = (a.T+b).numpy()
a.assign(a.T+b)
np.testing.assert_allclose(a.numpy(), new_a)
def test_post_flipped_assignment(self):
a = Tensor.arange(N*N).reshape(N,N).contiguous().realize()
b = Tensor.arange(N*N).reshape(N,N).contiguous().realize()
new_a = (a.flip(0)+b).numpy()
a.assign(a.flip(0)+b)
np.testing.assert_allclose(a.numpy(), new_a)
def test_post_flipped_assignment_axis1(self):
a = Tensor.arange(N*N).reshape(N,N).contiguous().realize()
b = Tensor.arange(N*N).reshape(N,N).contiguous().realize()
new_a = (a.flip(1)+b).numpy()
a.assign(a.flip(1)+b)
np.testing.assert_allclose(a.numpy(), new_a)
def test_post_reshape_assignment_fine(self):
a = Tensor.arange(N*N).reshape(N, N).contiguous().realize()
b = Tensor.arange(N*N).reshape(N, N).contiguous().realize()
rhs = a.reshape(-1).reshape(N, N)
new_a = (rhs+b).numpy()
a.assign(rhs+b) # self-assign with reshape view is fine
np.testing.assert_allclose(a.numpy(), new_a)
@unittest.skip("multi output not supported anymore")
def test_simple_assignment_multioutput(self):
a = Tensor.arange(32*32).reshape(32, 32).contiguous().realize()
b = Tensor.full((32, ), 1.).contiguous().realize()
c = Tensor.full((32, ), 2.).contiguous().realize()
d = Tensor.full((32, ), 3.).contiguous().realize()
r = a.sum(axis=1)
b.assign(r + b)
c.assign(r + c)
d.assign(r + d)
kc = GlobalCounters.kernel_count
Tensor.realize(b, c, d)
assert GlobalCounters.kernel_count - kc == 1
np.testing.assert_allclose(b.numpy(), a.sum(1).numpy()+1)
np.testing.assert_allclose(c.numpy(), a.sum(1).numpy()+2)
np.testing.assert_allclose(d.numpy(), a.sum(1).numpy()+3)
# NOTE: if the assign target is read/write in a single kernel, it should be contiguous
def test_permuted_assignment_correct(self):
a = Tensor.arange(4 * 4).reshape(4, 4).contiguous().realize()
b = Tensor.arange(4 * 4).reshape(4, 4).contiguous().realize()
a = a.permute(1, 0)
new_val = a + b
a.assign(new_val)
np.testing.assert_equal(a.numpy(), np.arange(4 * 4).reshape(4, 4).transpose(1, 0) + np.arange(4 * 4).reshape(4, 4))
def test_permuted_reduceop_child_dual_use(self):
a = Tensor.arange(32*32*32).reshape(32, 32, 32).contiguous().realize()
b = Tensor.ones(32, 32, dtype=dtypes.int).contiguous().realize()
r = a.sum(axis=1)
b.assign(r + b.permute(1, 0))
b.realize()
np.testing.assert_equal(b.numpy(), a.numpy().sum(axis=1)+np.ones((32, 32), dtype=np.int32).transpose(1, 0))
@unittest.skip("multi output not supported anymore")
def test_permuted_reduceop_multioutput_dual_use(self):
a = Tensor.arange(32*32*32).reshape(32, 32, 32).contiguous().realize()
b = Tensor.full((32, 32), 1.).contiguous().realize()
c = Tensor.full((32, 32), 2.).contiguous().realize()
with self.assertRaisesRegex(RuntimeError, "contiguous"):
r = a.sum(axis=1)
b_perm = b.permute(1, 0)
b.assign(r + b)
c.assign(r + b_perm)
Tensor.realize(b, c)
@unittest.skip("multi output not supported anymore")
def test_permuted_reduceop_multioutput_dual_use_possible(self):
a = Tensor.arange(32*32*32).reshape(32, 32, 32).contiguous().realize()
b = Tensor.arange(32 * 32).reshape(32, 32).realize()
c = Tensor.arange(32 * 32).reshape(32, 32).realize()
kc = GlobalCounters.kernel_count
r = a.sum(axis=1)
b_perm = b.permute(1, 0)
b.assign(r + b)
c.assign(r + b_perm.contiguous())
Tensor.realize(b, c)
assert GlobalCounters.kernel_count - kc == 2
np.testing.assert_equal(b.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32))
np.testing.assert_equal(c.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32).transpose(1, 0))
def test_permuted_assignment_masked_view_possible(self):
a = Tensor.ones(4, 4).contiguous().realize()
b = a.shrink((None, (0, 2))).pad((None, (0, 2)), value=2)
a.assign(a + b)
kc = GlobalCounters.kernel_count
a.realize()
assert GlobalCounters.kernel_count - kc == 1
np.testing.assert_equal(a.numpy(), np.ones((4, 4))+np.pad(np.ones((4, 4))[:, 0:2], ((0, 0), (0, 2)), constant_values=2))
def test_permuted_assignment_masked_view_not_contiguous(self):
a = Tensor.ones(4, 4).contiguous().realize()
b = a.shrink((None, (0, 2))).pad((None, (0, 2)), value=2).permute(1, 0)
a.assign(a + b)
a.realize()
self.assertListEqual(a.tolist(), [[2.,2.,2.,2.],[2.,2.,2.,2.],[3.,3.,3.,3.], [3.,3.,3.,3.]])
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?
@unittest.skip("this test is crashing!")
def test_overlapping_shrink_assignment_forward(self):
# Forward shift: read index > write index in overlap
N = 100000
shift = 1000
a = Tensor.arange(N).float().contiguous().realize()
expected = np.arange(N, dtype=np.float32)
expected[:N-shift] = expected[shift:].copy()
with Context(NOOPT=1): a[0:N-shift].assign(a[shift:N]).realize()
np.testing.assert_allclose(a.numpy(), expected)
@unittest.skip("this test is crashing!")
def test_overlapping_shrink_assignment_reverse(self):
# Reverse shift: write index > read index in overlap
N = 100000
shift = 1000
a = Tensor.arange(N).float().contiguous().realize()
expected = np.arange(N, dtype=np.float32)
expected[shift:] = expected[:N-shift].copy()
with Context(NOOPT=1): a[shift:N].assign(a[0:N-shift]).realize()
np.testing.assert_allclose(a.numpy(), expected)
@unittest.skip("this test is crashing!")
def test_nonoverlapping_shrink_assignment(self):
# TODO: non-overlapping shrinks don't actually need contiguous, could be 1 kernel with smarter range analysis
a = Tensor.arange(100).float().contiguous().realize()
expected = np.arange(100, dtype=np.float32)
expected[0:10] = expected[50:60].copy()
kc = GlobalCounters.kernel_count
a[0:10].assign(a[50:60]).realize()
assert GlobalCounters.kernel_count - kc == 2, "currently conservative, forces contiguous"
np.testing.assert_allclose(a.numpy(), expected)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_setitem_half(self):
a = Tensor.full((8,), 1.0, dtype=dtypes.half).contiguous().realize()
b = Tensor.full((4,), 2.0, dtype=dtypes.half).contiguous().realize()
assign = a[:4].assign(b)
assign.realize()
np.testing.assert_allclose(a.numpy(), [2., 2., 2., 2., 1., 1., 1., 1.])
def test_setitem_list(self):
a = Tensor.zeros(8).contiguous().realize()
a[2:5] = [1, 2, 3]
np.testing.assert_allclose(a.numpy(), [0., 0., 1., 2., 3., 0., 0., 0.])
def test_assign_bitcast(self):
# assign to a bitcast view should modify the underlying buffer
a = Tensor([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32).realize()
# IEEE 754: 1.0f = 0x3f800000, 2.0f = 0x40000000, 3.0f = 0x40400000, 4.0f = 0x40800000
a.bitcast(dtypes.uint32).assign(Tensor([0x40800000, 0x40400000, 0x40000000, 0x3f800000], dtype=dtypes.uint32)).realize()
np.testing.assert_allclose(a.numpy(), [4.0, 3.0, 2.0, 1.0])
# double bitcast
b = Tensor([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32).realize()
b.bitcast(dtypes.uint32).bitcast(dtypes.int32).assign(Tensor([0x40800000, 0x40400000, 0x40000000, 0x3f800000], dtype=dtypes.int32)).realize()
np.testing.assert_allclose(b.numpy(), [4.0, 3.0, 2.0, 1.0])
# shrink then bitcast
c = Tensor([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32).realize()
c[0:2].bitcast(dtypes.uint32).assign(Tensor([0x40800000, 0x40400000], dtype=dtypes.uint32)).realize()
np.testing.assert_allclose(c.numpy(), [4.0, 3.0, 3.0, 4.0])
def test_assign_bitcast_different_size(self):
# different-size bitcast creates a new tensor, not a view, so assign doesn't modify the original
a = Tensor([0]*8, dtype=dtypes.uint8).realize()
a.bitcast(dtypes.int64).assign(Tensor([12345], dtype=dtypes.int64)).realize()
np.testing.assert_equal(a.numpy(), [0]*8)
@unittest.skip("don't use output buffer, and mismatch dtype no longer supported")
def test_cast_assignment(self):
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
a.realize()
oba1 = a.uop.base.output_buffer
a.assign(a.cast(dtypes.int32).realize())
a.realize()
oba2 = a.uop.base.output_buffer
assert oba1 is None and oba2 is None
np.testing.assert_allclose(a.numpy(), np.arange(N*N,dtype=np.int32).reshape((N,N)))
def test_assign_dtype_mismatch(self):
# assign should not implicitly cast dtypes - this can lose precision
a = Tensor.zeros(4, dtype=dtypes.float32).contiguous().realize()
b = Tensor([1, 2, 3, 4], dtype=dtypes.int32)
with self.assertRaisesRegex(RuntimeError, "assign dtype mismatch"):
a.assign(b)
def test_assign_dtype_mismatch_int64_to_float32(self):
# int64 -> float32 loses precision for large values, should not be implicit
a = Tensor.zeros(1, dtype=dtypes.float32).contiguous().realize()
b = Tensor([16777217], dtype=dtypes.int64) # 2^24 + 1, not exactly representable in float32
with self.assertRaisesRegex(RuntimeError, "assign dtype mismatch"):
a.assign(b)
def test_assign_shape_broadcast(self):
# shape broadcasting should work when dtypes match
a = Tensor.zeros(3, 5, dtype=dtypes.float32).contiguous().realize()
b = Tensor([1., 2., 3., 4., 5.], dtype=dtypes.float32)
a.assign(b)
a.realize()
expected = np.array([[1., 2., 3., 4., 5.]] * 3)
np.testing.assert_allclose(a.numpy(), expected)
def test_assign_shape_broadcast_2d(self):
# broadcast (1, 5) to (3, 5)
a = Tensor.zeros(3, 5, dtype=dtypes.float32).contiguous().realize()
b = Tensor([[1., 2., 3., 4., 5.]], dtype=dtypes.float32)
a.assign(b)
a.realize()
expected = np.array([[1., 2., 3., 4., 5.]] * 3)
np.testing.assert_allclose(a.numpy(), expected)
def test_disk_assignment(self):
a = Tensor.empty(5, device=f"disk:{temp('disk_assignment')}").assign(Tensor.ones(5)).numpy()
np.testing.assert_equal(a, np.ones(5))
@unittest.skip("this test is crashing!")
def test_assign_slice_then_read(self):
"""Assign to slice then read from buffer - read should see the assigned values.
This is the KV cache pattern from llm.py.
"""
v_pos = Variable("pos", 0, 3).bind(0)
cache = Tensor.zeros(4, 4).contiguous().realize()
cache[v_pos:v_pos+1, :].assign(Tensor.ones(1, 4))
self.assertEqual(cache.sum().item(), 4.0)
def test_chained_assign_slice_then_read(self):
"""Three caches with chained assign-then-read: each block writes to its cache and reads back,
feeding the result to the next block's assign. Without proper dependency tracking, block N's read
may see stale data from block N-1's cache (pre-assign zeros instead of the assigned values).
This is the multi-layer KV cache pattern from llm.py._attention.
"""
D, max_ctx = 4, 8
cache1 = Tensor.zeros(max_ctx, D).contiguous().realize()
cache2 = Tensor.zeros(max_ctx, D).contiguous().realize()
cache3 = Tensor.zeros(max_ctx, D).contiguous().realize()
cache1[:3].assign(Tensor.ones(3, D)).realize()
cache2[:3].assign(Tensor.ones(3, D) * 2).realize()
cache3[:3].assign(Tensor.ones(3, D) * 3).realize()
# block 1: assign [10]*D at position 3, read sum -> c1=[13]*D
cache1[3:4].assign(Tensor.ones(1, D) * 10)
c1 = cache1[:4].sum(0, keepdim=True)
# block 2: assign c1 at position 3, read sum -> c2=[19]*D
cache2[3:4].assign(c1)
c2 = cache2[:4].sum(0, keepdim=True)
# block 3: assign c2 at position 3, read sum -> 112
cache3[3:4].assign(c2)
self.assertEqual(cache3[:4].sum().item(), 112.0)
def test_chained_assign_kernel_count(self):
"""Chained pending assigns must not produce excessive kernels (tests recursive transitive processing)."""
D, N = 4, 5
caches = [Tensor.zeros(8, D).contiguous().realize() for _ in range(N)]
caches[0][0:1].assign(Tensor.ones(1, D) * 10)
x = caches[0][:1].sum(0, keepdim=True)
for i in range(1, N):
caches[i][0:1].assign(x)
x = caches[i][:1].sum(0, keepdim=True)
GlobalCounters.reset()
x.realize()
# N assigns (1 kernel each) producing N kernels total
self.assertEqual(GlobalCounters.kernel_count, N)
def test_shared_computation_assign_kernel_count(self):
"""When a .contiguous() is shared between an assign value and the next layer's input (like QKV projection in LLM),
substitute optimization replaces already-realized sub-graphs in remaining pending assigns, preventing kernel escalation.
Without substitute, pending assign graphs grow linearly and produce 153 kernels instead of 48."""
D, N = 16, 16
caches = [Tensor.zeros(4, D).contiguous().realize() for _ in range(N)]
W = [Tensor.full((D, D*2), 0.01).contiguous().realize() for _ in range(N)]
x = Tensor.ones(1, D).contiguous().realize()
for i in range(N):
shared = (x @ W[i]).contiguous() # .contiguous() UOp is shared between assign (k) and next layer (q)
k, q = shared[:, :D], shared[:, D:]
caches[i][0:1].assign(k) # assign references the CONTIGUOUS
x = q + caches[i][:1] # next layer also references the same CONTIGUOUS through q
GlobalCounters.reset()
caches[-1][:1].contiguous().realize()
# 2 kernels for first assign + 3 per remaining assign (matmul, contiguous, assign) + 1 final read = 3*N
self.assertEqual(GlobalCounters.kernel_count, 3*N)
class TestAssignOrdering(unittest.TestCase):
"""Tests for complex assign orderings that could differ between lazy and eager execution.
The key principle: tinygrad's lazy execution with RAW/WAR dependency tracking should
produce the same results as eager (immediate) execution for valid programs.
These tests exercise edge cases where incorrect dependency tracking could cause:
- Stale reads (reading before write completes)
- Lost writes (write ordering reversed)
- Race conditions (concurrent access to same buffer)
"""
def test_overlapping_slice_assigns(self):
"""Overlapping slice assigns - later write should win for overlapping elements."""
buf = Tensor.zeros(8).contiguous().realize()
buf[0:4].assign(Tensor.ones(4))
buf[2:6].assign(Tensor.ones(4) * 2)
np.testing.assert_equal(buf.numpy(), [1,1,2,2,2,2,0,0])
def test_overlapping_slice_assigns_reverse(self):
"""Overlapping slice assigns in reverse order."""
buf = Tensor.zeros(8).contiguous().realize()
buf[2:6].assign(Tensor.ones(4) * 2)
buf[0:4].assign(Tensor.ones(4))
np.testing.assert_equal(buf.numpy(), [1,1,1,1,2,2,0,0])
def test_read_between_writes(self):
"""Read should see first write before second write happens."""
buf = Tensor.zeros(4).contiguous().realize()
buf.assign(Tensor.ones(4))
r1 = buf.sum().realize() # should see ones = 4
buf.assign(Tensor.ones(4) * 2)
r2 = buf.sum().realize() # should see twos = 8
self.assertEqual(r1.item(), 4)
self.assertEqual(r2.item(), 8)
def test_write_read_write_chain(self):
"""Write, read, write chain - middle read must complete before second write."""
buf = Tensor.zeros(4).contiguous().realize()
buf.assign(Tensor.ones(4) * 3)
mid_sum = buf.sum() # lazy read, should be 12
buf.assign(Tensor.ones(4) * 5)
final_sum = buf.sum() # lazy read, should be 20
# Realize in "wrong" order - final first
self.assertEqual(final_sum.realize().item(), 20)
self.assertEqual(mid_sum.realize().item(), 12)
def test_slice_read_then_full_write(self):
"""Read from slice, then overwrite full buffer - WAR dependency works for full buffer assigns."""
buf = Tensor([1.,2.,3.,4.]).contiguous().realize()
partial = buf[0:2].sum() # lazy read
buf.assign(Tensor.ones(4) * 10) # overwrite everything
full = buf.sum()
# WAR dependency correctly tracked - partial sees original data
self.assertEqual(partial.realize().item(), 3) # 1+2
self.assertEqual(full.realize().item(), 40)
def test_slice_write_then_full_read(self):
"""Write to slice, then read full buffer."""
buf = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
buf[1:3].assign(Tensor([5, 6]))
np.testing.assert_equal(buf.numpy(), [0, 5, 6, 0])
def test_chained_slice_copies(self):
"""Copy from one slice to another within same buffer."""
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
buf[4:8].assign(buf[0:4].contiguous())
np.testing.assert_equal(buf.numpy(), [1, 2, 3, 4, 1, 2, 3, 4])
def test_swap_slices(self):
"""Swap two non-overlapping slices - requires reading both before writing."""
# without .realize() on temps: values not captured before overwriting
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
left = buf[0:4].contiguous() # lazy - not captured yet
right = buf[4:8].contiguous() # lazy - not captured yet
buf[0:4].assign(right).realize() # this works
buf[4:8].assign(left).realize() # left now reads from modified buf!
np.testing.assert_equal(buf.numpy(), [5, 6, 7, 8, 5, 6, 7, 8]) # TODO: wrong! should be [5,6,7,8,1,2,3,4]
# with .realize() on temps: values captured before writes
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
left = buf[0:4].contiguous().realize()
right = buf[4:8].contiguous().realize()
buf[0:4].assign(right).realize()
buf[4:8].assign(left).realize()
np.testing.assert_equal(buf.numpy(), [5, 6, 7, 8, 1, 2, 3, 4])
def test_reduction_after_partial_assign(self):
"""Reduction over buffer after partial assign - must see the assigned values."""
buf = Tensor.zeros(4, 4).contiguous().realize()
buf[0:2, :].assign(Tensor.ones(2, 4)) # top half = 1
total = buf.sum()
self.assertEqual(total.item(), 8)
def test_multiple_reductions_different_views(self):
"""Multiple reductions over different views of same buffer after assign."""
buf = Tensor.zeros(4, 4).contiguous().realize()
buf.assign(Tensor.arange(16).reshape(4, 4).float())
row_sums = buf.sum(axis=1) # [6, 22, 38, 54]
col_sums = buf.sum(axis=0) # [24, 28, 32, 36]
total = buf.sum() # 120
# All should see the assigned values
np.testing.assert_equal(row_sums.numpy(), [6, 22, 38, 54])
np.testing.assert_equal(col_sums.numpy(), [24, 28, 32, 36])
self.assertEqual(total.item(), 120)
def test_assign_from_self_transformed(self):
"""Assign to buffer from transformed view of itself."""
buf = Tensor([1, 2, 3, 4]).contiguous().realize()
# Read and transform, then write back (requires reading before writing)
buf.assign((buf * 2).contiguous())
np.testing.assert_equal(buf.numpy(), [2, 4, 6, 8])
def test_two_buffers_cross_assign(self):
"""Two buffers each reading from the other before writing."""
a = Tensor([1, 2, 3, 4]).contiguous().realize()
b = Tensor([10, 20, 30, 40]).contiguous().realize()
# Both read from each other's original values
a_new = (a + b).contiguous()
b_new = (a * b).contiguous()
a.assign(a_new)
b.assign(b_new)
Tensor.realize(a, b)
np.testing.assert_equal(a.numpy(), [11, 22, 33, 44])
np.testing.assert_equal(b.numpy(), [10, 40, 90, 160])
def test_three_buffer_chain(self):
"""Chain: A depends on B, B depends on C - ordering matters."""
a = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
b = Tensor([1, 2, 3, 4]).contiguous().realize()
c = Tensor([10, 10, 10, 10]).contiguous().realize()
# b reads from c, a reads from b
b.assign((b + c).contiguous()) # b = [11, 12, 13, 14]
a.assign((a + b).contiguous()) # a should see new b = [11, 12, 13, 14]
Tensor.realize(a, b)
np.testing.assert_equal(b.numpy(), [11, 12, 13, 14])
np.testing.assert_equal(a.numpy(), [11, 12, 13, 14])
def test_interleaved_assign_read_patterns(self):
"""Complex interleaved pattern: write A, read A into B, write B, read B."""
a = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
b = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
a.assign(Tensor([1, 2, 3, 4]))
b.assign(a.contiguous()) # b should get [1,2,3,4]
a.assign(Tensor([5, 6, 7, 8]))
result = b.sum() # should be 10, not 26
self.assertEqual(result.item(), 10)
np.testing.assert_equal(a.numpy(), [5, 6, 7, 8])
np.testing.assert_equal(b.numpy(), [1, 2, 3, 4])
def test_variable_slice_ordering(self):
"""Variable-indexed slices - tests symbolic dependency tracking."""
v_i = Variable("i", 0, 3)
buf = Tensor.zeros(4, 4).contiguous().realize()
buf[v_i.bind(0):v_i.bind(0)+1, :].assign(Tensor.ones(1, 4))
buf[v_i.bind(1):v_i.bind(1)+1, :].assign(Tensor.ones(1, 4) * 2)
self.assertEqual(buf[0:1, :].sum().item(), 4)
self.assertEqual(buf[1:2, :].sum().item(), 8)
def test_multiple_slice_assigns_then_read(self):
"""Multiple non-overlapping slice assigns then read."""
buf = Tensor.zeros(4).contiguous().realize()
buf[0:1].assign(Tensor.ones(1))
buf[1:2].assign(Tensor.full((1,), 2.0))
buf[2:3].assign(Tensor.full((1,), 3.0))
self.assertEqual(buf.sum().realize().item(), 6.0)
if __name__ == "__main__":
unittest.main()