add unique to const, fix longstanding bug (#12965)

* add unique to const, fix longstanding bug

* _force_unique=True

* fix tests

* fix more tests
This commit is contained in:
George Hotz
2025-10-28 15:11:37 +08:00
committed by GitHub
parent e110f4632a
commit b0da173f2f
6 changed files with 54 additions and 13 deletions

View File

@@ -919,5 +919,38 @@ class TestIdxUpcast(unittest.TestCase):
a = Tensor.empty(2**11, 2**11, 1, dtype=dtypes.int8).permute((2, 0, 1)).expand((2**9+10, -1, -1)).contiguous()
a.realize()
class TestTensorUnique(unittest.TestCase):
def test_empty_bufs_unique(self):
a = Tensor.empty(10, 10).contiguous()
b = Tensor.empty(10, 10).contiguous()
Tensor.realize(a,b)
self.assertIsNot(a.uop.buffer, b.uop.buffer)
def test_zeros_bufs_unique_sep(self):
a = Tensor.zeros(10, 10).contiguous()
Tensor.realize(a)
b = Tensor.zeros(10, 10).contiguous()
Tensor.realize(b)
self.assertIsNot(a.uop.buffer, b.uop.buffer)
def test_zeros_bufs_unique(self):
a = Tensor.zeros(10, 10).contiguous()
b = Tensor.zeros(10, 10).contiguous()
Tensor.realize(a,b)
self.assertIsNot(a.uop.buffer, b.uop.buffer)
def test_eye_bufs_unique(self):
a = Tensor.eye(10).contiguous()
b = Tensor.eye(10).contiguous()
Tensor.realize(a,b)
self.assertIsNot(a.uop.buffer, b.uop.buffer)
def test_times_2_not_unique(self):
a = Tensor.zeros(10, 10).contiguous()
b = a * 2
c = a * 2
Tensor.realize(b,c)
self.assertIs(b.uop.buffer, c.uop.buffer)
if __name__ == '__main__':
unittest.main()