diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d001c657dc..8c50b23a9f 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2,7 +2,7 @@ from __future__ import annotations import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref from contextlib import ContextDecorator -from typing import Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex +from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.dtype import _from_np_dtype, _to_np_dtype from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup @@ -46,7 +46,7 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None: # **** Tensor helper functions **** -def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None): +def _metaop(op, shape:tuple[sint,...], dtype:DType, device:str|tuple[str, ...], arg=None): if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg) return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None) @@ -63,7 +63,7 @@ def get_shape(x) -> tuple[int, ...]: if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}") return (len(subs),) + (subs[0] if subs else ()) -def _frompy(x:Union[list, tuple, bytes], dtype:DType) -> UOp: +def _frompy(x:list|tuple|bytes, dtype:DType) -> UOp: if isinstance(x, bytes): ret, data = UOp.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x else: ret = UOp.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON") @@ -74,7 +74,7 @@ def _frompy(x:Union[list, tuple, bytes], dtype:DType) -> UOp: ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data))) return ret -def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:Union[str, tuple[str, ...]], dtype:DType) -> list[list[Tensor]]: +def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:str|tuple[str, ...], dtype:DType) -> list[list[Tensor]]: return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype) for m in mat], dim=dim) for k in range(len(mat[0]))] for dim in range(dims)] @@ -126,18 +126,18 @@ class Tensor(SimpleMathTrait): training: ClassVar[bool] = False no_grad: ClassVar[bool] = False - def __init__(self, data:Union[None, ConstType, bytes, list, tuple, UOp, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821 - device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None): + def __init__(self, data:ConstType|bytes|list|tuple|UOp|'np.ndarray'|pathlib.Path|None, # type: ignore [name-defined] # noqa: F821 + device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None): if dtype is not None: dtype = to_dtype(dtype) if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device) # tensors can have gradients if you have called .backward - self.grad: Optional[Tensor] = None + self.grad:Tensor|None = None # NOTE: this can be in three states. False and None: no gradient, True: gradient # None (the default) will be updated to True if it's put in an optimizer - self.requires_grad: Optional[bool] = requires_grad + self.requires_grad:bool|None = requires_grad # create a LazyBuffer from the different types of inputs if isinstance(data, UOp): @@ -182,7 +182,7 @@ class Tensor(SimpleMathTrait): needs_input_grad = [t.requires_grad for t in (self,)+x] return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False) - def _apply_broadcasted_uop(self, fxn:Callable, x:Union[Tensor, ConstType], reverse=False) -> Tensor: + def _apply_broadcasted_uop(self, fxn:Callable, x:Tensor|ConstType, reverse=False) -> Tensor: lhs,rhs = self._broadcasted(x, reverse) return lhs._apply_uop(fxn, rhs) @@ -215,7 +215,7 @@ class Tensor(SimpleMathTrait): return self.shape[0] @property - def device(self) -> Union[str, tuple[str, ...]]: return self.lazydata.device + def device(self) -> str|tuple[str, ...]: return self.lazydata.device @property def shape(self) -> tuple[sint, ...]: return self.lazydata.shape @@ -326,7 +326,7 @@ class Tensor(SimpleMathTrait): # TODO: should be Tensor.tolist() -> Union[list[ConstType], ConstType]. The list is Sequence because mypy expects memoryview.tolist() -> list[int] # src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803 - def tolist(self) -> Union[Sequence[ConstType], ConstType]: + def tolist(self) -> Sequence[ConstType]|ConstType: """ Returns the value of this tensor as a nested list. Returns single value for const tensor. @@ -365,7 +365,7 @@ class Tensor(SimpleMathTrait): if self.grad is not None: ret.grad = self.grad.clone() return ret - def to(self, device:Optional[Union[str, tuple[str, ...]]]) -> Tensor: + def to(self, device:str|tuple[str, ...]|None) -> Tensor: """ Moves the tensor to the given device. """ @@ -376,7 +376,7 @@ class Tensor(SimpleMathTrait): if self.grad is not None: ret.grad = self.grad.to(device) return ret - def to_(self, device:Optional[Union[str, tuple[str, ...]]]): + def to_(self, device:str|tuple[str, ...]|None): """ Moves the tensor to the given device in place. """ @@ -384,7 +384,7 @@ class Tensor(SimpleMathTrait): if self.grad is not None and real.grad is not None: self.grad.replace(real.grad) return self.replace(real) - def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> Tensor: + def shard(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor: """ Shards the tensor across the given devices. Optionally specify which axis to shard on. @@ -398,7 +398,7 @@ class Tensor(SimpleMathTrait): mlb = self.lazydata.shard(devices, self._resolve_dim(axis) if axis is not None else None) return Tensor(mlb, device=devices, requires_grad=self.requires_grad) - def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None): + def shard_(self, devices:tuple[str, ...], axis:int|None=None): """ Shards the tensor across the given devices in place. """ @@ -415,7 +415,7 @@ class Tensor(SimpleMathTrait): # ***** creation entrypoint ***** @staticmethod - def _metaop(op, shape, device:Optional[Union[tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs): + def _metaop(op, shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None, arg=None, **kwargs): dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float if isinstance(device, tuple): return Tensor(UOp.multi(*[UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], axis=None), @@ -493,7 +493,7 @@ class Tensor(SimpleMathTrait): return counts0.cat(counts1) @staticmethod - def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, contiguous:bool=True, **kwargs) -> Tensor: + def rand(*shape, device:str|None=None, dtype:DTypeLike|None=None, contiguous:bool=True, **kwargs) -> Tensor: """ Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`. @@ -633,7 +633,7 @@ class Tensor(SimpleMathTrait): return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype) @staticmethod - def linspace(start:Union[int, float], stop:Union[int, float], steps:int, **kwargs) -> Tensor: + def linspace(start:int|float, stop:int|float, steps:int, **kwargs) -> Tensor: """ Returns a 1-D tensor of `steps` evenly spaced values from `start` to `stop`, inclusive. @@ -653,7 +653,7 @@ class Tensor(SimpleMathTrait): return (start + Tensor.arange(steps, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype) @staticmethod - def eye(n:int, m:Optional[int]=None, **kwargs) -> Tensor: + def eye(n:int, m:int|None=None, **kwargs) -> Tensor: """ Returns a 2-D tensor with `n` rows and `m` columns, with ones on the diagonal and zeros elsewhere. @@ -740,7 +740,7 @@ class Tensor(SimpleMathTrait): # ***** rng hlops ***** @staticmethod - def randn(*shape, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None, **kwargs) -> Tensor: + def randn(*shape, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor: """ Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`. If `dtype` is not specified, the default type is used. @@ -777,7 +777,7 @@ class Tensor(SimpleMathTrait): return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs) @staticmethod - def normal(*shape, mean=0.0, std=1.0, requires_grad:Optional[bool]=None, **kwargs) -> Tensor: + def normal(*shape, mean=0.0, std=1.0, requires_grad:bool|None=None, **kwargs) -> Tensor: """ Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`. @@ -792,7 +792,7 @@ class Tensor(SimpleMathTrait): return ((std * Tensor.randn(*shape, **kwargs)) + mean).requires_grad_(requires_grad) @staticmethod - def uniform(*shape, low=0.0, high=1.0, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None, **kwargs) -> Tensor: + def uniform(*shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor: """ Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`. @@ -883,7 +883,7 @@ class Tensor(SimpleMathTrait): # ***** toposort and backward pass ***** - def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None, materialize_grads=False) -> list[Tensor]: + def gradient(self, *targets:Tensor, gradient:Tensor|None=None, materialize_grads=False) -> list[Tensor]: """ Compute the gradient of the targets with respect to self. @@ -912,7 +912,7 @@ class Tensor(SimpleMathTrait): # create returned Tensors return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])] - def backward(self, gradient:Optional[Tensor]=None) -> Tensor: + def backward(self, gradient:Tensor|None=None) -> Tensor: """ Propagates the gradient of a tensor backwards through the computation graph. If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0. @@ -1007,7 +1007,7 @@ class Tensor(SimpleMathTrait): if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}") return self._apply_uop(UOp.flip, arg=tuple([i in axis_arg for i in range(len(self.shape))])) - def shrink(self, arg:tuple[Optional[tuple[sint, sint]], ...]) -> Tensor: + def shrink(self, arg:tuple[tuple[sint, sint]|None, ...]) -> Tensor: """ Returns a tensor that shrinks the each axis based on input arg. `arg` must have the same length as `self.ndim`. @@ -1027,7 +1027,7 @@ class Tensor(SimpleMathTrait): if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg)) - def pad(self, padding:Union[Sequence[sint], Sequence[Optional[tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor: + def pad(self, padding:Sequence[sint]|Sequence[tuple[sint, sint]|None], mode:str="constant", value:float=0.0) -> Tensor: """ Returns a tensor with padding applied based on the input `padding`. @@ -1065,7 +1065,7 @@ class Tensor(SimpleMathTrait): if len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads") pX = _flat_to_grouped(tuple(cast(Sequence[sint], padding)) + (0,0)*(self.ndim - len(padding)//2)) # group padding - else: pX = tuple((0,0) if p is None else p for p in cast(Sequence[Optional[tuple[sint, sint]]], padding)) + else: pX = tuple((0,0) if p is None else p for p in cast(Sequence[tuple[sint, sint]|None], padding)) if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}") X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX) if mode == "constant": @@ -1092,7 +1092,7 @@ class Tensor(SimpleMathTrait): # ***** movement high level ops ***** - def _getitem(self, indices, v: Optional[Tensor] = None) -> Tensor: + def _getitem(self, indices, v: Tensor|None = None) -> Tensor: # wrap single index into a list if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices] x, indices = self, list(indices) @@ -1221,7 +1221,7 @@ class Tensor(SimpleMathTrait): """ return self._getitem(indices) - def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None: + def __setitem__(self, indices, v:Tensor|ConstType) -> None: if isinstance(self.device, str) and self.device.startswith("DISK"): self._getitem(indices).assign(v) return @@ -1293,7 +1293,7 @@ class Tensor(SimpleMathTrait): # checks for shapes and number of dimensions delegated to cat return Tensor.cat(*[t.unsqueeze(dim) for t in [self, *args]], dim=dim) - def repeat_interleave(self, repeats:int, dim:Optional[int]=None) -> Tensor: + def repeat_interleave(self, repeats:int, dim:int|None=None) -> Tensor: """ Repeat elements of a tensor. @@ -1331,7 +1331,7 @@ class Tensor(SimpleMathTrait): if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}") return dim + total if dim < 0 else dim - def split(self, sizes:Union[int, list[int]], dim:int=0) -> tuple[Tensor, ...]: + def split(self, sizes:int|Sequence[int], dim:int=0) -> tuple[Tensor, ...]: """ Splits the tensor into chunks along the dimension specified by `dim`. If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller. @@ -1380,7 +1380,7 @@ class Tensor(SimpleMathTrait): dim = self._resolve_dim(dim) return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim)) - def meshgrid(self:Tensor, *args:Tensor, indexing:Union[Literal["ij"], Literal["xy"]]="ij") -> tuple[Tensor, ...]: + def meshgrid(self:Tensor, *args:Tensor, indexing:Literal["ij", "xy"]="ij") -> tuple[Tensor, ...]: """ Generates coordinate matrices from coordinate vectors. Input tensors can be scalars or 1D tensors. @@ -1407,7 +1407,7 @@ class Tensor(SimpleMathTrait): output_shape = _broadcast_shape(*(t.shape for t in tensors)) return tuple(t._broadcast_to(output_shape) for t in tensors) - def squeeze(self, dim:Optional[int]=None) -> Tensor: + def squeeze(self, dim:int|None=None) -> Tensor: """ Returns a tensor with specified dimensions of input of size 1 removed. If `dim` is not specified, all dimensions with size 1 are removed. @@ -1497,7 +1497,7 @@ class Tensor(SimpleMathTrait): dim = self._resolve_dim(dim) return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:]) - def roll(self, shifts:Union[int, tuple[int, ...]], dims:Union[int, tuple[int, ...]]) -> Tensor: + def roll(self, shifts:int|tuple[int, ...], dims:int|tuple[int, ...]) -> Tensor: """ Rolls the tensor along specified dimension(s). The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension. @@ -1559,13 +1559,13 @@ class Tensor(SimpleMathTrait): # ***** reduce ops ***** - def _reduce(self, op:Ops, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor: + def _reduce(self, op:Ops, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor: axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1))) if self.ndim == 0: axis = () ret = self._apply_uop(UOp.r, op=op, axis=axis) return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis)) - def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None): + def sum(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None): """ Returns the sum of the elements of the tensor along the specified axis or axes. @@ -1592,7 +1592,7 @@ class Tensor(SimpleMathTrait): ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim) return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret - def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None): + def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None): """ Returns the product of the elements of the tensor along the specified axis or axes. @@ -1618,7 +1618,7 @@ class Tensor(SimpleMathTrait): """ return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim) - def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False): + def max(self, axis:int|Sequence[int]|None=None, keepdim=False): """ Returns the maximum value of the tensor along the specified axis or axes. @@ -1643,7 +1643,7 @@ class Tensor(SimpleMathTrait): def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not() - def min(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False): + def min(self, axis:int|Sequence[int]|None=None, keepdim=False): """ Returns the minimum value of the tensor along the specified axis or axes. @@ -1666,7 +1666,7 @@ class Tensor(SimpleMathTrait): """ return self._inverse().max(axis=axis, keepdim=keepdim)._inverse() - def any(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False): + def any(self, axis:int|Sequence[int]|None=None, keepdim=False): """ Tests if any element evaluates to `True` along the specified axis or axes. @@ -1688,7 +1688,7 @@ class Tensor(SimpleMathTrait): """ return self.bool().max(axis, keepdim) - def all(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False): + def all(self, axis:int|Sequence[int]|None=None, keepdim=False): """ Tests if all element evaluates to `True` along the specified axis or axes. @@ -1730,7 +1730,7 @@ class Tensor(SimpleMathTrait): is_nan_close = (self.isnan() & other.isnan()) & equal_nan return is_finite_close | is_infinite_close | is_nan_close - def mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False): + def mean(self, axis:int|Sequence[int]|None=None, keepdim=False): """ Returns the mean value of the tensor along the specified axis or axes. @@ -1756,7 +1756,7 @@ class Tensor(SimpleMathTrait): numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim) return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])).cast(output_dtype) - def var(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1): + def var(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1): """ Returns the variance of the tensor along the specified axis or axes. @@ -1782,7 +1782,7 @@ class Tensor(SimpleMathTrait): n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)]) return squares.sum(axis=axis, keepdim=keepdim).div(smax([0, n-correction])) - def var_mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1): + def var_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1): """ Calculates the variance and mean over the dimensions specified by dim. Syntactic sugar around `Tensor.var` and `Tensor.mean` to match `torch.var_mean`. @@ -1799,7 +1799,7 @@ class Tensor(SimpleMathTrait): """ return self.var(axis, keepdim, correction), self.mean(axis, keepdim) - def std(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1): + def std(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1): """ Returns the standard deviation of the tensor along the specified axis or axes. @@ -1823,7 +1823,7 @@ class Tensor(SimpleMathTrait): """ return self.var(axis, keepdim, correction).sqrt() - def std_mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1): + def std_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1): """ Calculates the standard deviation and mean over the dimensions specified by dim. Syntactic sugar around `Tensor.std` and `Tensor.mean` to match `torch.std_mean`. @@ -1840,13 +1840,13 @@ class Tensor(SimpleMathTrait): """ return self.std(axis, keepdim, correction), self.mean(axis, keepdim) - def _softmax(self, axis, dtype:Optional[DTypeLike]=None): + def _softmax(self, axis, dtype:DTypeLike|None=None): m = self - self.max(axis=axis, keepdim=True).detach() if dtype is not None: m = m.cast(dtype) e = m.exp() return m, e, e.sum(axis=axis, keepdim=True) - def softmax(self, axis=-1, dtype:Optional[DTypeLike]=None): + def softmax(self, axis=-1, dtype:DTypeLike|None=None): """ Applies the softmax function to the tensor along the specified axis. @@ -1869,7 +1869,7 @@ class Tensor(SimpleMathTrait): _, e, ss = self._softmax(axis, dtype) return e.div(ss) - def log_softmax(self, axis=-1, dtype:Optional[DTypeLike]=None): + def log_softmax(self, axis=-1, dtype:DTypeLike|None=None): """ Applies the log-softmax function to the tensor along the specified axis. @@ -2005,7 +2005,7 @@ class Tensor(SimpleMathTrait): return self._inverse().argmax(axis=axis, keepdim=keepdim) @staticmethod - def einsum(formula:str, *operands:Tensor|Sequence[Tensor], acc_dtype:Optional[DTypeLike]=None) -> Tensor: + def einsum(formula:str, *operands:Tensor|Sequence[Tensor], acc_dtype:DTypeLike|None=None) -> Tensor: """ Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention. @@ -2051,7 +2051,7 @@ class Tensor(SimpleMathTrait): # ***** processing ops ***** - def _pool(self, k_:tuple[sint, ...], stride:Union[tuple[int, ...], int]=1, dilation:Union[tuple[int, ...], int]=1) -> Tensor: + def _pool(self, k_:tuple[sint, ...], stride:int|tuple[int, ...]=1, dilation:int|tuple[int, ...]=1) -> Tensor: assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}" s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_)) assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}" @@ -2076,12 +2076,12 @@ class Tensor(SimpleMathTrait): x = x.shrink(tuple(noop + flatten(((0,o), (0,k)) for o,k in zip(o_,k_)))) return x.permute(*range(len(noop)), *[len(noop)+i*2 for i in range(len(i_))], *[len(noop)+i*2+1 for i in range(len(i_))]) - def _resolve_pool_pads(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]: + def _resolve_pool_pads(self, padding:int|Sequence[int], dims:int) -> Sequence[int]: if not isinstance(padding, int) and not (len(padding) == 2*dims or len(padding) == dims): raise ValueError(f"Padding must be an int or a sequence of length {dims} or {2*dims}, but got {padding=} for {self.shape=} with {dims=}.") return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1]) - def _apply_ceil_mode(self, pads:Sequence[int], k_:tuple[sint, ...], s_:Union[tuple[int, ...], int], d_:Union[tuple[int, ...], int]) -> list[int]: + def _apply_ceil_mode(self, pads:Sequence[int], k_:tuple[sint, ...], s_:int|tuple[int, ...], d_:int|tuple[int, ...]) -> list[int]: (d_,s_), i_ = (make_tuple(x, len(k_)) for x in (d_,s_)), self.shape[-len(k_):] pads, grouped_pads = list(pads), _flat_to_grouped(pads) # https://arxiv.org/pdf/1603.07285 section 5.1, relationship 15. @@ -2181,8 +2181,8 @@ class Tensor(SimpleMathTrait): if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation) return self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0))) - def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0, - acc_dtype:Optional[DTypeLike]=None) -> Tensor: + def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0, + acc_dtype:DTypeLike|None=None) -> Tensor: """ Applies a convolution over a tensor with a given `weight` and optional `bias`. @@ -2255,7 +2255,7 @@ class Tensor(SimpleMathTrait): return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward() - def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor: + def conv_transpose2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor: """ Applies a transposed convolution over a tensor with a given `weight` and optional `bias`. @@ -2294,7 +2294,7 @@ class Tensor(SimpleMathTrait): padding = flatten((((k-1)*d-pB,(k-1)*d-pA+op) for k,d,(pB,pA),op in reversed(list(zip(HW, dilation, padding, output_padding))))) return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding) - def dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor: + def dot(self, w:Tensor, acc_dtype:DTypeLike|None=None) -> Tensor: """ Performs dot product between two tensors. @@ -2322,7 +2322,7 @@ class Tensor(SimpleMathTrait): w = w.reshape(*w.shape[0:-2], *[1]*min(dx-1, dw-1, 1), *w.shape[axis_w:]).transpose(-1, axis_w) return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype) - def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DTypeLike]=None) -> Tensor: + def matmul(self, x:Tensor, reverse=False, acc_dtype:DTypeLike|None=None) -> Tensor: """ Performs matrix multiplication between two tensors. @@ -2487,7 +2487,7 @@ class Tensor(SimpleMathTrait): src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask)) return src, mask - def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']]=None) -> Tensor: + def scatter(self, dim:int, index:Tensor, src:Tensor|ConstType, reduce:Literal['multiply', 'add']|None=None) -> Tensor: """ Scatters `src` values along an axis specified by `dim`. Apply `add` or `multiply` reduction operation with `reduce`. @@ -2552,7 +2552,7 @@ class Tensor(SimpleMathTrait): ``` """ src, mask = self._pre_scatter(dim, index, src) - def _inv_mask(a:Union[Tensor, ConstType], b:Union[Tensor, ConstType]) -> Tensor: return mask.any(-1).logical_not().where(a, b) + def _inv_mask(a:Tensor|ConstType, b:Tensor|ConstType) -> Tensor: return mask.any(-1).logical_not().where(a, b) # TODO: should not overwrite acc_dtype here? if reduce == "sum": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)) if reduce == "prod": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1)) @@ -2821,7 +2821,7 @@ class Tensor(SimpleMathTrait): """ return (self.isinf()|self.isnan()).logical_not() - def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor: + def lerp(self, end:Tensor, weight:Tensor|float) -> Tensor: """ Linearly interpolates between `self` and `end` by `weight`. @@ -3167,7 +3167,7 @@ class Tensor(SimpleMathTrait): raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}") return self.reshape(shape)._apply_uop(UOp.expand, arg=new_shape) - def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]: + def _broadcasted(self, y:Tensor|ConstType|UOp, reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]: x: Tensor = self if not isinstance(y, Tensor): # make y a Tensor @@ -3186,7 +3186,7 @@ class Tensor(SimpleMathTrait): # broadcast return x._broadcast_to(out_shape:=_broadcast_shape(x.shape, y.shape)), y._broadcast_to(out_shape) - def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: + def add(self, x:Tensor|ConstType, reverse=False) -> Tensor: """ Adds `self` and `x`. Equivalent to `self + x`. @@ -3206,7 +3206,7 @@ class Tensor(SimpleMathTrait): """ return self._apply_broadcasted_uop(UOp.add, x, reverse) - def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: + def sub(self, x:Tensor|ConstType, reverse=False) -> Tensor: """ Subtracts `x` from `self`. Equivalent to `self - x`. @@ -3227,7 +3227,7 @@ class Tensor(SimpleMathTrait): a, b = self._broadcasted(x, reverse) return a + (-b) - def mul(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: + def mul(self, x:Tensor|ConstType, reverse=False) -> Tensor: """ Multiplies `self` and `x`. Equivalent to `self * x`. @@ -3247,7 +3247,7 @@ class Tensor(SimpleMathTrait): """ return self._apply_broadcasted_uop(UOp.mul, x, reverse) - def idiv(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: + def idiv(self, x:Tensor|ConstType, reverse=False) -> Tensor: """ Divides `self` by `x`. Equivalent to `self // x`. @@ -3260,7 +3260,7 @@ class Tensor(SimpleMathTrait): """ return self._apply_broadcasted_uop(UOp.idiv, x, reverse) - def div(self, x:Union[Tensor, ConstType], reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor: + def div(self, x:Tensor|ConstType, reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor: """ Divides `self` by `x`. Equivalent to `self / x`. @@ -3287,7 +3287,7 @@ class Tensor(SimpleMathTrait): if rounding_mode is not None: raise RuntimeError(f"{rounding_mode=} is not supported") return d - def mod(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: + def mod(self, x:Tensor|ConstType, reverse=False) -> Tensor: """ Mod `self` by `x`. Equivalent to `self % x`. @@ -3300,7 +3300,7 @@ class Tensor(SimpleMathTrait): a, b = self._broadcasted(x, reverse) return a - a.div(b, rounding_mode="floor") * b - def bitwise_xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: + def bitwise_xor(self, x:Tensor|ConstType, reverse=False) -> Tensor: """ Computes bitwise xor of `self` and `x`. Equivalent to `self ^ x`. @@ -3316,7 +3316,7 @@ class Tensor(SimpleMathTrait): if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") return self._apply_broadcasted_uop(UOp.bitwise_xor, x, reverse) - def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: + def bitwise_and(self, x:Tensor|ConstType, reverse=False) -> Tensor: """ Compute the bitwise AND of `self` and `x`. Equivalent to `self & x`. @@ -3331,7 +3331,7 @@ class Tensor(SimpleMathTrait): if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") return self._apply_broadcasted_uop(UOp.bitwise_and, x, reverse) - def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: + def bitwise_or(self, x:Tensor|ConstType, reverse=False) -> Tensor: """ Compute the bitwise OR of `self` and `x`. Equivalent to `self | x`. @@ -3384,7 +3384,7 @@ class Tensor(SimpleMathTrait): assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}" return self.idiv(2 ** x) - def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: + def pow(self, x:Tensor|ConstType, reverse=False) -> Tensor: """ Computes power of `self` with `x`. Equivalent to `self ** x`. @@ -3407,7 +3407,7 @@ class Tensor(SimpleMathTrait): ret = base._apply_uop(UOp.pow, exponent) return ret.round().cast(self.dtype) if not dtypes.is_float(self.dtype) else ret - def maximum(self, x:Union[Tensor, ConstType]) -> Tensor: + def maximum(self, x:Tensor|ConstType) -> Tensor: """ Computes element-wise maximum of `self` and `x`. @@ -3420,7 +3420,7 @@ class Tensor(SimpleMathTrait): """ return self._apply_broadcasted_uop(UOp.maximum, x) - def minimum(self, x:Union[Tensor, ConstType]) -> Tensor: + def minimum(self, x:Tensor|ConstType) -> Tensor: """ Computes element-wise minimum of `self` and `x`. @@ -3434,7 +3434,7 @@ class Tensor(SimpleMathTrait): t, x = self._broadcasted(x) return t._inverse().maximum(x._inverse())._inverse() - def where(self:Tensor, x:Union[Tensor, ConstType, sint], y:Union[Tensor, ConstType, sint]): + def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint): """ Return a tensor of elements selected from either `x` or `y`, depending on `self`. `output_i = x_i if self_i else y_i`. @@ -3458,7 +3458,7 @@ class Tensor(SimpleMathTrait): cond, y = cond._broadcasted(y, match_dtype=False) return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y)) - def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self) + def masked_fill(self:Tensor, mask:Tensor, value:Tensor|ConstType): return mask.where(value, self) def copysign(self, other) -> Tensor: """ @@ -3503,7 +3503,7 @@ class Tensor(SimpleMathTrait): # ***** functional nn ops ***** - def linear(self, weight:Tensor, bias:Optional[Tensor]=None): + def linear(self, weight:Tensor, bias:Tensor|None=None): """ Applies a linear transformation to `self` using `weight` and `bias`. @@ -3530,7 +3530,7 @@ class Tensor(SimpleMathTrait): """ return functools.reduce(lambda x,f: f(x), ll, self) - def layernorm(self, axis:Union[int,tuple[int,...]]=-1, eps:float=1e-5) -> Tensor: + def layernorm(self, axis:int|tuple[int,...]=-1, eps:float=1e-5) -> Tensor: """ Applies Layer Normalization over a mini-batch of inputs. @@ -3549,7 +3549,7 @@ class Tensor(SimpleMathTrait): y = (self - self.mean(axis, keepdim=True)) return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt()) - def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,tuple[int,...]]=1) -> Tensor: + def batchnorm(self, weight:Tensor|None, bias:Tensor|None, mean:Tensor, invstd:Tensor, axis:int|tuple[int, ...]=1) -> Tensor: """ Applies Batch Normalization over a mini-batch of inputs. @@ -3722,7 +3722,7 @@ class Tensor(SimpleMathTrait): ret = -self.log_softmax(axis=1).mul(Y).sum(axis=1) return ret._do_reduction(reduction) - def nll_loss(self, Y:Tensor, weight:Optional[Tensor]=None, ignore_index:Optional[int]=None, reduction:ReductionStr="mean") -> Tensor: + def nll_loss(self, Y:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean") -> Tensor: """ Compute the negative log likelihood loss between log-probabilities and target labels. @@ -3805,7 +3805,7 @@ class Tensor(SimpleMathTrait): """ return dtypes.is_float(self.dtype) - def size(self, dim:Optional[int]=None) -> Union[sint, tuple[sint, ...]]: + def size(self, dim:int|None=None) -> sint|tuple[sint, ...]: """ Return the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor. @@ -3935,7 +3935,7 @@ class Tensor(SimpleMathTrait): # *** image Tensor function replacements *** - def image_dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor: + def image_dot(self, w:Tensor, acc_dtype:DTypeLike|None=None) -> Tensor: # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) x, dx, dw = self, self.ndim, w.ndim if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D") @@ -3951,7 +3951,7 @@ class Tensor(SimpleMathTrait): cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1)) return cx.image_conv2d(cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2) - def image_conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None) -> Tensor: + def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None) -> Tensor: base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape