diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 918b96e79c..cb976960a9 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -436,7 +436,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def multi(self, *more:UOp, axis:int|None): parents = (self,)+more assert all_same([x.dtype for x in parents]), "multi parents must have the same dtype" - return UOp(Ops.MULTI, self.dtype, parents, (axis,)) + return UOp(Ops.MULTI, self.dtype, parents, axis) @property def bounds(self): @@ -445,7 +445,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @functools.cached_property def axis(self) -> Optional[int]: - if self.op is Ops.MULTI: return self.arg[0] + if self.op is Ops.MULTI: return self.arg # NOTE: they all have to share an axis, we always choose [-1] if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None src_axis = self.src[0].axis