fix: fix null tests to actually use null device (#15104)

This commit is contained in:
wozeparrot
2026-03-03 18:05:47 +08:00
committed by GitHub
parent 7d025089e3
commit 529318259c

View File

@@ -66,16 +66,17 @@ class TestMultiRamUsage(unittest.TestCase):
self.assertUsed(256 * 4) # TODO: can be zero
def test_zeros_per_device(self):
_ = Tensor.zeros(self.N, self.N).contiguous().realize()
_ = Tensor.zeros(self.N, self.N, device="NULL").contiguous().realize()
self.assertDeviceUsed({"NULL": self.N*self.N*4})
def test_zeros_del_per_device(self):
_ = Tensor.zeros(self.N, self.N).contiguous().realize()
_ = Tensor.zeros(self.N, self.N, device="NULL").contiguous().realize()
del _
self.assertDeviceUsed({"NULL": 0})
def test_zeros_copy_per_device(self):
_ = Tensor.zeros(self.N, self.N).contiguous().to(("NULL:1", "NULL:2")).realize()
devices_2 = ("NULL:1", "NULL:2")
_ = Tensor.zeros(self.N, self.N).contiguous().to(devices_2).realize()
self.assertDeviceUsed({"NULL:1": self.N*self.N*4, "NULL:2": self.N*self.N*4})
def test_zeros_shard_per_device(self):
@@ -85,7 +86,7 @@ class TestMultiRamUsage(unittest.TestCase):
def test_sharded_memory_replicated_per_device(self):
devices_4 = tuple(f"NULL:{i+1}" for i in range(4))
X = Tensor.ones(256).contiguous().realize()
X = Tensor.ones(256, device="NULL").contiguous().realize()
self.assertDeviceUsed({"NULL": 256*4})
X.shard_(devices_4).realize()
for d in devices_4: