diff --git a/test/test_const_folding.py b/test/test_const_folding.py index 2ffce0941e..fd1900106d 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -104,14 +104,12 @@ class TestMultiConstFolding(unittest.TestCase): np.testing.assert_equal((t * 0).numpy(), [0] * 16) np.testing.assert_equal((t * 1).numpy(), np.arange(16)) - @unittest.expectedFailure def test_multi_const_folding_tensor(self): ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4)) t = Tensor.arange(16).float().realize().to(ds) zero = Tensor.zeros(16).realize().to(ds) one = Tensor.ones(16).realize().to(ds) - # TODO: fix const to multi and const folding multi # const folded _check_ast_count(0, t + zero) _check_ast_count(0, zero + t) @@ -132,3 +130,6 @@ class TestMultiConstFolding(unittest.TestCase): _check_ast_count(0, t ** 0) _check_ast_count(0, t ** 1) _check_ast_count(0, 1 ** t) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/test/test_multitensor.py b/test/test_multitensor.py index f7c0063cf2..cfa0964608 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -40,6 +40,29 @@ class TestMultiTensor(unittest.TestCase): assert lb.shape == (128,) (X + X).realize() + def test_sharded_memory(self): + mem_base = GlobalCounters.mem_used + + X = Tensor.ones(256).contiguous().realize() + assert GlobalCounters.mem_used-mem_base== X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base + X.shard_((d0, d1, d2, d3)).realize() + assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256 * 4, GlobalCounters.mem_used-mem_base + + X = Tensor.ones(256).contiguous().realize() + assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base + X.shard_((d0, d1, d2, d3), axis=0).realize() + assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base + + X = Tensor.ones(256).realize() + assert GlobalCounters.mem_used-mem_base == 0 + X.shard_((d0, d1, d2, d3)).realize() + assert GlobalCounters.mem_used-mem_base == 0 + + X = Tensor.ones(256).realize() + assert GlobalCounters.mem_used-mem_base == 0 + X.shard_((d0, d1, d2, d3), axis=0).realize() + assert GlobalCounters.mem_used-mem_base == 0 + def test_shard_same_device(self): X = Tensor.ones(256).contiguous().realize() X.shard_((d0, X.device), 0) diff --git a/tinygrad/features/multi.py b/tinygrad/features/multi.py index 4b446d670a..dbae4882e6 100644 --- a/tinygrad/features/multi.py +++ b/tinygrad/features/multi.py @@ -70,8 +70,9 @@ class MultiLazyBuffer: @staticmethod def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None): - lbs = [lb.contiguous() if lb.base != lb else lb] * len(devices) - return MultiLazyBuffer([lb.copy_to_device(d).contiguous() for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)], axis) + lbs = [lb.contiguous() if lb.base != lb and not lb.is_unrealized_unpadded_const() else lb] * len(devices) + sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)] + return MultiLazyBuffer([lb if lb.is_unrealized_unpadded_const() else lb.contiguous() for lb in sharded_lbs], axis) def copy_to_device(self, device:str) -> LazyBuffer: if self.axis is None: return self.lbs[self.real.index(True)].copy_to_device(device) @@ -82,9 +83,6 @@ class MultiLazyBuffer: llbs.append(lb.pad(pad_arg)) return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs) - # TODO: fix this - def is_unrealized_unpadded_const(self): return False - # passthroughs def is_realized(self) -> bool: return all([lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True]) def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)