simple LoadOps.ASSIGN (#3745)

* simple LoadOps.ASSIGN

* skip that test

* don't assign in onnx ops gemm

* track cache usage

* recreate the lazybuffer to avoid the cache

* fix contigs

* skip that test

* lol

* better letters
This commit is contained in:
George Hotz
2024-03-14 20:44:34 -07:00
committed by GitHub
parent 75d4344cda
commit 641f347232
12 changed files with 73 additions and 49 deletions

View File

@@ -20,6 +20,21 @@ class TestAssign(unittest.TestCase):
assert ba1 == ba2 and ba1 != bb1
np.testing.assert_allclose(a.numpy(), (np.arange(N*N)*2).reshape((N,N)))
def test_assign_zeros_good(self):
a = Tensor.zeros(10,10).contiguous()
a.assign(Tensor.ones(10,10))
b = Tensor.zeros(10,10).contiguous()
a.realize()
np.testing.assert_allclose(b.numpy(), 0)
def test_assign_zeros(self):
a = Tensor.zeros(10,10).contiguous()
b = Tensor.zeros(10,10).contiguous()
#with self.assertRaises(RuntimeError):
a.assign(Tensor.ones(10,10))
a.realize()
np.testing.assert_allclose(b.numpy(), 0)
def test_assign_add(self):
def f(x):
x += 1
@@ -98,14 +113,14 @@ class TestAssign(unittest.TestCase):
a = (Tensor.rand(4,4).realize() + 1)
kc = GlobalCounters.kernel_count
b.assign(a.contiguous()).realize()
assert GlobalCounters.kernel_count - kc == 1
assert GlobalCounters.kernel_count - kc == 2
def test_assign_contiguous_permute(self):
b = Tensor.rand(4,4).realize()
a = (Tensor.rand(4,4).realize() + 1).permute((1,0))
kc = GlobalCounters.kernel_count
b.assign(a.contiguous()).realize()
assert GlobalCounters.kernel_count - kc == 1
assert GlobalCounters.kernel_count - kc == 2
def test_permuted_assignment(self):
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
@@ -114,12 +129,13 @@ class TestAssign(unittest.TestCase):
b.realize()
ba1 = a.lazydata.base.realized
bb1 = b.lazydata.base.realized
a = a.permute(1,0)
a += b
a.realize()
ba2 = a.lazydata.base.realized
assert ba1 != ba2 and ba1 != bb1
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
with self.assertRaises(RuntimeError):
a = a.permute(1,0)
a += b
a.realize()
ba2 = a.lazydata.base.realized
assert ba1 != ba2 and ba1 != bb1
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
def test_post_permuted_assignment(self):
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
@@ -129,12 +145,13 @@ class TestAssign(unittest.TestCase):
#GlobalCounters.cache = []
ba1 = a.lazydata.base.realized # noqa: F841
bb1 = b.lazydata.base.realized # noqa: F841
a.assign(a.permute(1,0) + b) # this should not work!
a.realize()
ba2 = a.lazydata.base.realized # noqa: F841
# NOTE: don't test that it's assigned
#assert ba1 == ba2 and ba1 != bb1
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
with self.assertRaises(RuntimeError):
a.assign(a.permute(1,0) + b) # this should not work!
a.realize()
ba2 = a.lazydata.base.realized # noqa: F841
# NOTE: don't test that it's assigned
#assert ba1 == ba2 and ba1 != bb1
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?