mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
simple LoadOps.ASSIGN (#3745)
* simple LoadOps.ASSIGN * skip that test * don't assign in onnx ops gemm * track cache usage * recreate the lazybuffer to avoid the cache * fix contigs * skip that test * lol * better letters
This commit is contained in:
@@ -20,6 +20,21 @@ class TestAssign(unittest.TestCase):
|
||||
assert ba1 == ba2 and ba1 != bb1
|
||||
np.testing.assert_allclose(a.numpy(), (np.arange(N*N)*2).reshape((N,N)))
|
||||
|
||||
def test_assign_zeros_good(self):
|
||||
a = Tensor.zeros(10,10).contiguous()
|
||||
a.assign(Tensor.ones(10,10))
|
||||
b = Tensor.zeros(10,10).contiguous()
|
||||
a.realize()
|
||||
np.testing.assert_allclose(b.numpy(), 0)
|
||||
|
||||
def test_assign_zeros(self):
|
||||
a = Tensor.zeros(10,10).contiguous()
|
||||
b = Tensor.zeros(10,10).contiguous()
|
||||
#with self.assertRaises(RuntimeError):
|
||||
a.assign(Tensor.ones(10,10))
|
||||
a.realize()
|
||||
np.testing.assert_allclose(b.numpy(), 0)
|
||||
|
||||
def test_assign_add(self):
|
||||
def f(x):
|
||||
x += 1
|
||||
@@ -98,14 +113,14 @@ class TestAssign(unittest.TestCase):
|
||||
a = (Tensor.rand(4,4).realize() + 1)
|
||||
kc = GlobalCounters.kernel_count
|
||||
b.assign(a.contiguous()).realize()
|
||||
assert GlobalCounters.kernel_count - kc == 1
|
||||
assert GlobalCounters.kernel_count - kc == 2
|
||||
|
||||
def test_assign_contiguous_permute(self):
|
||||
b = Tensor.rand(4,4).realize()
|
||||
a = (Tensor.rand(4,4).realize() + 1).permute((1,0))
|
||||
kc = GlobalCounters.kernel_count
|
||||
b.assign(a.contiguous()).realize()
|
||||
assert GlobalCounters.kernel_count - kc == 1
|
||||
assert GlobalCounters.kernel_count - kc == 2
|
||||
|
||||
def test_permuted_assignment(self):
|
||||
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
||||
@@ -114,12 +129,13 @@ class TestAssign(unittest.TestCase):
|
||||
b.realize()
|
||||
ba1 = a.lazydata.base.realized
|
||||
bb1 = b.lazydata.base.realized
|
||||
a = a.permute(1,0)
|
||||
a += b
|
||||
a.realize()
|
||||
ba2 = a.lazydata.base.realized
|
||||
assert ba1 != ba2 and ba1 != bb1
|
||||
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
|
||||
with self.assertRaises(RuntimeError):
|
||||
a = a.permute(1,0)
|
||||
a += b
|
||||
a.realize()
|
||||
ba2 = a.lazydata.base.realized
|
||||
assert ba1 != ba2 and ba1 != bb1
|
||||
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
|
||||
|
||||
def test_post_permuted_assignment(self):
|
||||
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
||||
@@ -129,12 +145,13 @@ class TestAssign(unittest.TestCase):
|
||||
#GlobalCounters.cache = []
|
||||
ba1 = a.lazydata.base.realized # noqa: F841
|
||||
bb1 = b.lazydata.base.realized # noqa: F841
|
||||
a.assign(a.permute(1,0) + b) # this should not work!
|
||||
a.realize()
|
||||
ba2 = a.lazydata.base.realized # noqa: F841
|
||||
# NOTE: don't test that it's assigned
|
||||
#assert ba1 == ba2 and ba1 != bb1
|
||||
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
|
||||
with self.assertRaises(RuntimeError):
|
||||
a.assign(a.permute(1,0) + b) # this should not work!
|
||||
a.realize()
|
||||
ba2 = a.lazydata.base.realized # noqa: F841
|
||||
# NOTE: don't test that it's assigned
|
||||
#assert ba1 == ba2 and ba1 != bb1
|
||||
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
|
||||
|
||||
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?
|
||||
|
||||
|
||||
Reference in New Issue
Block a user