Add consistent typing for python 3.10 in tensor.py (#9326)

* add consistent typing for python 3.10 in tensor.py

* pull

* Tensor.copysign (#9329)

* fast amd gemm (#9318)

* 50 TFLOP AMD gemm

* add lds tiling

* register tiling

* flip locals

* work

* comment

* remove those

* fix Tensor.view with a tuple arg (#9330)

* reorder binops (#9328)

* reorder binops

* test improvements + fix string tests

* ugh, okay this

* Make const moving not depend on the order (#9245)

Since floats are not being flipped anymore this should help with const
folding for floats

* use empty for test instead of rand (#9332)

* linter

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
Co-authored-by: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com>
This commit is contained in:
Friedrich Carl Eichenroth
2025-03-03 21:00:17 +00:00
committed by GitHub
parent 019417743c
commit 94db8426cb

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
from contextlib import ContextDecorator
from typing import Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
@@ -46,7 +46,7 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None:
# **** Tensor helper functions ****
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None):
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:str|tuple[str, ...], arg=None):
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None)
@@ -63,7 +63,7 @@ def get_shape(x) -> tuple[int, ...]:
if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}")
return (len(subs),) + (subs[0] if subs else ())
def _frompy(x:Union[list, tuple, bytes], dtype:DType) -> UOp:
def _frompy(x:list|tuple|bytes, dtype:DType) -> UOp:
if isinstance(x, bytes): ret, data = UOp.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
else:
ret = UOp.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON")
@@ -74,7 +74,7 @@ def _frompy(x:Union[list, tuple, bytes], dtype:DType) -> UOp:
ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data)))
return ret
def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:Union[str, tuple[str, ...]], dtype:DType) -> list[list[Tensor]]:
def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:str|tuple[str, ...], dtype:DType) -> list[list[Tensor]]:
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype) for m in mat], dim=dim)
for k in range(len(mat[0]))] for dim in range(dims)]
@@ -126,18 +126,18 @@ class Tensor(SimpleMathTrait):
training: ClassVar[bool] = False
no_grad: ClassVar[bool] = False
def __init__(self, data:Union[None, ConstType, bytes, list, tuple, UOp, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'np.ndarray'|pathlib.Path|None, # type: ignore [name-defined] # noqa: F821
device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None):
if dtype is not None: dtype = to_dtype(dtype)
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
# tensors can have gradients if you have called .backward
self.grad: Optional[Tensor] = None
self.grad:Tensor|None = None
# NOTE: this can be in three states. False and None: no gradient, True: gradient
# None (the default) will be updated to True if it's put in an optimizer
self.requires_grad: Optional[bool] = requires_grad
self.requires_grad:bool|None = requires_grad
# create a LazyBuffer from the different types of inputs
if isinstance(data, UOp):
@@ -182,7 +182,7 @@ class Tensor(SimpleMathTrait):
needs_input_grad = [t.requires_grad for t in (self,)+x]
return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False)
def _apply_broadcasted_uop(self, fxn:Callable, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
def _apply_broadcasted_uop(self, fxn:Callable, x:Tensor|ConstType, reverse=False) -> Tensor:
lhs,rhs = self._broadcasted(x, reverse)
return lhs._apply_uop(fxn, rhs)
@@ -215,7 +215,7 @@ class Tensor(SimpleMathTrait):
return self.shape[0]
@property
def device(self) -> Union[str, tuple[str, ...]]: return self.lazydata.device
def device(self) -> str|tuple[str, ...]: return self.lazydata.device
@property
def shape(self) -> tuple[sint, ...]: return self.lazydata.shape
@@ -326,7 +326,7 @@ class Tensor(SimpleMathTrait):
# TODO: should be Tensor.tolist() -> Union[list[ConstType], ConstType]. The list is Sequence because mypy expects memoryview.tolist() -> list[int]
# src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
def tolist(self) -> Union[Sequence[ConstType], ConstType]:
def tolist(self) -> Sequence[ConstType]|ConstType:
"""
Returns the value of this tensor as a nested list.
Returns single value for const tensor.
@@ -365,7 +365,7 @@ class Tensor(SimpleMathTrait):
if self.grad is not None: ret.grad = self.grad.clone()
return ret
def to(self, device:Optional[Union[str, tuple[str, ...]]]) -> Tensor:
def to(self, device:str|tuple[str, ...]|None) -> Tensor:
"""
Moves the tensor to the given device.
"""
@@ -376,7 +376,7 @@ class Tensor(SimpleMathTrait):
if self.grad is not None: ret.grad = self.grad.to(device)
return ret
def to_(self, device:Optional[Union[str, tuple[str, ...]]]):
def to_(self, device:str|tuple[str, ...]|None):
"""
Moves the tensor to the given device in place.
"""
@@ -384,7 +384,7 @@ class Tensor(SimpleMathTrait):
if self.grad is not None and real.grad is not None: self.grad.replace(real.grad)
return self.replace(real)
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> Tensor:
def shard(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor:
"""
Shards the tensor across the given devices. Optionally specify which axis to shard on.
@@ -398,7 +398,7 @@ class Tensor(SimpleMathTrait):
mlb = self.lazydata.shard(devices, self._resolve_dim(axis) if axis is not None else None)
return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None):
def shard_(self, devices:tuple[str, ...], axis:int|None=None):
"""
Shards the tensor across the given devices in place.
"""
@@ -415,7 +415,7 @@ class Tensor(SimpleMathTrait):
# ***** creation entrypoint *****
@staticmethod
def _metaop(op, shape, device:Optional[Union[tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs):
def _metaop(op, shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None, arg=None, **kwargs):
dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float
if isinstance(device, tuple):
return Tensor(UOp.multi(*[UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], axis=None),
@@ -493,7 +493,7 @@ class Tensor(SimpleMathTrait):
return counts0.cat(counts1)
@staticmethod
def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, contiguous:bool=True, **kwargs) -> Tensor:
def rand(*shape, device:str|None=None, dtype:DTypeLike|None=None, contiguous:bool=True, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
@@ -633,7 +633,7 @@ class Tensor(SimpleMathTrait):
return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype)
@staticmethod
def linspace(start:Union[int, float], stop:Union[int, float], steps:int, **kwargs) -> Tensor:
def linspace(start:int|float, stop:int|float, steps:int, **kwargs) -> Tensor:
"""
Returns a 1-D tensor of `steps` evenly spaced values from `start` to `stop`, inclusive.
@@ -653,7 +653,7 @@ class Tensor(SimpleMathTrait):
return (start + Tensor.arange(steps, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype)
@staticmethod
def eye(n:int, m:Optional[int]=None, **kwargs) -> Tensor:
def eye(n:int, m:int|None=None, **kwargs) -> Tensor:
"""
Returns a 2-D tensor with `n` rows and `m` columns, with ones on the diagonal and zeros elsewhere.
@@ -740,7 +740,7 @@ class Tensor(SimpleMathTrait):
# ***** rng hlops *****
@staticmethod
def randn(*shape, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
def randn(*shape, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
If `dtype` is not specified, the default type is used.
@@ -777,7 +777,7 @@ class Tensor(SimpleMathTrait):
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
@staticmethod
def normal(*shape, mean=0.0, std=1.0, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
def normal(*shape, mean=0.0, std=1.0, requires_grad:bool|None=None, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
@@ -792,7 +792,7 @@ class Tensor(SimpleMathTrait):
return ((std * Tensor.randn(*shape, **kwargs)) + mean).requires_grad_(requires_grad)
@staticmethod
def uniform(*shape, low=0.0, high=1.0, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
def uniform(*shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
@@ -883,7 +883,7 @@ class Tensor(SimpleMathTrait):
# ***** toposort and backward pass *****
def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None, materialize_grads=False) -> list[Tensor]:
def gradient(self, *targets:Tensor, gradient:Tensor|None=None, materialize_grads=False) -> list[Tensor]:
"""
Compute the gradient of the targets with respect to self.
@@ -912,7 +912,7 @@ class Tensor(SimpleMathTrait):
# create returned Tensors
return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])]
def backward(self, gradient:Optional[Tensor]=None) -> Tensor:
def backward(self, gradient:Tensor|None=None) -> Tensor:
"""
Propagates the gradient of a tensor backwards through the computation graph.
If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
@@ -1007,7 +1007,7 @@ class Tensor(SimpleMathTrait):
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
return self._apply_uop(UOp.flip, arg=tuple([i in axis_arg for i in range(len(self.shape))]))
def shrink(self, arg:tuple[Optional[tuple[sint, sint]], ...]) -> Tensor:
def shrink(self, arg:tuple[tuple[sint, sint]|None, ...]) -> Tensor:
"""
Returns a tensor that shrinks the each axis based on input arg.
`arg` must have the same length as `self.ndim`.
@@ -1027,7 +1027,7 @@ class Tensor(SimpleMathTrait):
if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg))
def pad(self, padding:Union[Sequence[sint], Sequence[Optional[tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor:
def pad(self, padding:Sequence[sint]|Sequence[tuple[sint, sint]|None], mode:str="constant", value:float=0.0) -> Tensor:
"""
Returns a tensor with padding applied based on the input `padding`.
@@ -1065,7 +1065,7 @@ class Tensor(SimpleMathTrait):
if len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads")
pX = _flat_to_grouped(tuple(cast(Sequence[sint], padding)) + (0,0)*(self.ndim - len(padding)//2))
# group padding
else: pX = tuple((0,0) if p is None else p for p in cast(Sequence[Optional[tuple[sint, sint]]], padding))
else: pX = tuple((0,0) if p is None else p for p in cast(Sequence[tuple[sint, sint]|None], padding))
if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
if mode == "constant":
@@ -1092,7 +1092,7 @@ class Tensor(SimpleMathTrait):
# ***** movement high level ops *****
def _getitem(self, indices, v: Optional[Tensor] = None) -> Tensor:
def _getitem(self, indices, v: Tensor|None = None) -> Tensor:
# wrap single index into a list
if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices]
x, indices = self, list(indices)
@@ -1221,7 +1221,7 @@ class Tensor(SimpleMathTrait):
"""
return self._getitem(indices)
def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None:
def __setitem__(self, indices, v:Tensor|ConstType) -> None:
if isinstance(self.device, str) and self.device.startswith("DISK"):
self._getitem(indices).assign(v)
return
@@ -1293,7 +1293,7 @@ class Tensor(SimpleMathTrait):
# checks for shapes and number of dimensions delegated to cat
return Tensor.cat(*[t.unsqueeze(dim) for t in [self, *args]], dim=dim)
def repeat_interleave(self, repeats:int, dim:Optional[int]=None) -> Tensor:
def repeat_interleave(self, repeats:int, dim:int|None=None) -> Tensor:
"""
Repeat elements of a tensor.
@@ -1331,7 +1331,7 @@ class Tensor(SimpleMathTrait):
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
return dim + total if dim < 0 else dim
def split(self, sizes:Union[int, list[int]], dim:int=0) -> tuple[Tensor, ...]:
def split(self, sizes:int|Sequence[int], dim:int=0) -> tuple[Tensor, ...]:
"""
Splits the tensor into chunks along the dimension specified by `dim`.
If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
@@ -1380,7 +1380,7 @@ class Tensor(SimpleMathTrait):
dim = self._resolve_dim(dim)
return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim))
def meshgrid(self:Tensor, *args:Tensor, indexing:Union[Literal["ij"], Literal["xy"]]="ij") -> tuple[Tensor, ...]:
def meshgrid(self:Tensor, *args:Tensor, indexing:Literal["ij", "xy"]="ij") -> tuple[Tensor, ...]:
"""
Generates coordinate matrices from coordinate vectors.
Input tensors can be scalars or 1D tensors.
@@ -1407,7 +1407,7 @@ class Tensor(SimpleMathTrait):
output_shape = _broadcast_shape(*(t.shape for t in tensors))
return tuple(t._broadcast_to(output_shape) for t in tensors)
def squeeze(self, dim:Optional[int]=None) -> Tensor:
def squeeze(self, dim:int|None=None) -> Tensor:
"""
Returns a tensor with specified dimensions of input of size 1 removed.
If `dim` is not specified, all dimensions with size 1 are removed.
@@ -1497,7 +1497,7 @@ class Tensor(SimpleMathTrait):
dim = self._resolve_dim(dim)
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
def roll(self, shifts:Union[int, tuple[int, ...]], dims:Union[int, tuple[int, ...]]) -> Tensor:
def roll(self, shifts:int|tuple[int, ...], dims:int|tuple[int, ...]) -> Tensor:
"""
Rolls the tensor along specified dimension(s).
The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension.
@@ -1559,13 +1559,13 @@ class Tensor(SimpleMathTrait):
# ***** reduce ops *****
def _reduce(self, op:Ops, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
def _reduce(self, op:Ops, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1)))
if self.ndim == 0: axis = ()
ret = self._apply_uop(UOp.r, op=op, axis=axis)
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis))
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
def sum(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None):
"""
Returns the sum of the elements of the tensor along the specified axis or axes.
@@ -1592,7 +1592,7 @@ class Tensor(SimpleMathTrait):
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim)
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None):
"""
Returns the product of the elements of the tensor along the specified axis or axes.
@@ -1618,7 +1618,7 @@ class Tensor(SimpleMathTrait):
"""
return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim)
def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
def max(self, axis:int|Sequence[int]|None=None, keepdim=False):
"""
Returns the maximum value of the tensor along the specified axis or axes.
@@ -1643,7 +1643,7 @@ class Tensor(SimpleMathTrait):
def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
def min(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
def min(self, axis:int|Sequence[int]|None=None, keepdim=False):
"""
Returns the minimum value of the tensor along the specified axis or axes.
@@ -1666,7 +1666,7 @@ class Tensor(SimpleMathTrait):
"""
return self._inverse().max(axis=axis, keepdim=keepdim)._inverse()
def any(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
def any(self, axis:int|Sequence[int]|None=None, keepdim=False):
"""
Tests if any element evaluates to `True` along the specified axis or axes.
@@ -1688,7 +1688,7 @@ class Tensor(SimpleMathTrait):
"""
return self.bool().max(axis, keepdim)
def all(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
def all(self, axis:int|Sequence[int]|None=None, keepdim=False):
"""
Tests if all element evaluates to `True` along the specified axis or axes.
@@ -1730,7 +1730,7 @@ class Tensor(SimpleMathTrait):
is_nan_close = (self.isnan() & other.isnan()) & equal_nan
return is_finite_close | is_infinite_close | is_nan_close
def mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
def mean(self, axis:int|Sequence[int]|None=None, keepdim=False):
"""
Returns the mean value of the tensor along the specified axis or axes.
@@ -1756,7 +1756,7 @@ class Tensor(SimpleMathTrait):
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])).cast(output_dtype)
def var(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
def var(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1):
"""
Returns the variance of the tensor along the specified axis or axes.
@@ -1782,7 +1782,7 @@ class Tensor(SimpleMathTrait):
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
return squares.sum(axis=axis, keepdim=keepdim).div(smax([0, n-correction]))
def var_mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
def var_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1):
"""
Calculates the variance and mean over the dimensions specified by dim.
Syntactic sugar around `Tensor.var` and `Tensor.mean` to match `torch.var_mean`.
@@ -1799,7 +1799,7 @@ class Tensor(SimpleMathTrait):
"""
return self.var(axis, keepdim, correction), self.mean(axis, keepdim)
def std(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
def std(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1):
"""
Returns the standard deviation of the tensor along the specified axis or axes.
@@ -1823,7 +1823,7 @@ class Tensor(SimpleMathTrait):
"""
return self.var(axis, keepdim, correction).sqrt()
def std_mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
def std_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1):
"""
Calculates the standard deviation and mean over the dimensions specified by dim.
Syntactic sugar around `Tensor.std` and `Tensor.mean` to match `torch.std_mean`.
@@ -1840,13 +1840,13 @@ class Tensor(SimpleMathTrait):
"""
return self.std(axis, keepdim, correction), self.mean(axis, keepdim)
def _softmax(self, axis, dtype:Optional[DTypeLike]=None):
def _softmax(self, axis, dtype:DTypeLike|None=None):
m = self - self.max(axis=axis, keepdim=True).detach()
if dtype is not None: m = m.cast(dtype)
e = m.exp()
return m, e, e.sum(axis=axis, keepdim=True)
def softmax(self, axis=-1, dtype:Optional[DTypeLike]=None):
def softmax(self, axis=-1, dtype:DTypeLike|None=None):
"""
Applies the softmax function to the tensor along the specified axis.
@@ -1869,7 +1869,7 @@ class Tensor(SimpleMathTrait):
_, e, ss = self._softmax(axis, dtype)
return e.div(ss)
def log_softmax(self, axis=-1, dtype:Optional[DTypeLike]=None):
def log_softmax(self, axis=-1, dtype:DTypeLike|None=None):
"""
Applies the log-softmax function to the tensor along the specified axis.
@@ -2005,7 +2005,7 @@ class Tensor(SimpleMathTrait):
return self._inverse().argmax(axis=axis, keepdim=keepdim)
@staticmethod
def einsum(formula:str, *operands:Tensor|Sequence[Tensor], acc_dtype:Optional[DTypeLike]=None) -> Tensor:
def einsum(formula:str, *operands:Tensor|Sequence[Tensor], acc_dtype:DTypeLike|None=None) -> Tensor:
"""
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
@@ -2051,7 +2051,7 @@ class Tensor(SimpleMathTrait):
# ***** processing ops *****
def _pool(self, k_:tuple[sint, ...], stride:Union[tuple[int, ...], int]=1, dilation:Union[tuple[int, ...], int]=1) -> Tensor:
def _pool(self, k_:tuple[sint, ...], stride:int|tuple[int, ...]=1, dilation:int|tuple[int, ...]=1) -> Tensor:
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_))
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
@@ -2076,12 +2076,12 @@ class Tensor(SimpleMathTrait):
x = x.shrink(tuple(noop + flatten(((0,o), (0,k)) for o,k in zip(o_,k_))))
return x.permute(*range(len(noop)), *[len(noop)+i*2 for i in range(len(i_))], *[len(noop)+i*2+1 for i in range(len(i_))])
def _resolve_pool_pads(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]:
def _resolve_pool_pads(self, padding:int|Sequence[int], dims:int) -> Sequence[int]:
if not isinstance(padding, int) and not (len(padding) == 2*dims or len(padding) == dims):
raise ValueError(f"Padding must be an int or a sequence of length {dims} or {2*dims}, but got {padding=} for {self.shape=} with {dims=}.")
return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
def _apply_ceil_mode(self, pads:Sequence[int], k_:tuple[sint, ...], s_:Union[tuple[int, ...], int], d_:Union[tuple[int, ...], int]) -> list[int]:
def _apply_ceil_mode(self, pads:Sequence[int], k_:tuple[sint, ...], s_:int|tuple[int, ...], d_:int|tuple[int, ...]) -> list[int]:
(d_,s_), i_ = (make_tuple(x, len(k_)) for x in (d_,s_)), self.shape[-len(k_):]
pads, grouped_pads = list(pads), _flat_to_grouped(pads)
# https://arxiv.org/pdf/1603.07285 section 5.1, relationship 15.
@@ -2181,8 +2181,8 @@ class Tensor(SimpleMathTrait):
if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation)
return self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
acc_dtype:Optional[DTypeLike]=None) -> Tensor:
def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
acc_dtype:DTypeLike|None=None) -> Tensor:
"""
Applies a convolution over a tensor with a given `weight` and optional `bias`.
@@ -2255,7 +2255,7 @@ class Tensor(SimpleMathTrait):
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
def conv_transpose2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
"""
Applies a transposed convolution over a tensor with a given `weight` and optional `bias`.
@@ -2294,7 +2294,7 @@ class Tensor(SimpleMathTrait):
padding = flatten((((k-1)*d-pB,(k-1)*d-pA+op) for k,d,(pB,pA),op in reversed(list(zip(HW, dilation, padding, output_padding)))))
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
def dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
def dot(self, w:Tensor, acc_dtype:DTypeLike|None=None) -> Tensor:
"""
Performs dot product between two tensors.
@@ -2322,7 +2322,7 @@ class Tensor(SimpleMathTrait):
w = w.reshape(*w.shape[0:-2], *[1]*min(dx-1, dw-1, 1), *w.shape[axis_w:]).transpose(-1, axis_w)
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)
def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
def matmul(self, x:Tensor, reverse=False, acc_dtype:DTypeLike|None=None) -> Tensor:
"""
Performs matrix multiplication between two tensors.
@@ -2487,7 +2487,7 @@ class Tensor(SimpleMathTrait):
src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
return src, mask
def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']]=None) -> Tensor:
def scatter(self, dim:int, index:Tensor, src:Tensor|ConstType, reduce:Literal['multiply', 'add']|None=None) -> Tensor:
"""
Scatters `src` values along an axis specified by `dim`.
Apply `add` or `multiply` reduction operation with `reduce`.
@@ -2552,7 +2552,7 @@ class Tensor(SimpleMathTrait):
```
"""
src, mask = self._pre_scatter(dim, index, src)
def _inv_mask(a:Union[Tensor, ConstType], b:Union[Tensor, ConstType]) -> Tensor: return mask.any(-1).logical_not().where(a, b)
def _inv_mask(a:Tensor|ConstType, b:Tensor|ConstType) -> Tensor: return mask.any(-1).logical_not().where(a, b)
# TODO: should not overwrite acc_dtype here?
if reduce == "sum": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0))
if reduce == "prod": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1))
@@ -2821,7 +2821,7 @@ class Tensor(SimpleMathTrait):
"""
return (self.isinf()|self.isnan()).logical_not()
def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor:
def lerp(self, end:Tensor, weight:Tensor|float) -> Tensor:
"""
Linearly interpolates between `self` and `end` by `weight`.
@@ -3167,7 +3167,7 @@ class Tensor(SimpleMathTrait):
raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
return self.reshape(shape)._apply_uop(UOp.expand, arg=new_shape)
def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]:
def _broadcasted(self, y:Tensor|ConstType|UOp, reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]:
x: Tensor = self
if not isinstance(y, Tensor):
# make y a Tensor
@@ -3186,7 +3186,7 @@ class Tensor(SimpleMathTrait):
# broadcast
return x._broadcast_to(out_shape:=_broadcast_shape(x.shape, y.shape)), y._broadcast_to(out_shape)
def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
def add(self, x:Tensor|ConstType, reverse=False) -> Tensor:
"""
Adds `self` and `x`.
Equivalent to `self + x`.
@@ -3206,7 +3206,7 @@ class Tensor(SimpleMathTrait):
"""
return self._apply_broadcasted_uop(UOp.add, x, reverse)
def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
def sub(self, x:Tensor|ConstType, reverse=False) -> Tensor:
"""
Subtracts `x` from `self`.
Equivalent to `self - x`.
@@ -3227,7 +3227,7 @@ class Tensor(SimpleMathTrait):
a, b = self._broadcasted(x, reverse)
return a + (-b)
def mul(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
def mul(self, x:Tensor|ConstType, reverse=False) -> Tensor:
"""
Multiplies `self` and `x`.
Equivalent to `self * x`.
@@ -3247,7 +3247,7 @@ class Tensor(SimpleMathTrait):
"""
return self._apply_broadcasted_uop(UOp.mul, x, reverse)
def idiv(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
def idiv(self, x:Tensor|ConstType, reverse=False) -> Tensor:
"""
Divides `self` by `x`.
Equivalent to `self // x`.
@@ -3260,7 +3260,7 @@ class Tensor(SimpleMathTrait):
"""
return self._apply_broadcasted_uop(UOp.idiv, x, reverse)
def div(self, x:Union[Tensor, ConstType], reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor:
def div(self, x:Tensor|ConstType, reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor:
"""
Divides `self` by `x`.
Equivalent to `self / x`.
@@ -3287,7 +3287,7 @@ class Tensor(SimpleMathTrait):
if rounding_mode is not None: raise RuntimeError(f"{rounding_mode=} is not supported")
return d
def mod(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
def mod(self, x:Tensor|ConstType, reverse=False) -> Tensor:
"""
Mod `self` by `x`.
Equivalent to `self % x`.
@@ -3300,7 +3300,7 @@ class Tensor(SimpleMathTrait):
a, b = self._broadcasted(x, reverse)
return a - a.div(b, rounding_mode="floor") * b
def bitwise_xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
def bitwise_xor(self, x:Tensor|ConstType, reverse=False) -> Tensor:
"""
Computes bitwise xor of `self` and `x`.
Equivalent to `self ^ x`.
@@ -3316,7 +3316,7 @@ class Tensor(SimpleMathTrait):
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return self._apply_broadcasted_uop(UOp.bitwise_xor, x, reverse)
def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
def bitwise_and(self, x:Tensor|ConstType, reverse=False) -> Tensor:
"""
Compute the bitwise AND of `self` and `x`.
Equivalent to `self & x`.
@@ -3331,7 +3331,7 @@ class Tensor(SimpleMathTrait):
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return self._apply_broadcasted_uop(UOp.bitwise_and, x, reverse)
def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
def bitwise_or(self, x:Tensor|ConstType, reverse=False) -> Tensor:
"""
Compute the bitwise OR of `self` and `x`.
Equivalent to `self | x`.
@@ -3384,7 +3384,7 @@ class Tensor(SimpleMathTrait):
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
return self.idiv(2 ** x)
def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
def pow(self, x:Tensor|ConstType, reverse=False) -> Tensor:
"""
Computes power of `self` with `x`.
Equivalent to `self ** x`.
@@ -3407,7 +3407,7 @@ class Tensor(SimpleMathTrait):
ret = base._apply_uop(UOp.pow, exponent)
return ret.round().cast(self.dtype) if not dtypes.is_float(self.dtype) else ret
def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
def maximum(self, x:Tensor|ConstType) -> Tensor:
"""
Computes element-wise maximum of `self` and `x`.
@@ -3420,7 +3420,7 @@ class Tensor(SimpleMathTrait):
"""
return self._apply_broadcasted_uop(UOp.maximum, x)
def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
def minimum(self, x:Tensor|ConstType) -> Tensor:
"""
Computes element-wise minimum of `self` and `x`.
@@ -3434,7 +3434,7 @@ class Tensor(SimpleMathTrait):
t, x = self._broadcasted(x)
return t._inverse().maximum(x._inverse())._inverse()
def where(self:Tensor, x:Union[Tensor, ConstType, sint], y:Union[Tensor, ConstType, sint]):
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint):
"""
Return a tensor of elements selected from either `x` or `y`, depending on `self`.
`output_i = x_i if self_i else y_i`.
@@ -3458,7 +3458,7 @@ class Tensor(SimpleMathTrait):
cond, y = cond._broadcasted(y, match_dtype=False)
return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y))
def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self)
def masked_fill(self:Tensor, mask:Tensor, value:Tensor|ConstType): return mask.where(value, self)
def copysign(self, other) -> Tensor:
"""
@@ -3503,7 +3503,7 @@ class Tensor(SimpleMathTrait):
# ***** functional nn ops *****
def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
def linear(self, weight:Tensor, bias:Tensor|None=None):
"""
Applies a linear transformation to `self` using `weight` and `bias`.
@@ -3530,7 +3530,7 @@ class Tensor(SimpleMathTrait):
"""
return functools.reduce(lambda x,f: f(x), ll, self)
def layernorm(self, axis:Union[int,tuple[int,...]]=-1, eps:float=1e-5) -> Tensor:
def layernorm(self, axis:int|tuple[int,...]=-1, eps:float=1e-5) -> Tensor:
"""
Applies Layer Normalization over a mini-batch of inputs.
@@ -3549,7 +3549,7 @@ class Tensor(SimpleMathTrait):
y = (self - self.mean(axis, keepdim=True))
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,tuple[int,...]]=1) -> Tensor:
def batchnorm(self, weight:Tensor|None, bias:Tensor|None, mean:Tensor, invstd:Tensor, axis:int|tuple[int, ...]=1) -> Tensor:
"""
Applies Batch Normalization over a mini-batch of inputs.
@@ -3722,7 +3722,7 @@ class Tensor(SimpleMathTrait):
ret = -self.log_softmax(axis=1).mul(Y).sum(axis=1)
return ret._do_reduction(reduction)
def nll_loss(self, Y:Tensor, weight:Optional[Tensor]=None, ignore_index:Optional[int]=None, reduction:ReductionStr="mean") -> Tensor:
def nll_loss(self, Y:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean") -> Tensor:
"""
Compute the negative log likelihood loss between log-probabilities and target labels.
@@ -3805,7 +3805,7 @@ class Tensor(SimpleMathTrait):
"""
return dtypes.is_float(self.dtype)
def size(self, dim:Optional[int]=None) -> Union[sint, tuple[sint, ...]]:
def size(self, dim:int|None=None) -> sint|tuple[sint, ...]:
"""
Return the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.
@@ -3935,7 +3935,7 @@ class Tensor(SimpleMathTrait):
# *** image Tensor function replacements ***
def image_dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
def image_dot(self, w:Tensor, acc_dtype:DTypeLike|None=None) -> Tensor:
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
x, dx, dw = self, self.ndim, w.ndim
if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D")
@@ -3951,7 +3951,7 @@ class Tensor(SimpleMathTrait):
cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
return cx.image_conv2d(cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
def image_conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None) -> Tensor:
def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None) -> Tensor:
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape