don't use UOp.multi in Tensor.rand (#10362)

This commit is contained in:
George Hotz
2025-05-16 16:09:36 -07:00
committed by GitHub
parent 7703dbef99
commit 7cc35a031b
3 changed files with 3 additions and 6 deletions

View File

@@ -449,6 +449,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.MULTI: return self.arg
# NOTE: they all have to share an axis, we always choose [-1]
if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
if len(self.src) == 0: return None
src_axis = self.src[0].axis
if self.op is Ops.REDUCE_AXIS: return None if src_axis is not None and src_axis in self.arg[1] else src_axis
if self.op is Ops.RESHAPE:

View File

@@ -8,7 +8,7 @@ try:
# IDIV is truncated division but z3 does floored division; mod by power of two sometimes uses Ops.AND
def z3_cdiv(a,b): return z3.If(a<0, (a+(b-1))/b, a/b)
z3_alu: dict[Ops, Callable] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, Ops.SHR: lambda a,b: a/(2**b.as_long()),
Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b}
Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If}
def create_bounded(name:str, vmin, vmax, solver:z3.Solver) -> z3.ArithRef:
s = z3.Int(name, ctx=solver.ctx)
solver.add(vmin <= s, s <= vmax)

View File

@@ -723,11 +723,7 @@ class Tensor(SimpleMathTrait):
dtype = kwargs.pop("dtype", self.dtype)
if isinstance(self.device, tuple):
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor")
if self.lazydata.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device)
contiguous = kwargs.pop("contiguous", True)
sharded_shape = tuple(s//len(self.device) if a==self.lazydata.axis else s for a,s in enumerate(self.shape))
rands = [Tensor.rand(sharded_shape, device=d, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for d in self.device]
return Tensor(UOp.multi(*rands, axis=self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device, self.lazydata.axis)
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs)
# ***** rng hlops *****