add unique to const, fix longstanding bug (#12965)

* add unique to const, fix longstanding bug

* _force_unique=True

* fix tests

* fix more tests
This commit is contained in:
George Hotz
2025-10-28 15:11:37 +08:00
committed by GitHub
parent e110f4632a
commit b0da173f2f
6 changed files with 54 additions and 13 deletions

View File

@@ -112,7 +112,7 @@ class TestRealWorld(unittest.TestCase):
loss.backward()
optimizer.step()
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 102)
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 103)
@unittest.skipIf(CI and Device.DEFAULT in {"CPU", "CL"}, "slow")
def test_forward_cifar(self):
@@ -176,7 +176,7 @@ class TestRealWorld(unittest.TestCase):
for v in data.values(): v.to_(Device.DEFAULT)
helper_test("train_bert", lambda: (data["input_ids"], data["segment_ids"], data["input_mask"], data["masked_lm_positions"], \
data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.31, 358)
data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.31, 427)
if __name__ == '__main__':
unittest.main()

View File

@@ -370,6 +370,7 @@ class TestSchedule(unittest.TestCase):
# NOTE: this is causing "LAZYCACHE=1 incorrectly reuses contiguous const" #4562
# should contiguous dedup?
@unittest.skip("we do the exact opposite now")
def test_dedup_contiguous(self):
a = Tensor.ones(4).contiguous()
b = Tensor.ones(4).contiguous()
@@ -446,7 +447,7 @@ class TestSchedule(unittest.TestCase):
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
def test_fold_conv_batchnorm_optim(self):
# this is too high
for optim, cnt in [(nn.optim.Adam, 21), (nn.optim.SGD, 8)]:
for optim, cnt in [(nn.optim.Adam, 28), (nn.optim.SGD, 8)]:
with self.subTest(optim=optim.__name__):
with Tensor.train():
img = Tensor.ones(1,3,4,4)
@@ -1220,7 +1221,7 @@ class TestSchedule(unittest.TestCase):
_realize_weights(layer)
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
layer(x).relu().sum().backward()
check_schedule(opt.schedule_step(), 16)
check_schedule(opt.schedule_step(), 19)
def test_adam_conv_fuse(self):
with Tensor.train():
@@ -1230,7 +1231,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
opt.zero_grad()
c1(img).relu().sum().backward()
check_schedule(opt.schedule_step(), 16)
check_schedule(opt.schedule_step(), 19)
def test_adam_2convs_fuse(self):
with Tensor.train():
@@ -1241,7 +1242,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 18)
check_schedule(opt.schedule_step(), 21)
def test_sgd_conv_fuse(self):
with Tensor.train():

View File

@@ -919,5 +919,38 @@ class TestIdxUpcast(unittest.TestCase):
a = Tensor.empty(2**11, 2**11, 1, dtype=dtypes.int8).permute((2, 0, 1)).expand((2**9+10, -1, -1)).contiguous()
a.realize()
class TestTensorUnique(unittest.TestCase):
def test_empty_bufs_unique(self):
a = Tensor.empty(10, 10).contiguous()
b = Tensor.empty(10, 10).contiguous()
Tensor.realize(a,b)
self.assertIsNot(a.uop.buffer, b.uop.buffer)
def test_zeros_bufs_unique_sep(self):
a = Tensor.zeros(10, 10).contiguous()
Tensor.realize(a)
b = Tensor.zeros(10, 10).contiguous()
Tensor.realize(b)
self.assertIsNot(a.uop.buffer, b.uop.buffer)
def test_zeros_bufs_unique(self):
a = Tensor.zeros(10, 10).contiguous()
b = Tensor.zeros(10, 10).contiguous()
Tensor.realize(a,b)
self.assertIsNot(a.uop.buffer, b.uop.buffer)
def test_eye_bufs_unique(self):
a = Tensor.eye(10).contiguous()
b = Tensor.eye(10).contiguous()
Tensor.realize(a,b)
self.assertIsNot(a.uop.buffer, b.uop.buffer)
def test_times_2_not_unique(self):
a = Tensor.zeros(10, 10).contiguous()
b = a * 2
c = a * 2
Tensor.realize(b,c)
self.assertIs(b.uop.buffer, c.uop.buffer)
if __name__ == '__main__':
unittest.main()