make Tensor init the only caller of Tensor.from_uop (#15421)

* make Tensor init the only caller of Tensor.from_uop

prep broadcast cleanups

* type
This commit is contained in:
chenyu
2026-03-23 00:29:08 -04:00
committed by GitHub
parent 67dcc79fdd
commit 248cd9b39f

View File

@@ -134,7 +134,7 @@ class Tensor(OpMixin):
if isinstance(data, UOp):
assert _dtype is None or _dtype==data.dtype or data.dtype==dtypes.weakint, f"dtype mismatch: {_dtype} vs {data.dtype}"
# if data is dtype.weakint that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of
if data.dtype == dtypes.weakint: data = Tensor.from_uop(data).uop
if data.dtype == dtypes.weakint: data = Tensor.from_uop(data, device=_device).uop
elif data is None:
data = UOp.const(_dtype or dtypes.default_float, 0, _device)
elif isinstance(data, get_args(ConstType)):
@@ -1822,7 +1822,7 @@ class Tensor(OpMixin):
output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
denominator = prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
return numerator.div(Tensor.from_uop(denominator, device=numerator.device) if isinstance(denominator, UOp) else denominator).cast(output_dtype)
return numerator.div(denominator).cast(output_dtype)
def var(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> Tensor:
"""
@@ -1848,7 +1848,8 @@ class Tensor(OpMixin):
"""
squares = (self - self.mean(axis=axis, keepdim=True)).square()
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
denominator = (Tensor.from_uop(n, device=self.device) if isinstance(n, UOp) else Tensor(n, device=self.device)) - correction
denominator = Tensor(n, device=self.device) - correction
# TODO: infer device and remove relu
return squares.sum(axis=axis, keepdim=keepdim).div(denominator.relu())
def var_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> tuple[Tensor, Tensor]:
@@ -2954,9 +2955,8 @@ class Tensor(OpMixin):
# make y a Tensor
assert isinstance(y, (*get_args(ConstType), UOp)), f"{type(y)=}, {y=}"
if y is Invalid or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
elif not isinstance(y, UOp): y_dtype = dtypes.from_py(y)
if isinstance(y, UOp): y = Tensor.from_uop(y, device=x.device)
else: y = Tensor(y_dtype.const(y), x.device, y_dtype, requires_grad=False)
else: y_dtype = y.dtype if isinstance(y, UOp) else dtypes.from_py(y)
y = Tensor(y, x.device, y_dtype, requires_grad=False)
if match_dtype and x.dtype != y.dtype:
output_dtype = least_upper_dtype(x.dtype, y.dtype)
@@ -2993,7 +2993,7 @@ class Tensor(OpMixin):
a, b = self._broadcasted(x, reverse)
return a + (-b)
def div(self, x:Tensor|ConstType, reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor:
def div(self, x:Tensor|ConstType|UOp, reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor:
"""
Divides `self` by `x`.
Equivalent to `self / x`.