mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
165 lines
6.0 KiB
Python
165 lines
6.0 KiB
Python
import gc, unittest
|
|
from tinygrad import Tensor, GlobalCounters, dtypes
|
|
from tinygrad.engine.jit import TinyJit
|
|
|
|
class TestMultiRamUsage(unittest.TestCase):
|
|
def setUp(self):
|
|
gc.collect()
|
|
self.baseline = GlobalCounters.mem_used
|
|
self.baseline_per_device = dict(GlobalCounters.mem_used_per_device)
|
|
self.N = 100
|
|
def assertUsed(self, amt, strict=True):
|
|
gc.collect()
|
|
used = GlobalCounters.mem_used - self.baseline
|
|
print(f"used {used} bytes")
|
|
if strict: self.assertEqual(used, amt)
|
|
else: self.assertLessEqual(used, amt)
|
|
def assertDeviceUsed(self, expected:dict[str, int]):
|
|
gc.collect()
|
|
for dev, amt in expected.items():
|
|
used = GlobalCounters.mem_used_per_device[dev] - self.baseline_per_device.get(dev, 0)
|
|
self.assertEqual(used, amt, f"device {dev}: expected {amt} bytes used, got {used}")
|
|
|
|
def test_zeros(self):
|
|
_ = Tensor.zeros(self.N, self.N).contiguous().realize()
|
|
self.assertUsed(self.N*self.N*4)
|
|
|
|
def test_zeros_del(self):
|
|
_ = Tensor.zeros(self.N, self.N).contiguous().realize()
|
|
del _
|
|
self.assertUsed(0)
|
|
|
|
def test_zeros_copy(self):
|
|
devices_2 = ("NULL:1", "NULL:2")
|
|
_ = Tensor.zeros(self.N, self.N).contiguous().to(devices_2).realize()
|
|
# NOTE: the first one on the DEFAULT device should be freed
|
|
self.assertUsed(self.N*self.N*4*2)
|
|
|
|
def test_zeros_shard(self, devices=("NULL:1", "NULL:2")):
|
|
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices, axis=0).realize()
|
|
self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage
|
|
def test_zeros_shard_self(self): self.test_zeros_shard(("NULL:0", "NULL:1"))
|
|
|
|
def test_zeros_contiguous_shard(self):
|
|
devices_2 = ("NULL:1", "NULL:2")
|
|
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).contiguous().realize()
|
|
self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage
|
|
|
|
def test_sharded_memory_replicated(self):
|
|
devices_4 = tuple(f"NULL:{i+1}" for i in range(4))
|
|
X = Tensor.ones(256).contiguous().realize()
|
|
self.assertUsed(256 * 4)
|
|
X.shard_(devices_4).realize()
|
|
self.assertUsed(256 * 4 * 4)
|
|
|
|
def test_sharded_memory_replicated_const(self):
|
|
devices_4 = tuple(f"NULL:{i+1}" for i in range(4))
|
|
X = Tensor.ones(256).realize()
|
|
self.assertUsed(0)
|
|
X.shard_(devices_4).realize()
|
|
self.assertUsed(256 * 4 * 4) # TODO: can be zero
|
|
|
|
def test_sharded_memory_axis_const(self):
|
|
devices_4 = tuple(f"NULL:{i+1}" for i in range(4))
|
|
X = Tensor.ones(256).realize()
|
|
self.assertUsed(0)
|
|
X.shard_(devices_4, axis=0).realize()
|
|
self.assertUsed(256 * 4) # TODO: can be zero
|
|
|
|
def test_zeros_per_device(self):
|
|
_ = Tensor.zeros(self.N, self.N, device="NULL").contiguous().realize()
|
|
self.assertDeviceUsed({"NULL": self.N*self.N*4})
|
|
|
|
def test_zeros_del_per_device(self):
|
|
_ = Tensor.zeros(self.N, self.N, device="NULL").contiguous().realize()
|
|
del _
|
|
self.assertDeviceUsed({"NULL": 0})
|
|
|
|
def test_zeros_copy_per_device(self):
|
|
devices_2 = ("NULL:1", "NULL:2")
|
|
_ = Tensor.zeros(self.N, self.N).contiguous().to(devices_2).realize()
|
|
self.assertDeviceUsed({"NULL:1": self.N*self.N*4, "NULL:2": self.N*self.N*4})
|
|
|
|
def test_zeros_shard_per_device(self):
|
|
devices_2 = ("NULL:1", "NULL:2")
|
|
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).realize()
|
|
self.assertDeviceUsed({"NULL:1": self.N*(self.N//2)*4, "NULL:2": self.N*(self.N//2)*4})
|
|
|
|
def test_sharded_memory_replicated_per_device(self):
|
|
devices_4 = tuple(f"NULL:{i+1}" for i in range(4))
|
|
X = Tensor.ones(256, device="NULL").contiguous().realize()
|
|
self.assertDeviceUsed({"NULL": 256*4})
|
|
X.shard_(devices_4).realize()
|
|
for d in devices_4:
|
|
self.assertDeviceUsed({d: 256*4})
|
|
|
|
def _test_matmul_half(self, dev_count:int):
|
|
N = 32
|
|
total_mem = {}
|
|
devs = tuple(f"NULL:{i}" for i in range(dev_count))
|
|
for dtype in {dtypes.float, dtypes.half}:
|
|
GlobalCounters.reset()
|
|
a = Tensor.empty((N, N), dtype=dtype, device=devs[0]).shard(devs, axis=0)
|
|
b = Tensor.empty((N, N), dtype=dtype, device=devs[0]).shard(devs, axis=None)
|
|
(a @ b).realize()
|
|
total_mem[dtype] = GlobalCounters.global_mem
|
|
self.assertEqual(total_mem[dtypes.half], total_mem[dtypes.float] // 2)
|
|
|
|
def test_matmul_half(self): self._test_matmul_half(dev_count=2)
|
|
def test_matmul_half_alt(self): self._test_matmul_half(dev_count=4)
|
|
|
|
def test_multi_layer_allreduce(self):
|
|
N = 32
|
|
devices_2 = ("NULL:1", "NULL:2")
|
|
|
|
def make_inp():
|
|
x = Tensor.zeros(N, N).contiguous().shard(devices_2, axis=None).realize()
|
|
w1 = Tensor.zeros(N, N).contiguous().shard(devices_2, axis=1).realize()
|
|
w2 = Tensor.zeros(N, N).contiguous().shard(devices_2, axis=0).realize()
|
|
return x, w1, w2
|
|
|
|
def run_layers(n_layers):
|
|
GlobalCounters.reset()
|
|
|
|
@TinyJit
|
|
def f(x, w1, w2):
|
|
for _ in range(n_layers):
|
|
x = (x @ w1 @ w2)
|
|
return x.contiguous()
|
|
|
|
for _ in range(3):
|
|
a = make_inp()
|
|
r = f(*a)
|
|
del a, r
|
|
|
|
gc.collect()
|
|
return GlobalCounters.mem_used
|
|
|
|
mem_2 = run_layers(2)
|
|
mem_4 = run_layers(4)
|
|
self.assertEqual(mem_2, mem_4, f"graph memory should not grow with layers: 2 layers={mem_2}, 4 layers={mem_4}")
|
|
|
|
class TestMultiAxis(unittest.TestCase):
|
|
def test_reshape_shard_invalid(self):
|
|
devices = ("NULL:0", "NULL:1")
|
|
t = Tensor.ones(4, 3).shard(devices, axis=0)
|
|
with self.assertRaises(RuntimeError, msg="reshape cannot move items between shards"):
|
|
t.reshape(3, 4).uop.axis
|
|
|
|
def test_reshape_shard_valid(self):
|
|
devices = ("NULL:0", "NULL:1")
|
|
t = Tensor.ones(4, 8).shard(devices, axis=0)
|
|
self.assertEqual(t.reshape(2, 16).uop.axis, 0)
|
|
self.assertEqual(t.reshape(2, 2, 8).uop.axis, 0)
|
|
|
|
def test_empty_like_sharded(self):
|
|
t = Tensor.ones(4, 8).shard(("NULL:0", "NULL:1"), axis=0)
|
|
e = t.empty_like()
|
|
self.assertEqual(e.shape, t.shape)
|
|
self.assertEqual(e.device, t.device)
|
|
self.assertEqual(e.uop.axis, 0)
|
|
self.assertTrue(e.uop.has_buffer_identity())
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|