mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
comment about multi real and more tests [pr] (#7467)
This commit is contained in:
@@ -772,6 +772,8 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
|
||||
a = t.shrink(((2, 4), None))
|
||||
b = t.shrink(((6, 8), None))
|
||||
self.assertEqual(a.lazydata.real, [False, True, False, False])
|
||||
self.assertEqual(b.lazydata.real, [False, False, False, True])
|
||||
na = t.numpy()[2:4]
|
||||
nb = t.numpy()[6:8]
|
||||
np.testing.assert_equal(a.numpy(), na)
|
||||
@@ -781,6 +783,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
c = a.pad(((2, 4), None)) + b.pad(((6, 0), None))
|
||||
self.assertEqual(c.lazydata.real, [True, True, True, True])
|
||||
expected = np.concatenate([np.zeros_like(t.numpy()[0:2]), na, np.zeros_like(t.numpy()[4:6]), nb])
|
||||
np.testing.assert_equal(c.numpy(), expected)
|
||||
|
||||
|
||||
@@ -87,8 +87,6 @@ class MultiLazyBuffer(MathTrait):
|
||||
llbs.append(lb.copy_to_device(device).pad(pad_arg))
|
||||
return functools.reduce(operator.add, llbs)
|
||||
|
||||
def clone(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.clone() for lb in self.lbs], self.axis, self.real)
|
||||
|
||||
# passthroughs
|
||||
def is_realized(self) -> bool: return all(lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True)
|
||||
def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True):
|
||||
@@ -96,6 +94,7 @@ class MultiLazyBuffer(MathTrait):
|
||||
def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real)
|
||||
def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
|
||||
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)
|
||||
def clone(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.clone() for lb in self.lbs], self.axis, self.real)
|
||||
|
||||
# elementwise is simple
|
||||
def alu(self, op:Union[MetaOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer) -> MultiLazyBuffer:
|
||||
@@ -122,7 +121,9 @@ class MultiLazyBuffer(MathTrait):
|
||||
if self.axis is not None and self.axis in axis:
|
||||
# all-reduce on sharded axes
|
||||
reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(self.lbs, self.real)]
|
||||
# if all partitions are real, do all_reduce
|
||||
if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None)
|
||||
# only one partition is real, keep it
|
||||
return MultiLazyBuffer(reduced_parts, None, self.real)
|
||||
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
|
||||
return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real)
|
||||
@@ -169,6 +170,7 @@ class MultiLazyBuffer(MathTrait):
|
||||
assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}"
|
||||
if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]):
|
||||
assert all(arg[i] == (0, s) or i == self.axis for i,s in enumerate(self.shape)), "cannot shrink sharded and non-sharded axis at the same time"
|
||||
# NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
|
||||
idx = self.bounds.index(arg[self.axis])
|
||||
# zero out other lbs to not create lb reference
|
||||
return MultiLazyBuffer([lb if i==idx else lb.const_like(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))])
|
||||
|
||||
Reference in New Issue
Block a user