diff --git a/test/test_multitensor.py b/test/test_multitensor.py index b1af690231..854efc3d13 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -844,6 +844,8 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): a.schedule() assert a.shape == (2, 8) + # real is no longer used, so these are on None and we can pad them however + """ with self.assertRaises(AssertionError): # cannot pad sharded and non-sharded axis at the same time p = a.pad(((0, 6), (0, 1))) @@ -853,6 +855,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): # can only pad to whole axis p = a.pad(((1, 5), (0, 0))) p.schedule() + """ p = a.pad(((0, 6), (0, 0))) p.schedule() diff --git a/tinygrad/engine/multi.py b/tinygrad/engine/multi.py index 364e1dd25c..b1e48c595c 100644 --- a/tinygrad/engine/multi.py +++ b/tinygrad/engine/multi.py @@ -50,61 +50,49 @@ def alu_multi(root:UOp): axis = root.axis bounds = dedup([x.bounds for x in root.src if x.axis == axis])[-1] if axis is not None else None srcs:list[list[UOp]] = [] - not_all_real = not all(all(mlb.real) for mlb in msrcs) - new_real = tuple(all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])) if not_all_real else msrcs[0].real for mlb in msrcs: - if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(list(mlb.src)) + if mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds): srcs.append(list(mlb.src)) else: assert axis is not None and bounds is not None if mlb.axis is None: srcs.append(to_sharded(list(mlb.src), axis, bounds)) else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.src], axis, bounds)) new_lbs = [lsrcs[0].alu(root.op, *lsrcs[1:]) for lsrcs in zip(*srcs)] - new_lbs = [x if r else x.const_like(0) for r,x in zip(new_real, new_lbs)] # TODO: is this needed? - return UOp.multi(*new_lbs, axis=axis, real=new_real) + return UOp.multi(*new_lbs, axis=axis) def reduce_multi(root:UOp, multi:UOp): op, axis = root.arg if multi.axis is not None and multi.axis in axis: # all-reduce on sharded axes - reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(multi.src, multi.real)] + reduced_parts = [x.r(op, axis) for x in multi.src] # if all partitions are real, do all_reduce - if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=root.axis) - # only one partition is real, keep it - return UOp.multi(*reduced_parts, axis=root.axis, real=multi.real) + return UOp.multi(*all_reduce(op, reduced_parts), axis=root.axis) # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct - return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=root.axis, real=multi.real) + return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=root.axis) def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]: return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape)) def reshape_multi(root:UOp, multi:UOp): arg = root.arg - if (new_axis:=root.axis) is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=new_axis, real=multi.real) + if (new_axis:=root.axis) is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=new_axis) assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)" assert all(prod(lb.shape[multi.axis:])%prod(arg[new_axis+1:])==0 for lb in multi.src), \ f"reshape cannot move items between shards {multi.shape} -> {root.arg=}" lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[multi.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in multi.src] - return UOp.multi(*lbs, axis=new_axis, real=multi.real) + return UOp.multi(*lbs, axis=new_axis) def expand_multi(root:UOp, multi:UOp): # NOTE: this assert isn't needed, sharded axis can have dim 1 assert multi.axis is None or root.arg[multi.axis] == multi.shape[multi.axis], f"expand not supported on sharded axis {root.arg=}" - return UOp.multi(*[x.expand(_shape_to_single_shard(multi.axis, root.arg, x)) for x in multi.src], axis=multi.axis, real=multi.real) + return UOp.multi(*[x.expand(_shape_to_single_shard(multi.axis, root.arg, x)) for x in multi.src], axis=multi.axis) def pad_multi(root:UOp, multi:UOp): - assert multi.axis is None or root.arg[multi.axis] == (0,0) or not all(multi.real), f"padding not supported for {root.arg=}" - # pad on shard axis -> fill others with zeros and set real to all True - if multi.axis is not None and root.arg[multi.axis] != (0,0): - # pad back to whole axis, remove real mask - assert all(root.arg[i] == (0, 0) for i in range(len(multi.shape)) if i != multi.axis), "cannot pad sharded and non-sharded axis at the same time" - dim, bound = sum(lb.shape[multi.axis] for lb in multi.src), multi.bounds[multi.real.index(True)] - assert root.arg[multi.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis" - return UOp.multi(*[x if r else x.const_like(0) for x,r in zip(multi.src, multi.real)], axis=multi.axis) - return UOp.multi(*[x.pad(root.arg) for x in multi.src], axis=multi.axis, real=multi.real) + assert multi.axis is None or root.arg[multi.axis] == (0,0), f"padding not supported for {root.arg=}" + return UOp.multi(*[x.pad(root.arg) for x in multi.src], axis=multi.axis) def permute_multi(root:UOp, multi:UOp): # all permutes supported! - return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=root.axis, real=multi.real) + return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=root.axis) def shrink_multi(root:UOp, multi:UOp): assert multi.axis is None or root.arg[multi.axis] == (0, multi.shape[multi.axis]) or root.arg[multi.axis] in multi.bounds, \ @@ -114,32 +102,29 @@ def shrink_multi(root:UOp, multi:UOp): "cannot shrink sharded and non-sharded axis at the same time" # NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real idx = multi.bounds.index(root.arg[multi.axis]) - # zero out other lbs to not create lb reference - return UOp.multi(*[lb if i==idx else lb.const_like(0) for i,lb in enumerate(multi.src)], - axis=multi.axis, real=tuple(i==idx for i in range(len(multi.src)))) + return UOp.multi(*[multi.src[idx].copy_to_device(d) for d in root.device], axis=None) return UOp.multi(*[x.shrink(tuple((0, x.shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))) for x in multi.src], - axis=multi.axis, real=multi.real) + axis=multi.axis) def flip_multi(root:UOp, multi:UOp): assert multi.axis is None or not root.arg[multi.axis], "flipping not supported on sharded axis" - return UOp.multi(*[x.flip(root.arg) for x in multi.src], axis=multi.axis, real=multi.real) + return UOp.multi(*[x.flip(root.arg) for x in multi.src], axis=multi.axis) def copy_multi(multi:UOp, device:UOp): # if we already have a copy on the device, return that - if multi.axis is None: return next((lb for lb in multi.real_lbs if lb.device == device.arg), multi.real_lbs[0].copy_to_device(device)) + if multi.axis is None: return next((lb for lb in multi.src if lb.device == device.arg), multi.src[0].copy_to_device(device)) # copy lbs to device, pad to final shape, and sum llbs:list[UOp] = [] - for lb,real,(start,end) in zip(multi.src, multi.real, multi.bounds): - if not real: continue + for lb,(start,end) in zip(multi.src, multi.bounds): pad_arg = tuple((0,0) if a != multi.axis else (start, multi.bounds[-1][1]-end) for a in range(len(lb.shape))) llbs.append(lb.copy_to_device(device).pad(pad_arg)) return functools.reduce(operator.add, llbs) def assign_multi(dest:UOp, src:UOp): assert dest.axis == src.axis, f"axis must match in assign {dest.axis} != {src.axis}" - return UOp.multi(*[x.assign(y) for x,y in zip(dest.src, src.src)], axis=src.axis, real=src.real) + return UOp.multi(*[x.assign(y) for x,y in zip(dest.src, src.src)], axis=src.axis) -def passthrough_multi(root:UOp, multi:UOp): return UOp.multi(*[root.replace(src=(m,)) for m in multi.src], axis=multi.axis, real=multi.real) +def passthrough_multi(root:UOp, multi:UOp): return UOp.multi(*[root.replace(src=(m,)) for m in multi.src], axis=multi.axis) # NOTE: this is the same pattern as Ops.UNROLL multi_pm = PatternMatcher([ diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 26dcffedd6..918b96e79c 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -322,7 +322,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if not (src_sts := [x.st for x in self.src if x.st is not None]): return None assert all_same([x.shape for x in src_sts]), f"UOp sources must have the same shape {self} {[x.shape for x in src_sts]}" match self.op: - case Ops.MULTI: shape = tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)) + case Ops.MULTI: shape = tuple(sum(y.shape[a] for y in self.src) if a == self.axis else s for a,s in enumerate(self.src[0].shape)) case Ops.BITCAST: shape = src_sts[0].shape if self.dtype.itemsize != (input_sz:=self.src[0].dtype.itemsize): shape = shape[:-1]+((shape[-1]*input_sz) // self.dtype.itemsize,) @@ -433,10 +433,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # *** from MultiLazyBuffer *** - def multi(self, *more:UOp, axis:int|None, real:tuple[bool,...]|None=None): + def multi(self, *more:UOp, axis:int|None): parents = (self,)+more assert all_same([x.dtype for x in parents]), "multi parents must have the same dtype" - return UOp(Ops.MULTI, self.dtype, parents, (axis, real if real is not None else (True,)*len(parents))) + return UOp(Ops.MULTI, self.dtype, parents, (axis,)) @property def bounds(self): @@ -459,14 +459,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.PERMUTE: return self.arg.index(src_axis) if src_axis is not None else None return src_axis - @property - def real(self): - assert self.op is Ops.MULTI - return self.arg[1] - - @property - def real_lbs(self): return [lb for lb,r in zip(self.src, self.real) if r] - def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> UOp: lbs = [self.copy_to_device(d) if self.device != d else self for d in devices] if axis is not None: @@ -560,7 +552,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def realized(self) -> Optional[Buffer]: return self.buffer if self.op is Ops.BUFFER and self.buffer.is_allocated() else None @property def is_realized(self) -> bool: - return all(x.base.realized is not None for x in self.base.real_lbs) if self.base.op is Ops.MULTI else self.base.realized is not None + return all(x.base.realized is not None for x in self.base.src) if self.base.op is Ops.MULTI else self.base.realized is not None # *** uop Variable stuff ***