comment about multi real and more tests [pr] (#7467)

This commit is contained in:
chenyu
2024-11-01 11:49:11 -04:00
committed by GitHub
parent 1f343aa40e
commit 18e159c9ac
2 changed files with 7 additions and 2 deletions

View File

@@ -772,6 +772,8 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
a = t.shrink(((2, 4), None))
b = t.shrink(((6, 8), None))
self.assertEqual(a.lazydata.real, [False, True, False, False])
self.assertEqual(b.lazydata.real, [False, False, False, True])
na = t.numpy()[2:4]
nb = t.numpy()[6:8]
np.testing.assert_equal(a.numpy(), na)
@@ -781,6 +783,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
c = a + b
c = a.pad(((2, 4), None)) + b.pad(((6, 0), None))
self.assertEqual(c.lazydata.real, [True, True, True, True])
expected = np.concatenate([np.zeros_like(t.numpy()[0:2]), na, np.zeros_like(t.numpy()[4:6]), nb])
np.testing.assert_equal(c.numpy(), expected)

View File

@@ -87,8 +87,6 @@ class MultiLazyBuffer(MathTrait):
llbs.append(lb.copy_to_device(device).pad(pad_arg))
return functools.reduce(operator.add, llbs)
def clone(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.clone() for lb in self.lbs], self.axis, self.real)
# passthroughs
def is_realized(self) -> bool: return all(lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True)
def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True):
@@ -96,6 +94,7 @@ class MultiLazyBuffer(MathTrait):
def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real)
def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)
def clone(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.clone() for lb in self.lbs], self.axis, self.real)
# elementwise is simple
def alu(self, op:Union[MetaOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer) -> MultiLazyBuffer:
@@ -122,7 +121,9 @@ class MultiLazyBuffer(MathTrait):
if self.axis is not None and self.axis in axis:
# all-reduce on sharded axes
reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(self.lbs, self.real)]
# if all partitions are real, do all_reduce
if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None)
# only one partition is real, keep it
return MultiLazyBuffer(reduced_parts, None, self.real)
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real)
@@ -169,6 +170,7 @@ class MultiLazyBuffer(MathTrait):
assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}"
if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]):
assert all(arg[i] == (0, s) or i == self.axis for i,s in enumerate(self.shape)), "cannot shrink sharded and non-sharded axis at the same time"
# NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
idx = self.bounds.index(arg[self.axis])
# zero out other lbs to not create lb reference
return MultiLazyBuffer([lb if i==idx else lb.const_like(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))])