diff --git a/test/test_multitensor.py b/test/test_multitensor.py index e1ab625977..54233b0f2c 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -772,6 +772,8 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): a = t.shrink(((2, 4), None)) b = t.shrink(((6, 8), None)) + self.assertEqual(a.lazydata.real, [False, True, False, False]) + self.assertEqual(b.lazydata.real, [False, False, False, True]) na = t.numpy()[2:4] nb = t.numpy()[6:8] np.testing.assert_equal(a.numpy(), na) @@ -781,6 +783,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): c = a + b c = a.pad(((2, 4), None)) + b.pad(((6, 0), None)) + self.assertEqual(c.lazydata.real, [True, True, True, True]) expected = np.concatenate([np.zeros_like(t.numpy()[0:2]), na, np.zeros_like(t.numpy()[4:6]), nb]) np.testing.assert_equal(c.numpy(), expected) diff --git a/tinygrad/multi.py b/tinygrad/multi.py index 975c6f7fb3..766d757f57 100644 --- a/tinygrad/multi.py +++ b/tinygrad/multi.py @@ -87,8 +87,6 @@ class MultiLazyBuffer(MathTrait): llbs.append(lb.copy_to_device(device).pad(pad_arg)) return functools.reduce(operator.add, llbs) - def clone(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.clone() for lb in self.lbs], self.axis, self.real) - # passthroughs def is_realized(self) -> bool: return all(lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True) def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True): @@ -96,6 +94,7 @@ class MultiLazyBuffer(MathTrait): def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real) def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real) def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real) + def clone(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.clone() for lb in self.lbs], self.axis, self.real) # elementwise is simple def alu(self, op:Union[MetaOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer) -> MultiLazyBuffer: @@ -122,7 +121,9 @@ class MultiLazyBuffer(MathTrait): if self.axis is not None and self.axis in axis: # all-reduce on sharded axes reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(self.lbs, self.real)] + # if all partitions are real, do all_reduce if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None) + # only one partition is real, keep it return MultiLazyBuffer(reduced_parts, None, self.real) # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real) @@ -169,6 +170,7 @@ class MultiLazyBuffer(MathTrait): assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}" if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]): assert all(arg[i] == (0, s) or i == self.axis for i,s in enumerate(self.shape)), "cannot shrink sharded and non-sharded axis at the same time" + # NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real idx = self.bounds.index(arg[self.axis]) # zero out other lbs to not create lb reference return MultiLazyBuffer([lb if i==idx else lb.const_like(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))])