diff --git a/test/test_const_folding.py b/test/test_const_folding.py index 1366853486..495f8db305 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -240,6 +240,8 @@ class TestMultiConstFolding(unittest.TestCase): _check_ast_count(0, t ** 1) _check_ast_count(0, 1 ** t) + # failing because multi calls .contiguous() on every single sharded uop + @unittest.expectedFailure def test_multi_const_folding_tensor(self): ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4)) t = Tensor.arange(16).float().realize().to(ds) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index b2f12987af..2cbcc0ad14 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -5,7 +5,7 @@ from tinygrad.ops import Ops from tinygrad.helpers import CI, getenv, prod, Context from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.engine.schedule import create_schedule -from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner +from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule from tinygrad.multi import all_reduce, MultiLazyBuffer import numpy as np from hypothesis import given, strategies as strat, settings @@ -704,6 +704,19 @@ class TestMultiTensor(unittest.TestCase): t = Tensor.rand(16, 16).shard(devices_2, axis=0) np.testing.assert_allclose(t.numpy(), t.clone().numpy()) + def test_multi_const_folding(self): + with Context(TRACK_MATCH_STATS=0): + a = Tensor.arange(3).realize() + zeros = Tensor.zeros(3).realize() + b = a.to(devices_2)*zeros.to(devices_2) + sched = b.schedule() + self.assertEqual(len(sched), 6) + # notably, only two copies (for the arange) - vs 4 copies if we didn't fold the const copy + self.assertEqual(len([x for x in sched if any(u.op is Ops.COPY for u in x.ast.toposort)]), 2) + # all these kernels are just because multi calls contiguous on every single shard + run_schedule(sched) + self.assertListEqual(b.tolist(), [0, 0, 0]) + @unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI") class TestHandleData(unittest.TestCase): def test_copied_to_device(self): diff --git a/test/test_schedule.py b/test/test_schedule.py index b0219595ae..fbe92bedae 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -2163,5 +2163,19 @@ class TestConst(unittest.TestCase): run_schedule(sched, var_vals) self.assertEqual(a.tolist(), 3) +@unittest.skipIf(Device.DEFAULT == "CLANG", "tests copy from another device to clang") +class TestCopyFolding(unittest.TestCase): + def test_const_copy_is_free(self): + b = Tensor(1).to("CLANG") + check_schedule(b, 0, filter_sink=False) + assert b.item() == 1 + + def test_late_const_copy_folding(self): + a = Tensor.arange(3).realize() + zeros = Tensor.zeros(3).realize() + b = (a*zeros).to("CLANG") + run_schedule(check_schedule(b, 0, filter_sink=False)) + self.assertListEqual(b.tolist(), [0, 0, 0]) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/multi.py b/tinygrad/multi.py index 0ef54a9434..4d1649f867 100644 --- a/tinygrad/multi.py +++ b/tinygrad/multi.py @@ -66,7 +66,8 @@ class MultiLazyBuffer(MathTrait): assert (axis is None) == (bounds is None), "must specify bounds iff axis is specified" lbs = [lb] * len(devices) sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis, bounds) if axis is not None and bounds is not None else lbs, devices)] - return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous(allow_buffer_view=False) for lb in sharded_lbs], axis) + # NOTE: this contiguous is making it impossible for the scheduler to do late const folding + return MultiLazyBuffer([lb.contiguous(allow_buffer_view=False) for lb in sharded_lbs], axis) def copy_to_device(self, device:str) -> UOp: if self.axis is None: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index f59b7eaea1..8b9ca14101 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -453,8 +453,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def copy_to_device(self, device:str, force=False, clone:bool=False) -> UOp: # no COPY if self.device == device and not clone: return self - # TODO: hack const metaop early here, fix this in multi - if self.base.op is Ops.CONST: return UOp.metaop(Ops.CONST, (), self.dtype, device, self.const_arg).view(unwrap(self.st)) # if it's a shrink, do the shrink before the copy with CONTIGUOUS if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device) # copy the base and apply the shapetracker on the new device