mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
1012 lines
40 KiB
Python
1012 lines
40 KiB
Python
#!/usr/bin/env python
|
|
import unittest
|
|
import numpy as np
|
|
from tinygrad import dtypes, Tensor, TinyJit, GlobalCounters, Variable
|
|
from tinygrad.uop.ops import Ops
|
|
from tinygrad.device import is_dtype_supported
|
|
from tinygrad.helpers import temp, CI, DEV, Context
|
|
|
|
N = 200 # has to be bigger than the cache to fail
|
|
|
|
class TestAssign(unittest.TestCase):
|
|
def test_simple_assignment(self):
|
|
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
|
b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
|
a.realize()
|
|
b.realize()
|
|
ba1 = a.uop.base.realized
|
|
bb1 = b.uop.base.realized
|
|
a += b
|
|
a.realize()
|
|
ba2 = a.uop.base.realized
|
|
assert ba1 == ba2 and ba1 != bb1
|
|
np.testing.assert_allclose(a.numpy(), (np.arange(N*N)*2).reshape((N,N)))
|
|
|
|
def test_assign_zeros_good(self):
|
|
a = Tensor.zeros(10,10).contiguous()
|
|
a.assign(Tensor.ones(10,10))
|
|
b = Tensor.zeros(10,10).contiguous()
|
|
a.realize()
|
|
np.testing.assert_allclose(b.numpy(), 0)
|
|
|
|
@unittest.skip("TODO: this often crashes in CI")
|
|
def test_assign_zeros(self):
|
|
a = Tensor.zeros(10,10).contiguous()
|
|
b = Tensor.zeros(10,10).contiguous()
|
|
a.assign(Tensor.ones(10,10))
|
|
a.realize()
|
|
np.testing.assert_allclose(b.numpy(), 0)
|
|
|
|
def test_assign_copy(self):
|
|
a = Tensor([1.,2,3], device="PYTHON")
|
|
c = Tensor.empty(3).assign(a.to(None))
|
|
# it should copy into the empty buffer
|
|
GlobalCounters.reset()
|
|
c.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
|
|
def test_assign_add(self):
|
|
for T in (1, 2, 10):#, 100): # this crashes in CI, not sure why
|
|
x = Tensor([0]).realize()
|
|
buf = x.uop.base.realized
|
|
for _ in range(T):
|
|
x += 1
|
|
x.realize()
|
|
assert x.item() == T
|
|
assert x.uop.base.realized is buf
|
|
|
|
def test_assign_slice_add(self):
|
|
for T in (1, 2, 10, 100):
|
|
x = Tensor([0, 0]).realize()
|
|
buf = x.uop.base.realized
|
|
for _ in range(T):
|
|
x[0] += 1
|
|
x.realize()
|
|
assert x.tolist() == [T, 0]
|
|
assert x.uop.base.realized is buf
|
|
|
|
def test_assign_add_double(self):
|
|
def f(x):
|
|
x += 1
|
|
x.realize()
|
|
x = Tensor([0])
|
|
f(x)
|
|
out = x.item()
|
|
assert out == 1, f"expected 1, got {out}"
|
|
x = Tensor([0])
|
|
f(x)
|
|
out = x.item()
|
|
assert out == 1, f"expected 1, got {out}"
|
|
|
|
def test_assign_add_jit(self):
|
|
@TinyJit
|
|
def f(x):
|
|
x += 1
|
|
x.realize()
|
|
x = Tensor([0])
|
|
for _ in range(5): f(x)
|
|
assert x.item() == 5
|
|
|
|
def test_assign_add_jit_other(self):
|
|
@TinyJit
|
|
def f(x):
|
|
x += 1
|
|
x.realize()
|
|
x = Tensor([0])
|
|
for _ in range(5): f(x)
|
|
assert x.item() == 5
|
|
|
|
y = Tensor([0])
|
|
for _ in range(4): f(y)
|
|
assert y.item() == 4
|
|
|
|
def test_assign_other_jit(self):
|
|
@TinyJit
|
|
def f(x, a):
|
|
x.assign(a)
|
|
x.realize()
|
|
x = Tensor([0])
|
|
for i in range(1, 6):
|
|
f(x, x.full_like(i).contiguous()) # const would be implicitly folded without contiguous
|
|
assert x.item() == i
|
|
|
|
def test_assign_add_other_jit(self):
|
|
@TinyJit
|
|
def f(x, a):
|
|
x += a
|
|
x.realize()
|
|
x = Tensor([0])
|
|
a = 0
|
|
for i in range(1, 6):
|
|
a += i
|
|
f(x, x.full_like(i).contiguous())
|
|
assert x.item() == a
|
|
|
|
def test_assign_changes(self):
|
|
a = Tensor.ones(4).contiguous().realize()
|
|
old_a = a
|
|
a.assign(Tensor.full((4,), 2.).contiguous())
|
|
# NOTE: old_a is now 2, and this would match the behavior of pytorch
|
|
new = a + old_a
|
|
np.testing.assert_allclose(new.numpy(), 4)
|
|
|
|
@unittest.skip("TODO: this is broken")
|
|
def test_assign_changes_alt(self, realize=False):
|
|
a = Tensor(1).contiguous()
|
|
if realize: a.realize()
|
|
b = a.contiguous() # b returns a new Tensor
|
|
b.assign(2)
|
|
b.realize()
|
|
self.assertNotEqual(a.item(), b.item())
|
|
# on a realized Tensor contiguous child changes the source
|
|
@unittest.expectedFailure
|
|
def test_assign_changes_realized_alt(self): return self.test_assign_changes_alt(realize=True)
|
|
|
|
@unittest.skip("assign to contiguous shouldn't change the base buffer")
|
|
def test_assign_changes_buffer_alt(self):
|
|
a, b = [Tensor(Tensor(0).contiguous().realize().uop.buf_uop) for _ in range(2)]
|
|
Tensor.realize(a.contiguous().assign(1), b.contiguous().assign(2))
|
|
self.assertEqual((a + b).item(), 3)
|
|
|
|
def test_assign_diamond_cycle(self):
|
|
# NOTE: should *not* raise AssertionError from numpy
|
|
with self.assertRaisesRegex(RuntimeError, "cycle"):
|
|
a = Tensor.ones(4).contiguous().realize()
|
|
times_a = a*3
|
|
a.assign(Tensor.full((4,), 2.).contiguous())
|
|
new = a + (times_a-1)
|
|
np.testing.assert_allclose(new.numpy(), 4)
|
|
|
|
def test_assign_diamond_contiguous_cycle(self):
|
|
with self.assertRaisesRegex(RuntimeError, "cycle"):
|
|
a = Tensor.ones(4).contiguous().realize()
|
|
times_a = a*3
|
|
a.assign(Tensor.full((4,), 2.))
|
|
new = a.contiguous() + times_a-1
|
|
np.testing.assert_allclose(new.numpy(), 4)
|
|
|
|
def test_assign_diamond_possible(self):
|
|
a = Tensor.ones(4).contiguous().realize()
|
|
times_a = a*3
|
|
a.assign(Tensor.full((4,), 2.))
|
|
new = a + (times_a-1).contiguous()
|
|
np.testing.assert_allclose(new.numpy(), 4)
|
|
|
|
def test_assign_diamond_possible_contiguous(self):
|
|
a = Tensor.ones(4).contiguous().realize()
|
|
times_a = a*3
|
|
a.assign(Tensor.full((4,), 2.).contiguous())
|
|
new = a + (times_a-1).contiguous()
|
|
np.testing.assert_allclose(new.numpy(), 4)
|
|
|
|
def test_assign_diamond_both_contiguous(self):
|
|
a = Tensor.ones(4).contiguous().realize()
|
|
times_a = a*3
|
|
a.assign(Tensor.full((4,), 2.))
|
|
new = a.contiguous() + (times_a-1).contiguous()
|
|
np.testing.assert_allclose(new.numpy(), 4)
|
|
|
|
def test_assign_diamond_alt(self):
|
|
a = Tensor.ones(4).contiguous().realize()
|
|
a.assign(Tensor.full((4,), 2.).contiguous())
|
|
times_a = a*3
|
|
new = a + times_a
|
|
np.testing.assert_allclose(new.numpy(), 8)
|
|
|
|
@unittest.skipIf(CI and DEV.renderer == "LVP", "flaky in CI")
|
|
def test_double_assign(self):
|
|
a = Tensor.ones(4).contiguous().realize()
|
|
a += 1
|
|
a += 1
|
|
np.testing.assert_allclose(a.numpy(), 3)
|
|
|
|
def test_crossover_assign(self):
|
|
a = Tensor.full((4,), 2).contiguous().realize()
|
|
b = Tensor.full((4,), 3).contiguous().realize()
|
|
a += b
|
|
b += a
|
|
Tensor.realize(a,b)
|
|
np.testing.assert_allclose(a.numpy(), 5)
|
|
np.testing.assert_allclose(b.numpy(), 8)
|
|
|
|
def test_assign_double_diamond(self):
|
|
a = Tensor.full((4,), 2).contiguous().realize()
|
|
b = Tensor.full((4,), 3).contiguous().realize()
|
|
a_prev = a*4
|
|
b_prev = b+3
|
|
b += a_prev.contiguous()
|
|
a += b_prev.contiguous()
|
|
Tensor.realize(a, b)
|
|
np.testing.assert_equal(b.numpy(), 11)
|
|
np.testing.assert_equal(a.numpy(), 8)
|
|
|
|
def test_assign_double_diamond_reduce(self):
|
|
a0 = Tensor.full((16, 16), 10).contiguous().realize()
|
|
a1 = Tensor.full((16, 16), 20).contiguous().realize()
|
|
b0 = Tensor.full((16, ), 1).contiguous().realize()
|
|
b1 = Tensor.full((16, ), 2).contiguous().realize()
|
|
|
|
r0 = (a0 - b1.contiguous()).sum(1)
|
|
r1 = (a1 - b0.contiguous()).sum(1)
|
|
b0.assign(r0 * b0)
|
|
b1.assign(r1 * b1)
|
|
Tensor.realize(b0, b1)
|
|
np.testing.assert_equal(b0.numpy(), 128)
|
|
np.testing.assert_equal(b1.numpy(), 608)
|
|
|
|
def test_crossunder_assign(self):
|
|
# NOTE: should *not* raise AssertionError from numpy
|
|
with self.assertRaisesRegex(RuntimeError, "cycle"):
|
|
a = Tensor.full((4,), 2).contiguous().realize()
|
|
b = Tensor.full((4,), 3).contiguous().realize()
|
|
c = a+9
|
|
a += b
|
|
b += c
|
|
Tensor.realize(a,b)
|
|
np.testing.assert_allclose(a.numpy(), 2+3)
|
|
np.testing.assert_allclose(b.numpy(), 3+2+9)
|
|
|
|
def test_assign_kv_cache(self):
|
|
bsz, max_context = 2, 8
|
|
|
|
class Attn:
|
|
@TinyJit
|
|
def __call__(self, xk:Tensor, start_pos:Variable):
|
|
seqlen = xk.shape[1]
|
|
if not hasattr(self, "cache_k"):
|
|
self.cache_k = Tensor.zeros(bsz, max_context, 1, 1).contiguous()
|
|
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk
|
|
self.cache_k.assign(keys.pad((None,(0,max_context-start_pos-seqlen),None,None)).contiguous()).realize()
|
|
|
|
attn = Attn()
|
|
xk = Tensor.ones(bsz, 3, 1, 1).contiguous()
|
|
attn(xk, 0)
|
|
for i in range(3,6):
|
|
# copied from LLaMA
|
|
start_pos = Variable("start_pos", 1, max_context).bind(i)
|
|
xk = Tensor.ones(bsz, 1, 1, 1).contiguous()
|
|
attn(xk, start_pos)
|
|
|
|
out = attn.cache_k.flatten().numpy()
|
|
np.testing.assert_allclose(out, [1.,1.,1.,1.,1.,1.,0.,0.,1.,1.,1.,1.,1.,1.,0.,0.])
|
|
|
|
def test_assign_after(self):
|
|
t = Tensor.zeros(10).contiguous().realize()
|
|
t.uop = t.uop.after(t.uop.store((t+1).uop))
|
|
np.testing.assert_allclose(t.numpy(), [1.,1.,1.,1.,1.,1.,1.,1.,1.,1.])
|
|
|
|
def test_assign_after_partial(self):
|
|
t = Tensor.zeros(10).contiguous().realize()
|
|
t.uop = t.uop.after(t[:5].uop.after(t[:5].uop.store(Tensor.ones(5).uop)))
|
|
np.testing.assert_allclose(t.numpy(), [1.,1.,1.,1.,1.,0.,0.,0.,0.,0.])
|
|
|
|
def test_assign_after_target_chain(self):
|
|
t = Tensor.arange(16).reshape(4, 4).permute(1, 0).contiguous()
|
|
t.assign(t + 100)
|
|
np.testing.assert_equal(t.numpy(), [[100, 104, 108, 112], [101, 105, 109, 113], [102, 106, 110, 114], [103, 107, 111, 115]])
|
|
|
|
def test_assign_contiguous(self):
|
|
b = Tensor.arange(16).reshape(4,4).contiguous().realize()
|
|
a = (Tensor.arange(16).reshape(4,4).contiguous().realize() + 1)
|
|
GlobalCounters.reset()
|
|
b.assign(a.contiguous()).realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 2)
|
|
|
|
def test_assign_contiguous_permute(self):
|
|
b = Tensor.arange(16).reshape(4,4).contiguous().realize()
|
|
a = (Tensor.arange(16).reshape(4,4).contiguous().realize() + 1).permute((1,0))
|
|
GlobalCounters.reset()
|
|
b.assign(a.contiguous()).realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 2)
|
|
|
|
def test_permuted_assignment(self):
|
|
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
|
b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
|
a.realize()
|
|
b.realize()
|
|
ba1 = a.uop.base.realized
|
|
bb1 = b.uop.base.realized
|
|
a = a.permute(1,0)
|
|
a += b
|
|
a.realize()
|
|
ba2 = a.uop.base.realized
|
|
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
|
|
# permute and base are the same buffer
|
|
assert ba1 == ba2 and ba1 != bb1
|
|
|
|
def test_post_permuted_assignment(self):
|
|
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
|
b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
|
a.realize()
|
|
b.realize()
|
|
#GlobalCounters.cache = []
|
|
ba1 = a.uop.base.realized # noqa: F841
|
|
bb1 = b.uop.base.realized # noqa: F841
|
|
a.assign(a.permute(1,0) + b) # this should not work!
|
|
a.realize()
|
|
ba2 = a.uop.base.realized # noqa: F841
|
|
# NOTE: don't test that it's assigned
|
|
#assert ba1 == ba2 and ba1 != bb1
|
|
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
|
|
|
|
def test_post_permuted_assignment_alt(self):
|
|
a = Tensor.arange(N*N).reshape(N,N).contiguous().realize()
|
|
b = Tensor.arange(N*N).reshape(N,N).contiguous().realize()
|
|
new_a = (a.T+b).numpy()
|
|
a.assign(a.T+b)
|
|
np.testing.assert_allclose(a.numpy(), new_a)
|
|
|
|
def test_post_flipped_assignment(self):
|
|
a = Tensor.arange(N*N).reshape(N,N).contiguous().realize()
|
|
b = Tensor.arange(N*N).reshape(N,N).contiguous().realize()
|
|
new_a = (a.flip(0)+b).numpy()
|
|
a.assign(a.flip(0)+b)
|
|
np.testing.assert_allclose(a.numpy(), new_a)
|
|
|
|
def test_post_flipped_assignment_axis1(self):
|
|
a = Tensor.arange(N*N).reshape(N,N).contiguous().realize()
|
|
b = Tensor.arange(N*N).reshape(N,N).contiguous().realize()
|
|
new_a = (a.flip(1)+b).numpy()
|
|
a.assign(a.flip(1)+b)
|
|
np.testing.assert_allclose(a.numpy(), new_a)
|
|
|
|
def test_post_reshape_assignment_fine(self):
|
|
a = Tensor.arange(N*N).reshape(N, N).contiguous().realize()
|
|
b = Tensor.arange(N*N).reshape(N, N).contiguous().realize()
|
|
rhs = a.reshape(-1).reshape(N, N)
|
|
new_a = (rhs+b).numpy()
|
|
a.assign(rhs+b) # self-assign with reshape view is fine
|
|
np.testing.assert_allclose(a.numpy(), new_a)
|
|
|
|
@unittest.skip("multi output not supported anymore")
|
|
def test_simple_assignment_multioutput(self):
|
|
a = Tensor.arange(32*32).reshape(32, 32).contiguous().realize()
|
|
b = Tensor.full((32, ), 1.).contiguous().realize()
|
|
c = Tensor.full((32, ), 2.).contiguous().realize()
|
|
d = Tensor.full((32, ), 3.).contiguous().realize()
|
|
|
|
r = a.sum(axis=1)
|
|
b.assign(r + b)
|
|
c.assign(r + c)
|
|
d.assign(r + d)
|
|
|
|
GlobalCounters.reset()
|
|
Tensor.realize(b, c, d)
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
np.testing.assert_allclose(b.numpy(), a.sum(1).numpy()+1)
|
|
np.testing.assert_allclose(c.numpy(), a.sum(1).numpy()+2)
|
|
np.testing.assert_allclose(d.numpy(), a.sum(1).numpy()+3)
|
|
|
|
# NOTE: if the assign target is read/write in a single kernel, it should be contiguous
|
|
|
|
def test_permuted_assignment_correct(self):
|
|
a = Tensor.arange(4 * 4).reshape(4, 4).contiguous().realize()
|
|
b = Tensor.arange(4 * 4).reshape(4, 4).contiguous().realize()
|
|
a = a.permute(1, 0)
|
|
new_val = a + b
|
|
a.assign(new_val)
|
|
np.testing.assert_equal(a.numpy(), np.arange(4 * 4).reshape(4, 4).transpose(1, 0) + np.arange(4 * 4).reshape(4, 4))
|
|
|
|
def test_permuted_reduceop_child_dual_use(self):
|
|
a = Tensor.arange(32*32*32).reshape(32, 32, 32).contiguous().realize()
|
|
b = Tensor.ones(32, 32, dtype=dtypes.int).contiguous().realize()
|
|
r = a.sum(axis=1)
|
|
b.assign(r + b.permute(1, 0))
|
|
b.realize()
|
|
np.testing.assert_equal(b.numpy(), a.numpy().sum(axis=1)+np.ones((32, 32), dtype=np.int32).transpose(1, 0))
|
|
|
|
@unittest.skip("multi output not supported anymore")
|
|
def test_permuted_reduceop_multioutput_dual_use(self):
|
|
a = Tensor.arange(32*32*32).reshape(32, 32, 32).contiguous().realize()
|
|
b = Tensor.full((32, 32), 1.).contiguous().realize()
|
|
c = Tensor.full((32, 32), 2.).contiguous().realize()
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "contiguous"):
|
|
r = a.sum(axis=1)
|
|
b_perm = b.permute(1, 0)
|
|
b.assign(r + b)
|
|
c.assign(r + b_perm)
|
|
Tensor.realize(b, c)
|
|
|
|
@unittest.skip("multi output not supported anymore")
|
|
def test_permuted_reduceop_multioutput_dual_use_possible(self):
|
|
a = Tensor.arange(32*32*32).reshape(32, 32, 32).contiguous().realize()
|
|
b = Tensor.arange(32 * 32).reshape(32, 32).realize()
|
|
c = Tensor.arange(32 * 32).reshape(32, 32).realize()
|
|
|
|
GlobalCounters.reset()
|
|
r = a.sum(axis=1)
|
|
b_perm = b.permute(1, 0)
|
|
b.assign(r + b)
|
|
c.assign(r + b_perm.contiguous())
|
|
Tensor.realize(b, c)
|
|
self.assertEqual(GlobalCounters.kernel_count, 2)
|
|
np.testing.assert_equal(b.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32))
|
|
np.testing.assert_equal(c.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32).transpose(1, 0))
|
|
|
|
def test_permuted_assignment_masked_view_possible(self):
|
|
a = Tensor.ones(4, 4).contiguous().realize()
|
|
b = a.shrink((None, (0, 2))).pad((None, (0, 2)), value=2)
|
|
a.assign(a + b)
|
|
GlobalCounters.reset()
|
|
a.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
np.testing.assert_equal(a.numpy(), np.ones((4, 4))+np.pad(np.ones((4, 4))[:, 0:2], ((0, 0), (0, 2)), constant_values=2))
|
|
|
|
def test_permuted_assignment_masked_view_not_contiguous(self):
|
|
a = Tensor.ones(4, 4).contiguous().realize()
|
|
b = a.shrink((None, (0, 2))).pad((None, (0, 2)), value=2).permute(1, 0)
|
|
a.assign(a + b)
|
|
a.realize()
|
|
self.assertListEqual(a.tolist(), [[2.,2.,2.,2.],[2.,2.,2.,2.],[3.,3.,3.,3.], [3.,3.,3.,3.]])
|
|
|
|
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?
|
|
|
|
def test_overlapping_shrink_assignment_forward(self):
|
|
# Forward shift: read index > write index in overlap
|
|
N = 100000
|
|
shift = 1000
|
|
a = Tensor.arange(N).float().contiguous().realize()
|
|
expected = np.arange(N, dtype=np.float32)
|
|
expected[:N-shift] = expected[shift:].copy()
|
|
with Context(NOOPT=1): a[0:N-shift].assign(a[shift:N]).realize()
|
|
np.testing.assert_allclose(a.numpy(), expected)
|
|
|
|
def test_overlapping_shrink_assignment_reverse(self):
|
|
# Reverse shift: write index > read index in overlap
|
|
N = 100000
|
|
shift = 1000
|
|
a = Tensor.arange(N).float().contiguous().realize()
|
|
expected = np.arange(N, dtype=np.float32)
|
|
expected[shift:] = expected[:N-shift].copy()
|
|
with Context(NOOPT=1): a[shift:N].assign(a[0:N-shift]).realize()
|
|
np.testing.assert_allclose(a.numpy(), expected)
|
|
|
|
def test_nonoverlapping_shrink_assignment(self):
|
|
# TODO: non-overlapping shrinks don't actually need contiguous, could be 1 kernel with smarter range analysis
|
|
a = Tensor.arange(100).float().contiguous().realize()
|
|
expected = np.arange(100, dtype=np.float32)
|
|
expected[0:10] = expected[50:60].copy()
|
|
GlobalCounters.reset()
|
|
a[0:10].assign(a[50:60]).realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 2) # currently conservative, forces contiguous
|
|
np.testing.assert_allclose(a.numpy(), expected)
|
|
|
|
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
|
def test_setitem_half(self):
|
|
a = Tensor.full((8,), 1.0, dtype=dtypes.half).contiguous().realize()
|
|
b = Tensor.full((4,), 2.0, dtype=dtypes.half).contiguous().realize()
|
|
assign = a[:4].assign(b)
|
|
assign.realize()
|
|
np.testing.assert_allclose(a.numpy(), [2., 2., 2., 2., 1., 1., 1., 1.])
|
|
|
|
def test_setitem_list(self):
|
|
a = Tensor.zeros(8).contiguous().realize()
|
|
a[2:5] = [1, 2, 3]
|
|
np.testing.assert_allclose(a.numpy(), [0., 0., 1., 2., 3., 0., 0., 0.])
|
|
|
|
def test_assign_bitcast(self):
|
|
# assign to a bitcast view should modify the underlying buffer
|
|
a = Tensor([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32).realize()
|
|
# IEEE 754: 1.0f = 0x3f800000, 2.0f = 0x40000000, 3.0f = 0x40400000, 4.0f = 0x40800000
|
|
a.bitcast(dtypes.uint32).assign(Tensor([0x40800000, 0x40400000, 0x40000000, 0x3f800000], dtype=dtypes.uint32)).realize()
|
|
np.testing.assert_allclose(a.numpy(), [4.0, 3.0, 2.0, 1.0])
|
|
# double bitcast
|
|
b = Tensor([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32).realize()
|
|
b.bitcast(dtypes.uint32).bitcast(dtypes.int32).assign(Tensor([0x40800000, 0x40400000, 0x40000000, 0x3f800000], dtype=dtypes.int32)).realize()
|
|
np.testing.assert_allclose(b.numpy(), [4.0, 3.0, 2.0, 1.0])
|
|
# shrink then bitcast
|
|
c = Tensor([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32).realize()
|
|
c[0:2].bitcast(dtypes.uint32).assign(Tensor([0x40800000, 0x40400000], dtype=dtypes.uint32)).realize()
|
|
np.testing.assert_allclose(c.numpy(), [4.0, 3.0, 3.0, 4.0])
|
|
|
|
def test_assign_bitcast_different_size(self):
|
|
# assign to a shape-changing bitcast view (only works on DISK currently)
|
|
a = Tensor([0]*8, dtype=dtypes.uint8).realize()
|
|
a.bitcast(dtypes.int64).assign(Tensor([12345], dtype=dtypes.int64)).realize()
|
|
try:
|
|
np.testing.assert_equal(a.numpy(), [57, 48, 0, 0, 0, 0, 0, 0])
|
|
except AssertionError:
|
|
# TODO: broken now
|
|
np.testing.assert_equal(a.numpy(), [0]*8)
|
|
|
|
@unittest.skip("don't use output buffer, and mismatch dtype no longer supported")
|
|
def test_cast_assignment(self):
|
|
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
|
a.realize()
|
|
oba1 = a.uop.base.output_buffer
|
|
a.assign(a.cast(dtypes.int32).realize())
|
|
a.realize()
|
|
oba2 = a.uop.base.output_buffer
|
|
assert oba1 is None and oba2 is None
|
|
np.testing.assert_allclose(a.numpy(), np.arange(N*N,dtype=np.int32).reshape((N,N)))
|
|
|
|
def test_assign_dtype_mismatch(self):
|
|
# assign should not implicitly cast dtypes - this can lose precision
|
|
a = Tensor.zeros(4, dtype=dtypes.float32).contiguous().realize()
|
|
b = Tensor([1, 2, 3, 4], dtype=dtypes.int32)
|
|
with self.assertRaisesRegex(RuntimeError, "assign dtype mismatch"):
|
|
a.assign(b)
|
|
|
|
def test_assign_dtype_mismatch_int64_to_float32(self):
|
|
# int64 -> float32 loses precision for large values, should not be implicit
|
|
a = Tensor.zeros(1, dtype=dtypes.float32).contiguous().realize()
|
|
b = Tensor([16777217], dtype=dtypes.int64) # 2^24 + 1, not exactly representable in float32
|
|
with self.assertRaisesRegex(RuntimeError, "assign dtype mismatch"):
|
|
a.assign(b)
|
|
|
|
def test_assign_shape_broadcast(self):
|
|
# shape broadcasting should work when dtypes match
|
|
a = Tensor.zeros(3, 5, dtype=dtypes.float32).contiguous().realize()
|
|
b = Tensor([1., 2., 3., 4., 5.], dtype=dtypes.float32)
|
|
a.assign(b)
|
|
a.realize()
|
|
expected = np.array([[1., 2., 3., 4., 5.]] * 3)
|
|
np.testing.assert_allclose(a.numpy(), expected)
|
|
|
|
def test_assign_shape_broadcast_2d(self):
|
|
# broadcast (1, 5) to (3, 5)
|
|
a = Tensor.zeros(3, 5, dtype=dtypes.float32).contiguous().realize()
|
|
b = Tensor([[1., 2., 3., 4., 5.]], dtype=dtypes.float32)
|
|
a.assign(b)
|
|
a.realize()
|
|
expected = np.array([[1., 2., 3., 4., 5.]] * 3)
|
|
np.testing.assert_allclose(a.numpy(), expected)
|
|
|
|
def test_disk_assignment(self):
|
|
a = Tensor.empty(5, device=f"disk:{temp('disk_assignment')}").assign(Tensor.ones(5)).numpy()
|
|
np.testing.assert_equal(a, np.ones(5))
|
|
|
|
def test_assign_slice_then_read(self):
|
|
"""Assign to slice then read from buffer - read should see the assigned values.
|
|
This is the KV cache pattern from llm.py.
|
|
"""
|
|
v_pos = Variable("pos", 0, 3).bind(0)
|
|
cache = Tensor.zeros(4, 4).contiguous().realize()
|
|
cache[v_pos:v_pos+1, :].assign(Tensor.ones(1, 4))
|
|
self.assertEqual(cache.sum().item(), 4.0)
|
|
|
|
def test_chained_assign_slice_then_read(self):
|
|
"""Three caches with chained assign-then-read: each block writes to its cache and reads back,
|
|
feeding the result to the next block's assign. Without proper dependency tracking, block N's read
|
|
may see stale data from block N-1's cache (pre-assign zeros instead of the assigned values).
|
|
This is the multi-layer KV cache pattern from llm.py._attention.
|
|
"""
|
|
D, max_ctx = 4, 8
|
|
cache1 = Tensor.zeros(max_ctx, D).contiguous().realize()
|
|
cache2 = Tensor.zeros(max_ctx, D).contiguous().realize()
|
|
cache3 = Tensor.zeros(max_ctx, D).contiguous().realize()
|
|
cache1[:3].assign(Tensor.ones(3, D)).realize()
|
|
cache2[:3].assign(Tensor.ones(3, D) * 2).realize()
|
|
cache3[:3].assign(Tensor.ones(3, D) * 3).realize()
|
|
# block 1: assign [10]*D at position 3, read sum -> c1=[13]*D
|
|
cache1[3:4].assign(Tensor.ones(1, D) * 10)
|
|
c1 = cache1[:4].sum(0, keepdim=True)
|
|
# block 2: assign c1 at position 3, read sum -> c2=[19]*D
|
|
cache2[3:4].assign(c1)
|
|
c2 = cache2[:4].sum(0, keepdim=True)
|
|
# block 3: assign c2 at position 3, read sum -> 112
|
|
cache3[3:4].assign(c2)
|
|
self.assertEqual(cache3[:4].sum().item(), 112.0)
|
|
|
|
def test_chained_assign_kernel_count(self):
|
|
"""Chained pending assigns must not produce excessive kernels (tests recursive transitive processing)."""
|
|
D, N = 4, 5
|
|
caches = [Tensor.zeros(8, D).contiguous().realize() for _ in range(N)]
|
|
caches[0][0:1].assign(Tensor.ones(1, D) * 10)
|
|
x = caches[0][:1].sum(0, keepdim=True)
|
|
for i in range(1, N):
|
|
caches[i][0:1].assign(x)
|
|
x = caches[i][:1].sum(0, keepdim=True)
|
|
GlobalCounters.reset()
|
|
x.realize()
|
|
# N assigns (1 kernel each) producing N kernels total
|
|
self.assertEqual(GlobalCounters.kernel_count, N)
|
|
|
|
def test_shared_computation_assign_kernel_count(self):
|
|
"""When a .contiguous() is shared between an assign value and the next layer's input (like QKV projection in LLM),
|
|
substitute optimization replaces already-realized sub-graphs in remaining pending assigns, preventing kernel escalation.
|
|
Without substitute, pending assign graphs grow linearly and produce 153 kernels instead of 48."""
|
|
D, N = 16, 16
|
|
caches = [Tensor.zeros(4, D).contiguous().realize() for _ in range(N)]
|
|
W = [Tensor.full((D, D*2), 0.01).contiguous().realize() for _ in range(N)]
|
|
x = Tensor.ones(1, D).contiguous().realize()
|
|
for i in range(N):
|
|
shared = (x @ W[i]).contiguous() # .contiguous() UOp is shared between assign (k) and next layer (q)
|
|
k, q = shared[:, :D], shared[:, D:]
|
|
caches[i][0:1].assign(k) # assign references the CONTIGUOUS
|
|
x = q + caches[i][:1] # next layer also references the same CONTIGUOUS through q
|
|
GlobalCounters.reset()
|
|
caches[-1][:1].contiguous().realize()
|
|
# N matmuls + N assigns + 1 final read = 2*N+1 (AFTER embedding allows full graph scheduling with shared contiguous reuse)
|
|
self.assertEqual(GlobalCounters.kernel_count, 2*N+1)
|
|
|
|
def test_double_assign_from_const(self):
|
|
a = Tensor.empty(2)
|
|
a.assign(Tensor.ones(2))
|
|
a.assign(Tensor.ones(2))
|
|
GlobalCounters.reset()
|
|
a.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertEqual(a.tolist(), [1.,1.])
|
|
|
|
def test_nested_after_contiguous_store(self):
|
|
# Mirrors the nested contiguous-write-then-assign-back shape from torch backend view updates.
|
|
base = Tensor.empty(3, dtype=dtypes.int64)
|
|
base.assign(Tensor([1, 2, 3], dtype=dtypes.int64))
|
|
contig = base.contiguous()
|
|
contig.assign(Tensor([1, 4, 3], dtype=dtypes.int64))
|
|
GlobalCounters.reset()
|
|
base.assign(contig).realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 2) # TODO: first copy is dead, could be 1
|
|
self.assertEqual(base.tolist(), [1,4,3])
|
|
|
|
def test_nested_after_contiguous_store_no_init(self):
|
|
# Same shape as test_nested_after_contiguous_store, but without the initial assign.
|
|
base = Tensor.empty(3, dtype=dtypes.int64)
|
|
contig = base.contiguous()
|
|
contig.assign(Tensor([1, 4, 3], dtype=dtypes.int64))
|
|
GlobalCounters.reset()
|
|
base.assign(contig).realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertEqual(base.tolist(), [1,4,3])
|
|
|
|
class TestAssignOrdering(unittest.TestCase):
|
|
"""Tests for complex assign orderings that could differ between lazy and eager execution.
|
|
|
|
The key principle: tinygrad's lazy execution with RAW/WAR dependency tracking should
|
|
produce the same results as eager (immediate) execution for valid programs.
|
|
|
|
These tests exercise edge cases where incorrect dependency tracking could cause:
|
|
- Stale reads (reading before write completes)
|
|
- Lost writes (write ordering reversed)
|
|
- Race conditions (concurrent access to same buffer)
|
|
"""
|
|
|
|
def test_overlapping_slice_assigns(self):
|
|
"""Overlapping slice assigns - later write should win for overlapping elements."""
|
|
buf = Tensor.zeros(8).contiguous().realize()
|
|
buf[0:4].assign(Tensor.ones(4))
|
|
buf[2:6].assign(Tensor.ones(4) * 2)
|
|
np.testing.assert_equal(buf.numpy(), [1,1,2,2,2,2,0,0])
|
|
|
|
def test_overlapping_slice_assigns_reverse(self):
|
|
"""Overlapping slice assigns in reverse order."""
|
|
buf = Tensor.zeros(8).contiguous().realize()
|
|
buf[2:6].assign(Tensor.ones(4) * 2)
|
|
buf[0:4].assign(Tensor.ones(4))
|
|
np.testing.assert_equal(buf.numpy(), [1,1,1,1,2,2,0,0])
|
|
|
|
def test_read_between_writes(self):
|
|
"""Read should see first write before second write happens."""
|
|
buf = Tensor.zeros(4).contiguous().realize()
|
|
buf.assign(Tensor.ones(4))
|
|
r1 = buf.sum().realize() # should see ones = 4
|
|
buf.assign(Tensor.ones(4) * 2)
|
|
r2 = buf.sum().realize() # should see twos = 8
|
|
self.assertEqual(r1.item(), 4)
|
|
self.assertEqual(r2.item(), 8)
|
|
|
|
@unittest.skip("TODO: this is broken")
|
|
def test_write_read_write_chain(self):
|
|
"""Write, read, write chain - middle read must complete before second write."""
|
|
buf = Tensor.zeros(4).contiguous().realize()
|
|
buf.assign(Tensor.ones(4) * 3)
|
|
mid_sum = buf.sum() # lazy read, should be 12
|
|
buf.assign(Tensor.ones(4) * 5)
|
|
final_sum = buf.sum() # lazy read, should be 20
|
|
# Realize in "wrong" order - final first
|
|
self.assertEqual(final_sum.realize().item(), 20)
|
|
self.assertEqual(mid_sum.realize().item(), 12)
|
|
|
|
def test_slice_read_then_full_write(self):
|
|
"""Read from slice, then overwrite full buffer - WAR dependency works for full buffer assigns."""
|
|
buf = Tensor([1.,2.,3.,4.]).contiguous().realize()
|
|
partial = buf[0:2].sum() # lazy read
|
|
buf.assign(Tensor.ones(4) * 10) # overwrite everything
|
|
full = buf.sum()
|
|
# WAR dependency correctly tracked - partial sees original data
|
|
self.assertEqual(partial.realize().item(), 3) # 1+2
|
|
self.assertEqual(full.realize().item(), 40)
|
|
|
|
def test_slice_write_then_full_read(self):
|
|
"""Write to slice, then read full buffer."""
|
|
buf = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
|
|
buf[1:3].assign(Tensor([5, 6]))
|
|
np.testing.assert_equal(buf.numpy(), [0, 5, 6, 0])
|
|
|
|
def test_chained_slice_copies(self):
|
|
"""Copy from one slice to another within same buffer."""
|
|
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
|
|
buf[4:8].assign(buf[0:4].contiguous())
|
|
np.testing.assert_equal(buf.numpy(), [1, 2, 3, 4, 1, 2, 3, 4])
|
|
|
|
def test_swap_slices(self):
|
|
"""Swap two non-overlapping slices - requires reading both before writing."""
|
|
# without .realize() on temps: values not captured before overwriting
|
|
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
|
|
left = buf[0:4].clone() # lazy - not captured yet
|
|
right = buf[4:8].clone() # lazy - not captured yet
|
|
buf[0:4].assign(right).realize() # this works
|
|
buf[4:8].assign(left).realize() # left now reads from modified buf!
|
|
try:
|
|
np.testing.assert_equal(buf.numpy(), [5, 6, 7, 8, 1, 2, 3, 4])
|
|
except AssertionError:
|
|
# TODO: broken now
|
|
np.testing.assert_equal(buf.numpy(), [5, 6, 7, 8, 5, 6, 7, 8])
|
|
|
|
# with .realize() on temps: values captured before writes
|
|
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
|
|
left = buf[0:4].clone().realize()
|
|
right = buf[4:8].clone().realize()
|
|
buf[0:4].assign(right).realize()
|
|
buf[4:8].assign(left).realize()
|
|
np.testing.assert_equal(buf.numpy(), [5, 6, 7, 8, 1, 2, 3, 4])
|
|
|
|
def test_reduction_after_partial_assign(self):
|
|
"""Reduction over buffer after partial assign - must see the assigned values."""
|
|
buf = Tensor.zeros(4, 4).contiguous().realize()
|
|
buf[0:2, :].assign(Tensor.ones(2, 4)) # top half = 1
|
|
total = buf.sum()
|
|
self.assertEqual(total.item(), 8)
|
|
|
|
def test_multiple_reductions_different_views(self):
|
|
"""Multiple reductions over different views of same buffer after assign."""
|
|
buf = Tensor.zeros(4, 4).contiguous().realize()
|
|
buf.assign(Tensor.arange(16).reshape(4, 4).float())
|
|
row_sums = buf.sum(axis=1) # [6, 22, 38, 54]
|
|
col_sums = buf.sum(axis=0) # [24, 28, 32, 36]
|
|
total = buf.sum() # 120
|
|
# All should see the assigned values
|
|
np.testing.assert_equal(row_sums.numpy(), [6, 22, 38, 54])
|
|
np.testing.assert_equal(col_sums.numpy(), [24, 28, 32, 36])
|
|
self.assertEqual(total.item(), 120)
|
|
|
|
def test_assign_from_self_transformed(self):
|
|
"""Assign to buffer from transformed view of itself."""
|
|
buf = Tensor([1, 2, 3, 4]).contiguous().realize()
|
|
# Read and transform, then write back (requires reading before writing)
|
|
buf.assign((buf * 2).contiguous())
|
|
np.testing.assert_equal(buf.numpy(), [2, 4, 6, 8])
|
|
|
|
def test_two_buffers_cross_assign(self):
|
|
"""Two buffers each reading from the other before writing."""
|
|
a = Tensor([1, 2, 3, 4]).contiguous().realize()
|
|
b = Tensor([10, 20, 30, 40]).contiguous().realize()
|
|
# Both read from each other's original values
|
|
a_new = (a + b).contiguous()
|
|
b_new = (a * b).contiguous()
|
|
a.assign(a_new)
|
|
b.assign(b_new)
|
|
Tensor.realize(a, b)
|
|
np.testing.assert_equal(a.numpy(), [11, 22, 33, 44])
|
|
np.testing.assert_equal(b.numpy(), [10, 40, 90, 160])
|
|
|
|
def test_three_buffer_chain(self):
|
|
"""Chain: A depends on B, B depends on C - ordering matters."""
|
|
a = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
|
|
b = Tensor([1, 2, 3, 4]).contiguous().realize()
|
|
c = Tensor([10, 10, 10, 10]).contiguous().realize()
|
|
# b reads from c, a reads from b
|
|
b.assign((b + c).contiguous()) # b = [11, 12, 13, 14]
|
|
a.assign((a + b).contiguous()) # a should see new b = [11, 12, 13, 14]
|
|
Tensor.realize(a, b)
|
|
np.testing.assert_equal(b.numpy(), [11, 12, 13, 14])
|
|
np.testing.assert_equal(a.numpy(), [11, 12, 13, 14])
|
|
|
|
def test_interleaved_assign_read_patterns(self):
|
|
"""Complex interleaved pattern: write A, read A into B, write B, read B."""
|
|
a = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
|
|
b = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
|
|
|
|
a.assign(Tensor([1, 2, 3, 4]))
|
|
b.assign(a.contiguous()) # b should get [1,2,3,4]
|
|
a.assign(Tensor([5, 6, 7, 8]))
|
|
result = b.sum() # should be 10, not 26
|
|
|
|
self.assertEqual(result.item(), 10)
|
|
np.testing.assert_equal(a.numpy(), [5, 6, 7, 8])
|
|
np.testing.assert_equal(b.numpy(), [1, 2, 3, 4])
|
|
|
|
def test_variable_slice_ordering(self):
|
|
"""Variable-indexed slices - conflicting variable binds in same schedule are rejected."""
|
|
v_i = Variable("i", 0, 3)
|
|
buf = Tensor.zeros(4, 4).contiguous().realize()
|
|
buf[v_i.bind(0):v_i.bind(0)+1, :].assign(Tensor.ones(1, 4))
|
|
buf[v_i.bind(1):v_i.bind(1)+1, :].assign(Tensor.ones(1, 4) * 2)
|
|
with self.assertRaises(RuntimeError): buf[0:1, :].sum().item()
|
|
|
|
def test_multi_step_assign_read_write_same_buffer(self):
|
|
"""Assign to m and param reading b, then update b, across multiple steps.
|
|
This is the optimizer bias-correction pattern from issue #13600: m accumulates,
|
|
param is updated using m/(1-b), and b is updated via *= after the reads."""
|
|
b = Tensor([0.5]).contiguous().realize()
|
|
m = Tensor([0.0]).contiguous().realize()
|
|
param = Tensor([1.0]).contiguous().realize()
|
|
for _ in range(10):
|
|
m.assign(0.9 * m + 0.1)
|
|
param.assign(param - m / (1 - b))
|
|
b *= 0.9
|
|
Tensor.realize(param, m, b)
|
|
# numpy reference
|
|
b_np, m_np, p_np = 0.5, 0.0, 1.0
|
|
for _ in range(10):
|
|
m_np = 0.9 * m_np + 0.1
|
|
p_np = p_np - m_np / (1 - b_np)
|
|
b_np *= 0.9
|
|
np.testing.assert_allclose(param.item(), p_np, atol=1e-5)
|
|
|
|
def test_multiple_slice_assigns_then_read(self):
|
|
"""Multiple non-overlapping slice assigns then read."""
|
|
buf = Tensor.zeros(4).contiguous().realize()
|
|
buf[0:1].assign(Tensor.ones(1))
|
|
buf[1:2].assign(Tensor.full((1,), 2.0))
|
|
buf[2:3].assign(Tensor.full((1,), 3.0))
|
|
self.assertEqual(buf.sum().realize().item(), 6.0)
|
|
|
|
# TODO: assigns into views of unrealized non-BUFFER bases are silently dropped
|
|
class TestAssignToUnrealizedView(unittest.TestCase):
|
|
def test_copy(self):
|
|
t = Tensor.zeros(2,2, dtype=dtypes.int).to("CPU:0").contiguous().realize()
|
|
c = t.to("CPU:1") # unrealized COPY
|
|
self.assertIs(c.uop.base.op, Ops.COPY)
|
|
c[:, 1:2].assign(Tensor.ones(2,1, dtype=dtypes.int).to("CPU:1").contiguous().realize())
|
|
try:
|
|
self.assertEqual(c.tolist(), [[0,1],[0,1]])
|
|
except AssertionError:
|
|
# TODO: broken now
|
|
self.assertEqual(c.tolist(), [[0,0],[0,0]])
|
|
|
|
def test_contiguous(self):
|
|
t = Tensor([[1,2],[3,4]]).contiguous().realize()
|
|
c = t.permute(1,0).contiguous() # unrealized CONTIGUOUS
|
|
self.assertIs(c.uop.base.op, Ops.CONTIGUOUS)
|
|
c[:, 1:2].assign(Tensor.ones(2,1, dtype=dtypes.int).contiguous().realize())
|
|
try:
|
|
self.assertEqual(c.tolist(), [[1,1],[2,1]])
|
|
except AssertionError:
|
|
# TODO: broken now
|
|
self.assertEqual(c.tolist(), [[1,3],[2,4]])
|
|
|
|
def test_contiguous_backward(self):
|
|
t = Tensor([[1,2],[3,4]]).contiguous().realize()
|
|
cb = t.contiguous_backward() # unrealized CONTIGUOUS_BACKWARD
|
|
self.assertIs(cb.uop.base.op, Ops.CONTIGUOUS_BACKWARD)
|
|
cb[:, 1:2].assign(Tensor.ones(2,1, dtype=dtypes.int).contiguous().realize())
|
|
try:
|
|
self.assertEqual(cb.tolist(), [[1,1],[3,1]])
|
|
except AssertionError:
|
|
# TODO: broken now
|
|
self.assertEqual(cb.tolist(), [[1,2],[3,4]])
|
|
|
|
def test_detach_copy(self):
|
|
t = Tensor.zeros(2,2, dtype=dtypes.int).to("CPU:0").contiguous().realize()
|
|
d = t.to("CPU:1").detach() # DETACH(unrealized COPY)
|
|
self.assertIs(d.uop.base.op, Ops.COPY)
|
|
d[:, 1:2].assign(Tensor.ones(2,1, dtype=dtypes.int).to("CPU:1").contiguous().realize())
|
|
try:
|
|
self.assertEqual(d.tolist(), [[0,1],[0,1]])
|
|
except AssertionError:
|
|
# TODO: broken now
|
|
self.assertEqual(d.tolist(), [[0,0],[0,0]])
|
|
|
|
def test_detach_contiguous(self):
|
|
t = Tensor([[1,2],[3,4]]).contiguous().realize()
|
|
d = t.permute(1,0).contiguous().detach() # DETACH(unrealized CONTIGUOUS)
|
|
self.assertIs(d.uop.base.op, Ops.CONTIGUOUS)
|
|
d[:, 1:2].assign(Tensor.ones(2,1, dtype=dtypes.int).contiguous().realize())
|
|
try:
|
|
self.assertEqual(d.tolist(), [[1,1],[2,1]])
|
|
except AssertionError:
|
|
# TODO: broken now
|
|
self.assertEqual(d.tolist(), [[1,3],[2,4]])
|
|
|
|
def test_alu(self):
|
|
a = Tensor([1,2,3,4]).contiguous().realize()
|
|
b = Tensor([5,6,7,8]).contiguous().realize()
|
|
c = a + b # unrealized ADD
|
|
self.assertIs(c.uop.base.op, Ops.ADD)
|
|
c[:2].assign(Tensor([99, 99]).realize())
|
|
try:
|
|
self.assertEqual(c.tolist(), [99,99,10,12])
|
|
except AssertionError:
|
|
# TODO: broken now, silently dropped
|
|
self.assertEqual(c.tolist(), [6,8,10,12])
|
|
|
|
def test_reduce(self):
|
|
a = Tensor([[1,2],[3,4]]).contiguous().realize()
|
|
r = a.sum(axis=0) # unrealized REDUCE_AXIS
|
|
self.assertIs(r.uop.base.op, Ops.REDUCE_AXIS)
|
|
r[:1].assign(Tensor([99]).realize())
|
|
try:
|
|
self.assertEqual(r.tolist(), [99,6])
|
|
except AssertionError:
|
|
# TODO: broken now, silently dropped
|
|
self.assertEqual(r.tolist(), [4,6])
|
|
|
|
def test_cast(self):
|
|
a = Tensor([1,2,3,4]).contiguous().realize()
|
|
c = a.float() # unrealized CAST
|
|
self.assertIs(c.uop.base.op, Ops.CAST)
|
|
c[:2].assign(Tensor([99, 99], dtype=dtypes.float).realize())
|
|
try:
|
|
self.assertEqual(c.tolist(), [99,99,3,4])
|
|
except AssertionError:
|
|
# TODO: broken now, silently dropped
|
|
self.assertEqual(c.tolist(), [1,2,3,4])
|
|
|
|
def test_const(self):
|
|
c = Tensor(5).reshape(1, 1).expand(2, 2)
|
|
self.assertIs(c.uop.base.op, Ops.CONST)
|
|
c[:, 1:2].assign(Tensor.ones(2,1, dtype=dtypes.int).contiguous().realize())
|
|
try:
|
|
self.assertEqual(c.tolist(), [[5,1],[5,1]])
|
|
except AssertionError:
|
|
# TODO: broken now, silently dropped
|
|
self.assertEqual(c.tolist(), [[5,5],[5,5]])
|
|
|
|
class TestPartialAssignToSharedBuffer(unittest.TestCase):
|
|
def test_five_slices(self):
|
|
big = Tensor.zeros(50).contiguous().realize()
|
|
views = [big[i*10:(i+1)*10].reshape(2, 5) for i in range(5)]
|
|
for v in views: v.assign(v + 1)
|
|
Tensor.realize(*views)
|
|
for v in views:
|
|
np.testing.assert_allclose(v.numpy(), np.ones((2, 5)))
|
|
|
|
def test_many_slices(self):
|
|
n_params = 10
|
|
big = Tensor.zeros(n_params * 12).contiguous().realize()
|
|
grads = [big[i*12:(i+1)*12].reshape(3, 4) for i in range(n_params)]
|
|
for g in grads: g.assign(g + 1)
|
|
Tensor.realize(*grads)
|
|
for g in grads:
|
|
np.testing.assert_allclose(g.numpy(), np.ones((3, 4)))
|
|
|
|
def test_mixed_shapes(self):
|
|
big = Tensor.zeros(100).contiguous().realize()
|
|
shapes = [(3, 4), (4, 6), (6, 4), (2, 5), (4, 3)]
|
|
pos, views = 0, []
|
|
for s in shapes:
|
|
n = s[0] * s[1]
|
|
views.append(big[pos:pos+n].reshape(*s))
|
|
pos += n
|
|
for v in views: v.assign(v + 1)
|
|
Tensor.realize(*views)
|
|
for v, s in zip(views, shapes):
|
|
np.testing.assert_allclose(v.numpy(), np.ones(s))
|
|
|
|
|
|
class TestAfterCachePatterns(unittest.TestCase):
|
|
def test_double_store_after(self):
|
|
a = Tensor.zeros(10).contiguous()
|
|
b = Tensor.zeros(10).contiguous()
|
|
c = Tensor.ones(10).contiguous()
|
|
Tensor.realize(a, b, c)
|
|
|
|
a_store = a.uop.store(c.uop)
|
|
b_store = b.uop.store(c.uop)
|
|
|
|
a = Tensor(a.uop.after(a_store, b_store))
|
|
a.realize()
|
|
np.testing.assert_array_equal(a.numpy(), 1)
|
|
np.testing.assert_array_equal(b.numpy(), 1)
|
|
|
|
def test_double_store_after_different_sizes(self):
|
|
full = Tensor.zeros(2).contiguous()
|
|
head = Tensor.zeros(1).contiguous()
|
|
full_src = Tensor([1, 2], dtype=dtypes.float).contiguous()
|
|
head_src = Tensor([3], dtype=dtypes.float).contiguous()
|
|
Tensor.realize(full, head, full_src, head_src)
|
|
|
|
full_store = full.uop.store(full_src.uop)
|
|
head_store = head.uop.store(head_src.uop)
|
|
|
|
head = Tensor(head.uop.after(head_store, full_store))
|
|
head.realize()
|
|
np.testing.assert_array_equal(head.numpy(), [3])
|
|
np.testing.assert_array_equal(full.numpy(), [1, 2])
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|