mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
don't use UOp.multi in Tensor.rand (#10362)
This commit is contained in:
@@ -449,6 +449,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.MULTI: return self.arg
|
||||
# NOTE: they all have to share an axis, we always choose [-1]
|
||||
if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
|
||||
if len(self.src) == 0: return None
|
||||
src_axis = self.src[0].axis
|
||||
if self.op is Ops.REDUCE_AXIS: return None if src_axis is not None and src_axis in self.arg[1] else src_axis
|
||||
if self.op is Ops.RESHAPE:
|
||||
|
||||
@@ -8,7 +8,7 @@ try:
|
||||
# IDIV is truncated division but z3 does floored division; mod by power of two sometimes uses Ops.AND
|
||||
def z3_cdiv(a,b): return z3.If(a<0, (a+(b-1))/b, a/b)
|
||||
z3_alu: dict[Ops, Callable] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, Ops.SHR: lambda a,b: a/(2**b.as_long()),
|
||||
Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b}
|
||||
Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If}
|
||||
def create_bounded(name:str, vmin, vmax, solver:z3.Solver) -> z3.ArithRef:
|
||||
s = z3.Int(name, ctx=solver.ctx)
|
||||
solver.add(vmin <= s, s <= vmax)
|
||||
|
||||
@@ -723,11 +723,7 @@ class Tensor(SimpleMathTrait):
|
||||
dtype = kwargs.pop("dtype", self.dtype)
|
||||
if isinstance(self.device, tuple):
|
||||
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor")
|
||||
if self.lazydata.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device)
|
||||
contiguous = kwargs.pop("contiguous", True)
|
||||
sharded_shape = tuple(s//len(self.device) if a==self.lazydata.axis else s for a,s in enumerate(self.shape))
|
||||
rands = [Tensor.rand(sharded_shape, device=d, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for d in self.device]
|
||||
return Tensor(UOp.multi(*rands, axis=self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
|
||||
return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device, self.lazydata.axis)
|
||||
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs)
|
||||
|
||||
# ***** rng hlops *****
|
||||
|
||||
Reference in New Issue
Block a user