mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix: fix null tests to actually use null device (#15104)
This commit is contained in:
@@ -66,16 +66,17 @@ class TestMultiRamUsage(unittest.TestCase):
|
||||
self.assertUsed(256 * 4) # TODO: can be zero
|
||||
|
||||
def test_zeros_per_device(self):
|
||||
_ = Tensor.zeros(self.N, self.N).contiguous().realize()
|
||||
_ = Tensor.zeros(self.N, self.N, device="NULL").contiguous().realize()
|
||||
self.assertDeviceUsed({"NULL": self.N*self.N*4})
|
||||
|
||||
def test_zeros_del_per_device(self):
|
||||
_ = Tensor.zeros(self.N, self.N).contiguous().realize()
|
||||
_ = Tensor.zeros(self.N, self.N, device="NULL").contiguous().realize()
|
||||
del _
|
||||
self.assertDeviceUsed({"NULL": 0})
|
||||
|
||||
def test_zeros_copy_per_device(self):
|
||||
_ = Tensor.zeros(self.N, self.N).contiguous().to(("NULL:1", "NULL:2")).realize()
|
||||
devices_2 = ("NULL:1", "NULL:2")
|
||||
_ = Tensor.zeros(self.N, self.N).contiguous().to(devices_2).realize()
|
||||
self.assertDeviceUsed({"NULL:1": self.N*self.N*4, "NULL:2": self.N*self.N*4})
|
||||
|
||||
def test_zeros_shard_per_device(self):
|
||||
@@ -85,7 +86,7 @@ class TestMultiRamUsage(unittest.TestCase):
|
||||
|
||||
def test_sharded_memory_replicated_per_device(self):
|
||||
devices_4 = tuple(f"NULL:{i+1}" for i in range(4))
|
||||
X = Tensor.ones(256).contiguous().realize()
|
||||
X = Tensor.ones(256, device="NULL").contiguous().realize()
|
||||
self.assertDeviceUsed({"NULL": 256*4})
|
||||
X.shard_(devices_4).realize()
|
||||
for d in devices_4:
|
||||
|
||||
Reference in New Issue
Block a user