diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 896d24b80b..9913a04605 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -554,24 +554,24 @@ class TestMultiTensor(unittest.TestCase): t4 = t2.reshape((26, 105,)) for t in [t0, t1, t2, t3, t4]: - np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten()) assert t.lazydata.axis == 1 + np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten()) # test shape-one axis t5 = t4.reshape((26, 1, 105)) - np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten()) assert t5.lazydata.axis == 2 + np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten()) # test split and rejoin to the right and reshape to the left t5 = t0.reshape((2, 13, 3, 5, 7)) t6 = t0.reshape((13, 2, 3, 7, 5)) t7 = t0.reshape((1, 13, 2, 3, 1, 7, 5)) - np.testing.assert_allclose(t5.numpy().flatten(), t0.numpy().flatten()) assert t5.lazydata.axis == 2 - np.testing.assert_allclose(t6.numpy().flatten(), t0.numpy().flatten()) assert t6.lazydata.axis == 2 - np.testing.assert_allclose(t7.numpy().flatten(), t0.numpy().flatten()) assert t7.lazydata.axis == 3 + np.testing.assert_allclose(t5.numpy().flatten(), t0.numpy().flatten()) + np.testing.assert_allclose(t6.numpy().flatten(), t0.numpy().flatten()) + np.testing.assert_allclose(t7.numpy().flatten(), t0.numpy().flatten()) # test no left join with self.assertRaises((AssertionError, ValueError)): @@ -580,8 +580,8 @@ class TestMultiTensor(unittest.TestCase): @unittest.skip("no longer supports uneven shard") def test_reshape_on_axis_uneven(self): def reshape_helper(t0, t, t_axis): - np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy()) assert t.lazydata.axis == t_axis + np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy()) t0 = Tensor.rand((4, 42, 15)).shard(devices_3, axis=1, splits=[14, 7, 21]) @@ -653,11 +653,11 @@ class TestMultiTensor(unittest.TestCase): def test_rand_like_from_alu(self): # TODO: fix this, which will also fix multi device dropout a = Tensor.ones(4, 4).shard(devices_2, axis=0) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): (a + a).rand_like() b = Tensor.empty(4, 4).shard(devices_2, axis=None) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): (a + b).rand_like() @unittest.skip("no longer supports uneven shard") diff --git a/tinygrad/engine/multi.py b/tinygrad/engine/multi.py index 17a3259b04..014f51db9e 100644 --- a/tinygrad/engine/multi.py +++ b/tinygrad/engine/multi.py @@ -43,27 +43,12 @@ def to_sharded(lbs:list[UOp], axis:int, bounds: tuple[tuple[int, int], ...]) -> from tinygrad.ops import PatternMatcher, UPat, GroupOp, graph_rewrite_map, track_rewrites -def get_axis(root:UOp): - if root.op is Ops.MULTI: return root.arg[0] - # NOTE: they all have to share an axis, we always choose [-1] - if root.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in root.src if x.axis is not None])) else None - src_axis = get_axis(root.src[0]) - if root.op is Ops.REDUCE_AXIS: return None if src_axis is not None and src_axis in root.arg[1] else src_axis - if root.op is Ops.RESHAPE: - if src_axis is None: return None - arg_acc:list[sint] = list(itertools.accumulate(root.arg, operator.mul, initial=1)) - # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards - # TODO: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1? - return len(arg_acc) - arg_acc[::-1].index(prod(root.src[0].shape[:src_axis])) - 1 - if root.op is Ops.PERMUTE: return root.arg.index(src_axis) if src_axis is not None else None - raise NotImplementedError("rest should be passthrough") - def alu_multi(root:UOp): msrcs = root.src assert all(x.op is Ops.MULTI for x in msrcs), f"all buffers must be MultiLazyBuffer {[x.op for x in msrcs]}" assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}" - axis = get_axis(root) + axis = root.axis bounds = dedup([x.bounds for x in root.src if x.axis == axis])[-1] if axis is not None else None srcs:list[list[UOp]] = [] not_all_real = not all(all(mlb.real) for mlb in msrcs) @@ -79,23 +64,23 @@ def alu_multi(root:UOp): return UOp.multi(*new_lbs, axis=axis, real=new_real) def reduce_multi(root:UOp, multi:UOp): - (op, axis), new_axis = root.arg, get_axis(root) + op, axis = root.arg if multi.axis is not None and multi.axis in axis: # all-reduce on sharded axes reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(multi.src, multi.real)] # if all partitions are real, do all_reduce - if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=new_axis) + if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=root.axis) # only one partition is real, keep it - return UOp.multi(*reduced_parts, axis=new_axis, real=multi.real) + return UOp.multi(*reduced_parts, axis=root.axis, real=multi.real) # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct - return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=new_axis, real=multi.real) + return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=root.axis, real=multi.real) def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]: return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape)) def reshape_multi(root:UOp, multi:UOp): - arg, new_axis = root.arg, get_axis(root) - if multi.axis is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=new_axis, real=multi.real) + arg = root.arg + if (new_axis:=root.axis) is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=new_axis, real=multi.real) assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)" assert all(prod(lb.shape[multi.axis:])%prod(arg[new_axis+1:])==0 for lb in multi.src), \ f"reshape cannot move items between shards {multi.shape} -> {root.arg=}" @@ -120,7 +105,7 @@ def pad_multi(root:UOp, multi:UOp): def permute_multi(root:UOp, multi:UOp): # all permutes supported! - return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=get_axis(root), real=multi.real) + return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=root.axis, real=multi.real) def shrink_multi(root:UOp, multi:UOp): assert multi.axis is None or root.arg[multi.axis] == (0, multi.shape[multi.axis]) or root.arg[multi.axis] in multi.bounds, \ diff --git a/tinygrad/ops.py b/tinygrad/ops.py index f193bd9bfd..7e7b9b6fa3 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from collections import defaultdict from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, _METADATA, flatten -from tinygrad.helpers import PICKLE_BUFFERS, SPLIT_REDUCEOP, DEBUG +from tinygrad.helpers import PICKLE_BUFFERS, SPLIT_REDUCEOP, DEBUG, dedup if TYPE_CHECKING: from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.device import Buffer @@ -436,10 +436,21 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.axis is None: raise RuntimeError("bounds is not defined when axis is None") return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.src], initial=0))) - @property - def axis(self): - assert self.op is Ops.MULTI - return self.arg[0] + @functools.cached_property + def axis(self) -> Optional[int]: + if self.op is Ops.MULTI: return self.arg[0] + # NOTE: they all have to share an axis, we always choose [-1] + if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None + src_axis = self.src[0].axis + if self.op is Ops.REDUCE_AXIS: return None if src_axis is not None and src_axis in self.arg[1] else src_axis + if self.op is Ops.RESHAPE: + if src_axis is None: return None + arg_acc:list[sint] = list(itertools.accumulate(self.arg, operator.mul, initial=1)) + # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards + # TODO: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1? + return len(arg_acc) - arg_acc[::-1].index(prod(self.src[0].shape[:src_axis])) - 1 + if self.op is Ops.PERMUTE: return self.arg.index(src_axis) if src_axis is not None else None + return src_axis @property def real(self):