mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
reshape is lazy now, so better to raise from the .axis call and not have caller to handle invalid case
72 lines
2.7 KiB
Python
72 lines
2.7 KiB
Python
import unittest
|
|
from tinygrad import Tensor, GlobalCounters, dtypes
|
|
|
|
class TestMultiRamUsage(unittest.TestCase):
|
|
def setUp(self):
|
|
self.baseline = GlobalCounters.mem_used
|
|
self.N = 100
|
|
def assertUsed(self, amt, strict=True):
|
|
used = GlobalCounters.mem_used - self.baseline
|
|
print(f"used {used} bytes")
|
|
if strict: self.assertEqual(used, amt)
|
|
else: self.assertLessEqual(used, amt)
|
|
|
|
def test_zeros(self):
|
|
_ = Tensor.zeros(self.N, self.N).contiguous().realize()
|
|
self.assertUsed(self.N*self.N*4)
|
|
|
|
def test_zeros_del(self):
|
|
_ = Tensor.zeros(self.N, self.N).contiguous().realize()
|
|
del _
|
|
self.assertUsed(0)
|
|
|
|
@unittest.skip("flaky")
|
|
def test_zeros_copy(self):
|
|
devices_2 = ("NULL:1", "NULL:2")
|
|
_ = Tensor.zeros(self.N, self.N).contiguous().to(devices_2).realize()
|
|
# NOTE: the first one on the DEFAULT device should be freed
|
|
self.assertUsed(self.N*self.N*4*2)
|
|
|
|
@unittest.skip("flaky")
|
|
def test_zeros_shard(self, devices=("NULL:1", "NULL:2")):
|
|
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices, axis=0).realize()
|
|
self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage
|
|
def test_zeros_shard_self(self): self.test_zeros_shard(("NULL:0", "NULL:1"))
|
|
|
|
@unittest.skip("flaky")
|
|
def test_zeros_contiguous_shard(self):
|
|
devices_2 = ("NULL:1", "NULL:2")
|
|
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).contiguous().realize()
|
|
self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage
|
|
|
|
def _test_matmul_half(self, dev_count:int):
|
|
N = 32
|
|
total_mem = {}
|
|
devs = tuple(f"NULL:{i}" for i in range(dev_count))
|
|
for dtype in {dtypes.float, dtypes.half}:
|
|
GlobalCounters.reset()
|
|
a = Tensor.empty((N, N), dtype=dtype, device=devs[0]).shard(devs, axis=0)
|
|
b = Tensor.empty((N, N), dtype=dtype, device=devs[0]).shard(devs, axis=None)
|
|
(a @ b).realize()
|
|
total_mem[dtype] = GlobalCounters.global_mem
|
|
self.assertEqual(total_mem[dtypes.half], total_mem[dtypes.float] // 2)
|
|
|
|
def test_matmul_half(self): self._test_matmul_half(dev_count=2)
|
|
def test_matmul_half_alt(self): self._test_matmul_half(dev_count=4)
|
|
|
|
class TestMultiAxis(unittest.TestCase):
|
|
def test_reshape_shard_invalid(self):
|
|
devices = ("NULL:0", "NULL:1")
|
|
t = Tensor.ones(4, 3).shard(devices, axis=0)
|
|
with self.assertRaises(RuntimeError, msg="reshape cannot move items between shards"):
|
|
t.reshape(3, 4).uop.axis
|
|
|
|
def test_reshape_shard_valid(self):
|
|
devices = ("NULL:0", "NULL:1")
|
|
t = Tensor.ones(4, 8).shard(devices, axis=0)
|
|
self.assertEqual(t.reshape(2, 16).uop.axis, 0)
|
|
self.assertEqual(t.reshape(2, 2, 8).uop.axis, 0)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|