mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
multi doesn't need tuple arg anymore [pr] (#10307)
This commit is contained in:
@@ -436,7 +436,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def multi(self, *more:UOp, axis:int|None):
|
||||
parents = (self,)+more
|
||||
assert all_same([x.dtype for x in parents]), "multi parents must have the same dtype"
|
||||
return UOp(Ops.MULTI, self.dtype, parents, (axis,))
|
||||
return UOp(Ops.MULTI, self.dtype, parents, axis)
|
||||
|
||||
@property
|
||||
def bounds(self):
|
||||
@@ -445,7 +445,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
@functools.cached_property
|
||||
def axis(self) -> Optional[int]:
|
||||
if self.op is Ops.MULTI: return self.arg[0]
|
||||
if self.op is Ops.MULTI: return self.arg
|
||||
# NOTE: they all have to share an axis, we always choose [-1]
|
||||
if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
|
||||
src_axis = self.src[0].axis
|
||||
|
||||
Reference in New Issue
Block a user