diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 68c814de23..3fa3115d9d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -134,7 +134,7 @@ class Tensor(OpMixin): if isinstance(data, UOp): assert _dtype is None or _dtype==data.dtype or data.dtype==dtypes.weakint, f"dtype mismatch: {_dtype} vs {data.dtype}" # if data is dtype.weakint that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of - if data.dtype == dtypes.weakint: data = Tensor.from_uop(data).uop + if data.dtype == dtypes.weakint: data = Tensor.from_uop(data, device=_device).uop elif data is None: data = UOp.const(_dtype or dtypes.default_float, 0, _device) elif isinstance(data, get_args(ConstType)): @@ -1822,7 +1822,7 @@ class Tensor(OpMixin): output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32 numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim) denominator = prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)]) - return numerator.div(Tensor.from_uop(denominator, device=numerator.device) if isinstance(denominator, UOp) else denominator).cast(output_dtype) + return numerator.div(denominator).cast(output_dtype) def var(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> Tensor: """ @@ -1848,7 +1848,8 @@ class Tensor(OpMixin): """ squares = (self - self.mean(axis=axis, keepdim=True)).square() n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)]) - denominator = (Tensor.from_uop(n, device=self.device) if isinstance(n, UOp) else Tensor(n, device=self.device)) - correction + denominator = Tensor(n, device=self.device) - correction + # TODO: infer device and remove relu return squares.sum(axis=axis, keepdim=keepdim).div(denominator.relu()) def var_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> tuple[Tensor, Tensor]: @@ -2954,9 +2955,8 @@ class Tensor(OpMixin): # make y a Tensor assert isinstance(y, (*get_args(ConstType), UOp)), f"{type(y)=}, {y=}" if y is Invalid or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype - elif not isinstance(y, UOp): y_dtype = dtypes.from_py(y) - if isinstance(y, UOp): y = Tensor.from_uop(y, device=x.device) - else: y = Tensor(y_dtype.const(y), x.device, y_dtype, requires_grad=False) + else: y_dtype = y.dtype if isinstance(y, UOp) else dtypes.from_py(y) + y = Tensor(y, x.device, y_dtype, requires_grad=False) if match_dtype and x.dtype != y.dtype: output_dtype = least_upper_dtype(x.dtype, y.dtype) @@ -2993,7 +2993,7 @@ class Tensor(OpMixin): a, b = self._broadcasted(x, reverse) return a + (-b) - def div(self, x:Tensor|ConstType, reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor: + def div(self, x:Tensor|ConstType|UOp, reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor: """ Divides `self` by `x`. Equivalent to `self / x`.