multi doesn't need tuple arg anymore [pr] (#10307)

This commit is contained in:
George Hotz
2025-05-14 15:16:40 -07:00
committed by GitHub
parent 7a3d4de59a
commit fc8ef63194

View File

@@ -436,7 +436,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def multi(self, *more:UOp, axis:int|None):
parents = (self,)+more
assert all_same([x.dtype for x in parents]), "multi parents must have the same dtype"
return UOp(Ops.MULTI, self.dtype, parents, (axis,))
return UOp(Ops.MULTI, self.dtype, parents, axis)
@property
def bounds(self):
@@ -445,7 +445,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def axis(self) -> Optional[int]:
if self.op is Ops.MULTI: return self.arg[0]
if self.op is Ops.MULTI: return self.arg
# NOTE: they all have to share an axis, we always choose [-1]
if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
src_axis = self.src[0].axis