diff --git a/tinygrad/features/multi.py b/tinygrad/features/multi.py index 384440d8f9..5eb22b8cc1 100644 --- a/tinygrad/features/multi.py +++ b/tinygrad/features/multi.py @@ -83,14 +83,15 @@ class MultiLazyBuffer: def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]: return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape)) - def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> MultiLazyBuffer: - if self.axis is not None and new_shape[self.axis] == 1: + def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer: + if self.axis is not None and self.axis in axis: # all-reduce on sharded axes - reduced_parts = [x.r(op, new_shape) if r else x.const(0, shape=new_shape) for x,r in zip(self.lbs, self.real)] + new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape)) + reduced_parts = [x.r(op, axis) if r else x.const(0, shape=new_shape) for x,r in zip(self.lbs, self.real)] if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None) return MultiLazyBuffer(reduced_parts, None, self.real) # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct - return MultiLazyBuffer([x.r(op, self._shape_to_single_shard(new_shape, x)) for x in self.lbs], self.axis, self.real) + return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real) # *** movement ops *** diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 0c905da755..676e34a218 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -121,7 +121,8 @@ class LazyBuffer: unbound_new_shape = tuple(s.unbind()[0] if not isinstance(s, int) else s for s in new_shape) return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, unbound_new_shape, (self,)) - def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer: + def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer: + new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape)) # TODO: this logic should move to the scheduler if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape) assert len(self.shape)==len(new_shape) and all(ns in (1,s) for s,ns in zip(self.shape,new_shape)), f"not a contraction {self.shape=} {new_shape=}" diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 517b17e3ae..d424330430 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -142,21 +142,21 @@ class Where(Function): # ************* reduce ops ************* class Sum(Function): - def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer: + def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer: self.input_shape = x.shape - return x.r(ReduceOps.SUM, new_shape) + return x.r(ReduceOps.SUM, axis) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape) class Max(Function): - def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer: - self.x, self.ret = x, x.r(ReduceOps.MAX, new_shape) + def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer: + self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: # 1s in locations where the max was chosen (can be two locations) max_is_1s = self.x.e(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)).cast(self.x.dtype) - div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape) + div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape) return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape)) # ************* movement ops ************* @@ -164,10 +164,10 @@ class Max(Function): # NOTE: this is sum in reverse class Expand(Function): def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer: - self.input_shape = x.shape + self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if si != so) return x.expand(shape) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.r(ReduceOps.SUM, self.input_shape) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.r(ReduceOps.SUM, self.expanded_axis) class Reshape(Function): def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 99c0964a9a..94fed035f6 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -546,7 +546,7 @@ class Tensor: axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if isinstance(axis, int) else list(axis)) axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_] shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_) - ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)])) + ret = fxn.apply(self, axis=axis_) return ret if keepdim else ret.reshape(shape=shape) def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None):