mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
multi ram usage tests on the NULL device (#14457)
This commit is contained in:
@@ -1273,19 +1273,20 @@ class TestMultiRamUsage(unittest.TestCase):
|
||||
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).contiguous().realize()
|
||||
self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage
|
||||
|
||||
def _test_matmul_half(self, devs):
|
||||
def _test_matmul_half(self, dev_count:int):
|
||||
N = 32
|
||||
total_mem = {}
|
||||
devs = tuple(f"NULL:{i}" for i in range(dev_count))
|
||||
for dtype in {dtypes.float, dtypes.half}:
|
||||
GlobalCounters.reset()
|
||||
a = Tensor.empty((N, N), dtype=dtype).shard(devs, axis=0)
|
||||
b = Tensor.empty((N, N), dtype=dtype).shard(devs, axis=None)
|
||||
a = Tensor.empty((N, N), dtype=dtype, device=devs[0]).shard(devs, axis=0)
|
||||
b = Tensor.empty((N, N), dtype=dtype, device=devs[0]).shard(devs, axis=None)
|
||||
(a @ b).realize()
|
||||
total_mem[dtype] = GlobalCounters.global_mem
|
||||
self.assertEqual(total_mem[dtypes.half], total_mem[dtypes.float] // 2)
|
||||
|
||||
def test_matmul_half(self): self._test_matmul_half(devices_2)
|
||||
def test_matmul_half_alt(self): self._test_matmul_half(devices_4)
|
||||
def test_matmul_half(self): self._test_matmul_half(dev_count=2)
|
||||
def test_matmul_half_alt(self): self._test_matmul_half(dev_count=4)
|
||||
|
||||
@unittest.skipIf(not_support_multi_device(), "need multi")
|
||||
class TestMultiFromUnrenderable(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user