mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
make Tensor init the only caller of Tensor.from_uop (#15421)
* make Tensor init the only caller of Tensor.from_uop prep broadcast cleanups * type
This commit is contained in:
@@ -134,7 +134,7 @@ class Tensor(OpMixin):
|
||||
if isinstance(data, UOp):
|
||||
assert _dtype is None or _dtype==data.dtype or data.dtype==dtypes.weakint, f"dtype mismatch: {_dtype} vs {data.dtype}"
|
||||
# if data is dtype.weakint that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of
|
||||
if data.dtype == dtypes.weakint: data = Tensor.from_uop(data).uop
|
||||
if data.dtype == dtypes.weakint: data = Tensor.from_uop(data, device=_device).uop
|
||||
elif data is None:
|
||||
data = UOp.const(_dtype or dtypes.default_float, 0, _device)
|
||||
elif isinstance(data, get_args(ConstType)):
|
||||
@@ -1822,7 +1822,7 @@ class Tensor(OpMixin):
|
||||
output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
|
||||
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
|
||||
denominator = prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
|
||||
return numerator.div(Tensor.from_uop(denominator, device=numerator.device) if isinstance(denominator, UOp) else denominator).cast(output_dtype)
|
||||
return numerator.div(denominator).cast(output_dtype)
|
||||
|
||||
def var(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> Tensor:
|
||||
"""
|
||||
@@ -1848,7 +1848,8 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
squares = (self - self.mean(axis=axis, keepdim=True)).square()
|
||||
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
|
||||
denominator = (Tensor.from_uop(n, device=self.device) if isinstance(n, UOp) else Tensor(n, device=self.device)) - correction
|
||||
denominator = Tensor(n, device=self.device) - correction
|
||||
# TODO: infer device and remove relu
|
||||
return squares.sum(axis=axis, keepdim=keepdim).div(denominator.relu())
|
||||
|
||||
def var_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> tuple[Tensor, Tensor]:
|
||||
@@ -2954,9 +2955,8 @@ class Tensor(OpMixin):
|
||||
# make y a Tensor
|
||||
assert isinstance(y, (*get_args(ConstType), UOp)), f"{type(y)=}, {y=}"
|
||||
if y is Invalid or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
|
||||
elif not isinstance(y, UOp): y_dtype = dtypes.from_py(y)
|
||||
if isinstance(y, UOp): y = Tensor.from_uop(y, device=x.device)
|
||||
else: y = Tensor(y_dtype.const(y), x.device, y_dtype, requires_grad=False)
|
||||
else: y_dtype = y.dtype if isinstance(y, UOp) else dtypes.from_py(y)
|
||||
y = Tensor(y, x.device, y_dtype, requires_grad=False)
|
||||
|
||||
if match_dtype and x.dtype != y.dtype:
|
||||
output_dtype = least_upper_dtype(x.dtype, y.dtype)
|
||||
@@ -2993,7 +2993,7 @@ class Tensor(OpMixin):
|
||||
a, b = self._broadcasted(x, reverse)
|
||||
return a + (-b)
|
||||
|
||||
def div(self, x:Tensor|ConstType, reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor:
|
||||
def div(self, x:Tensor|ConstType|UOp, reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor:
|
||||
"""
|
||||
Divides `self` by `x`.
|
||||
Equivalent to `self / x`.
|
||||
|
||||
Reference in New Issue
Block a user