From 7cc35a031b2e06d6e6127cfc3f04ec7e870af5f4 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 16 May 2025 16:09:36 -0700 Subject: [PATCH] don't use UOp.multi in Tensor.rand (#10362) --- tinygrad/ops.py | 1 + tinygrad/spec.py | 2 +- tinygrad/tensor.py | 6 +----- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index d7fdb84a3f..0bd1ffccd3 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -449,6 +449,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.MULTI: return self.arg # NOTE: they all have to share an axis, we always choose [-1] if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None + if len(self.src) == 0: return None src_axis = self.src[0].axis if self.op is Ops.REDUCE_AXIS: return None if src_axis is not None and src_axis in self.arg[1] else src_axis if self.op is Ops.RESHAPE: diff --git a/tinygrad/spec.py b/tinygrad/spec.py index 0369651690..e13757678a 100644 --- a/tinygrad/spec.py +++ b/tinygrad/spec.py @@ -8,7 +8,7 @@ try: # IDIV is truncated division but z3 does floored division; mod by power of two sometimes uses Ops.AND def z3_cdiv(a,b): return z3.If(a<0, (a+(b-1))/b, a/b) z3_alu: dict[Ops, Callable] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, Ops.SHR: lambda a,b: a/(2**b.as_long()), - Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b} + Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If} def create_bounded(name:str, vmin, vmax, solver:z3.Solver) -> z3.ArithRef: s = z3.Int(name, ctx=solver.ctx) solver.add(vmin <= s, s <= vmax) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 4aa2b3a15a..fd4e32f98e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -723,11 +723,7 @@ class Tensor(SimpleMathTrait): dtype = kwargs.pop("dtype", self.dtype) if isinstance(self.device, tuple): if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor") - if self.lazydata.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device) - contiguous = kwargs.pop("contiguous", True) - sharded_shape = tuple(s//len(self.device) if a==self.lazydata.axis else s for a,s in enumerate(self.shape)) - rands = [Tensor.rand(sharded_shape, device=d, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for d in self.device] - return Tensor(UOp.multi(*rands, axis=self.lazydata.axis), device=self.device, dtype=dtype, **kwargs) + return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device, self.lazydata.axis) return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs) # ***** rng hlops *****