mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Add consistent typing for python 3.10 in tensor.py (#9326)
* add consistent typing for python 3.10 in tensor.py * pull * Tensor.copysign (#9329) * fast amd gemm (#9318) * 50 TFLOP AMD gemm * add lds tiling * register tiling * flip locals * work * comment * remove those * fix Tensor.view with a tuple arg (#9330) * reorder binops (#9328) * reorder binops * test improvements + fix string tests * ugh, okay this * Make const moving not depend on the order (#9245) Since floats are not being flipped anymore this should help with const folding for floats * use empty for test instead of rand (#9332) * linter --------- Co-authored-by: chenyu <chenyu@fastmail.com> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> Co-authored-by: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
019417743c
commit
94db8426cb
@@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
|
||||
from contextlib import ContextDecorator
|
||||
from typing import Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
|
||||
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
@@ -46,7 +46,7 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None:
|
||||
|
||||
# **** Tensor helper functions ****
|
||||
|
||||
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None):
|
||||
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:str|tuple[str, ...], arg=None):
|
||||
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
|
||||
return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None)
|
||||
|
||||
@@ -63,7 +63,7 @@ def get_shape(x) -> tuple[int, ...]:
|
||||
if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}")
|
||||
return (len(subs),) + (subs[0] if subs else ())
|
||||
|
||||
def _frompy(x:Union[list, tuple, bytes], dtype:DType) -> UOp:
|
||||
def _frompy(x:list|tuple|bytes, dtype:DType) -> UOp:
|
||||
if isinstance(x, bytes): ret, data = UOp.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
|
||||
else:
|
||||
ret = UOp.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON")
|
||||
@@ -74,7 +74,7 @@ def _frompy(x:Union[list, tuple, bytes], dtype:DType) -> UOp:
|
||||
ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data)))
|
||||
return ret
|
||||
|
||||
def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:Union[str, tuple[str, ...]], dtype:DType) -> list[list[Tensor]]:
|
||||
def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:str|tuple[str, ...], dtype:DType) -> list[list[Tensor]]:
|
||||
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype) for m in mat], dim=dim)
|
||||
for k in range(len(mat[0]))] for dim in range(dims)]
|
||||
|
||||
@@ -126,18 +126,18 @@ class Tensor(SimpleMathTrait):
|
||||
training: ClassVar[bool] = False
|
||||
no_grad: ClassVar[bool] = False
|
||||
|
||||
def __init__(self, data:Union[None, ConstType, bytes, list, tuple, UOp, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
|
||||
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
|
||||
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'np.ndarray'|pathlib.Path|None, # type: ignore [name-defined] # noqa: F821
|
||||
device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None):
|
||||
if dtype is not None: dtype = to_dtype(dtype)
|
||||
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
|
||||
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
||||
|
||||
# tensors can have gradients if you have called .backward
|
||||
self.grad: Optional[Tensor] = None
|
||||
self.grad:Tensor|None = None
|
||||
|
||||
# NOTE: this can be in three states. False and None: no gradient, True: gradient
|
||||
# None (the default) will be updated to True if it's put in an optimizer
|
||||
self.requires_grad: Optional[bool] = requires_grad
|
||||
self.requires_grad:bool|None = requires_grad
|
||||
|
||||
# create a LazyBuffer from the different types of inputs
|
||||
if isinstance(data, UOp):
|
||||
@@ -182,7 +182,7 @@ class Tensor(SimpleMathTrait):
|
||||
needs_input_grad = [t.requires_grad for t in (self,)+x]
|
||||
return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False)
|
||||
|
||||
def _apply_broadcasted_uop(self, fxn:Callable, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
def _apply_broadcasted_uop(self, fxn:Callable, x:Tensor|ConstType, reverse=False) -> Tensor:
|
||||
lhs,rhs = self._broadcasted(x, reverse)
|
||||
return lhs._apply_uop(fxn, rhs)
|
||||
|
||||
@@ -215,7 +215,7 @@ class Tensor(SimpleMathTrait):
|
||||
return self.shape[0]
|
||||
|
||||
@property
|
||||
def device(self) -> Union[str, tuple[str, ...]]: return self.lazydata.device
|
||||
def device(self) -> str|tuple[str, ...]: return self.lazydata.device
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[sint, ...]: return self.lazydata.shape
|
||||
@@ -326,7 +326,7 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
# TODO: should be Tensor.tolist() -> Union[list[ConstType], ConstType]. The list is Sequence because mypy expects memoryview.tolist() -> list[int]
|
||||
# src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
|
||||
def tolist(self) -> Union[Sequence[ConstType], ConstType]:
|
||||
def tolist(self) -> Sequence[ConstType]|ConstType:
|
||||
"""
|
||||
Returns the value of this tensor as a nested list.
|
||||
Returns single value for const tensor.
|
||||
@@ -365,7 +365,7 @@ class Tensor(SimpleMathTrait):
|
||||
if self.grad is not None: ret.grad = self.grad.clone()
|
||||
return ret
|
||||
|
||||
def to(self, device:Optional[Union[str, tuple[str, ...]]]) -> Tensor:
|
||||
def to(self, device:str|tuple[str, ...]|None) -> Tensor:
|
||||
"""
|
||||
Moves the tensor to the given device.
|
||||
"""
|
||||
@@ -376,7 +376,7 @@ class Tensor(SimpleMathTrait):
|
||||
if self.grad is not None: ret.grad = self.grad.to(device)
|
||||
return ret
|
||||
|
||||
def to_(self, device:Optional[Union[str, tuple[str, ...]]]):
|
||||
def to_(self, device:str|tuple[str, ...]|None):
|
||||
"""
|
||||
Moves the tensor to the given device in place.
|
||||
"""
|
||||
@@ -384,7 +384,7 @@ class Tensor(SimpleMathTrait):
|
||||
if self.grad is not None and real.grad is not None: self.grad.replace(real.grad)
|
||||
return self.replace(real)
|
||||
|
||||
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> Tensor:
|
||||
def shard(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor:
|
||||
"""
|
||||
Shards the tensor across the given devices. Optionally specify which axis to shard on.
|
||||
|
||||
@@ -398,7 +398,7 @@ class Tensor(SimpleMathTrait):
|
||||
mlb = self.lazydata.shard(devices, self._resolve_dim(axis) if axis is not None else None)
|
||||
return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
|
||||
|
||||
def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None):
|
||||
def shard_(self, devices:tuple[str, ...], axis:int|None=None):
|
||||
"""
|
||||
Shards the tensor across the given devices in place.
|
||||
"""
|
||||
@@ -415,7 +415,7 @@ class Tensor(SimpleMathTrait):
|
||||
# ***** creation entrypoint *****
|
||||
|
||||
@staticmethod
|
||||
def _metaop(op, shape, device:Optional[Union[tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs):
|
||||
def _metaop(op, shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None, arg=None, **kwargs):
|
||||
dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float
|
||||
if isinstance(device, tuple):
|
||||
return Tensor(UOp.multi(*[UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], axis=None),
|
||||
@@ -493,7 +493,7 @@ class Tensor(SimpleMathTrait):
|
||||
return counts0.cat(counts1)
|
||||
|
||||
@staticmethod
|
||||
def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, contiguous:bool=True, **kwargs) -> Tensor:
|
||||
def rand(*shape, device:str|None=None, dtype:DTypeLike|None=None, contiguous:bool=True, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
|
||||
|
||||
@@ -633,7 +633,7 @@ class Tensor(SimpleMathTrait):
|
||||
return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype)
|
||||
|
||||
@staticmethod
|
||||
def linspace(start:Union[int, float], stop:Union[int, float], steps:int, **kwargs) -> Tensor:
|
||||
def linspace(start:int|float, stop:int|float, steps:int, **kwargs) -> Tensor:
|
||||
"""
|
||||
Returns a 1-D tensor of `steps` evenly spaced values from `start` to `stop`, inclusive.
|
||||
|
||||
@@ -653,7 +653,7 @@ class Tensor(SimpleMathTrait):
|
||||
return (start + Tensor.arange(steps, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype)
|
||||
|
||||
@staticmethod
|
||||
def eye(n:int, m:Optional[int]=None, **kwargs) -> Tensor:
|
||||
def eye(n:int, m:int|None=None, **kwargs) -> Tensor:
|
||||
"""
|
||||
Returns a 2-D tensor with `n` rows and `m` columns, with ones on the diagonal and zeros elsewhere.
|
||||
|
||||
@@ -740,7 +740,7 @@ class Tensor(SimpleMathTrait):
|
||||
# ***** rng hlops *****
|
||||
|
||||
@staticmethod
|
||||
def randn(*shape, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
|
||||
def randn(*shape, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
|
||||
If `dtype` is not specified, the default type is used.
|
||||
@@ -777,7 +777,7 @@ class Tensor(SimpleMathTrait):
|
||||
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def normal(*shape, mean=0.0, std=1.0, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
|
||||
def normal(*shape, mean=0.0, std=1.0, requires_grad:bool|None=None, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
|
||||
|
||||
@@ -792,7 +792,7 @@ class Tensor(SimpleMathTrait):
|
||||
return ((std * Tensor.randn(*shape, **kwargs)) + mean).requires_grad_(requires_grad)
|
||||
|
||||
@staticmethod
|
||||
def uniform(*shape, low=0.0, high=1.0, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
|
||||
def uniform(*shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
|
||||
|
||||
@@ -883,7 +883,7 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
|
||||
def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None, materialize_grads=False) -> list[Tensor]:
|
||||
def gradient(self, *targets:Tensor, gradient:Tensor|None=None, materialize_grads=False) -> list[Tensor]:
|
||||
"""
|
||||
Compute the gradient of the targets with respect to self.
|
||||
|
||||
@@ -912,7 +912,7 @@ class Tensor(SimpleMathTrait):
|
||||
# create returned Tensors
|
||||
return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])]
|
||||
|
||||
def backward(self, gradient:Optional[Tensor]=None) -> Tensor:
|
||||
def backward(self, gradient:Tensor|None=None) -> Tensor:
|
||||
"""
|
||||
Propagates the gradient of a tensor backwards through the computation graph.
|
||||
If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
|
||||
@@ -1007,7 +1007,7 @@ class Tensor(SimpleMathTrait):
|
||||
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
|
||||
return self._apply_uop(UOp.flip, arg=tuple([i in axis_arg for i in range(len(self.shape))]))
|
||||
|
||||
def shrink(self, arg:tuple[Optional[tuple[sint, sint]], ...]) -> Tensor:
|
||||
def shrink(self, arg:tuple[tuple[sint, sint]|None, ...]) -> Tensor:
|
||||
"""
|
||||
Returns a tensor that shrinks the each axis based on input arg.
|
||||
`arg` must have the same length as `self.ndim`.
|
||||
@@ -1027,7 +1027,7 @@ class Tensor(SimpleMathTrait):
|
||||
if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
|
||||
return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg))
|
||||
|
||||
def pad(self, padding:Union[Sequence[sint], Sequence[Optional[tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor:
|
||||
def pad(self, padding:Sequence[sint]|Sequence[tuple[sint, sint]|None], mode:str="constant", value:float=0.0) -> Tensor:
|
||||
"""
|
||||
Returns a tensor with padding applied based on the input `padding`.
|
||||
|
||||
@@ -1065,7 +1065,7 @@ class Tensor(SimpleMathTrait):
|
||||
if len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads")
|
||||
pX = _flat_to_grouped(tuple(cast(Sequence[sint], padding)) + (0,0)*(self.ndim - len(padding)//2))
|
||||
# group padding
|
||||
else: pX = tuple((0,0) if p is None else p for p in cast(Sequence[Optional[tuple[sint, sint]]], padding))
|
||||
else: pX = tuple((0,0) if p is None else p for p in cast(Sequence[tuple[sint, sint]|None], padding))
|
||||
if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
|
||||
X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
|
||||
if mode == "constant":
|
||||
@@ -1092,7 +1092,7 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
# ***** movement high level ops *****
|
||||
|
||||
def _getitem(self, indices, v: Optional[Tensor] = None) -> Tensor:
|
||||
def _getitem(self, indices, v: Tensor|None = None) -> Tensor:
|
||||
# wrap single index into a list
|
||||
if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices]
|
||||
x, indices = self, list(indices)
|
||||
@@ -1221,7 +1221,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self._getitem(indices)
|
||||
|
||||
def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None:
|
||||
def __setitem__(self, indices, v:Tensor|ConstType) -> None:
|
||||
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
||||
self._getitem(indices).assign(v)
|
||||
return
|
||||
@@ -1293,7 +1293,7 @@ class Tensor(SimpleMathTrait):
|
||||
# checks for shapes and number of dimensions delegated to cat
|
||||
return Tensor.cat(*[t.unsqueeze(dim) for t in [self, *args]], dim=dim)
|
||||
|
||||
def repeat_interleave(self, repeats:int, dim:Optional[int]=None) -> Tensor:
|
||||
def repeat_interleave(self, repeats:int, dim:int|None=None) -> Tensor:
|
||||
"""
|
||||
Repeat elements of a tensor.
|
||||
|
||||
@@ -1331,7 +1331,7 @@ class Tensor(SimpleMathTrait):
|
||||
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
|
||||
return dim + total if dim < 0 else dim
|
||||
|
||||
def split(self, sizes:Union[int, list[int]], dim:int=0) -> tuple[Tensor, ...]:
|
||||
def split(self, sizes:int|Sequence[int], dim:int=0) -> tuple[Tensor, ...]:
|
||||
"""
|
||||
Splits the tensor into chunks along the dimension specified by `dim`.
|
||||
If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
|
||||
@@ -1380,7 +1380,7 @@ class Tensor(SimpleMathTrait):
|
||||
dim = self._resolve_dim(dim)
|
||||
return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim))
|
||||
|
||||
def meshgrid(self:Tensor, *args:Tensor, indexing:Union[Literal["ij"], Literal["xy"]]="ij") -> tuple[Tensor, ...]:
|
||||
def meshgrid(self:Tensor, *args:Tensor, indexing:Literal["ij", "xy"]="ij") -> tuple[Tensor, ...]:
|
||||
"""
|
||||
Generates coordinate matrices from coordinate vectors.
|
||||
Input tensors can be scalars or 1D tensors.
|
||||
@@ -1407,7 +1407,7 @@ class Tensor(SimpleMathTrait):
|
||||
output_shape = _broadcast_shape(*(t.shape for t in tensors))
|
||||
return tuple(t._broadcast_to(output_shape) for t in tensors)
|
||||
|
||||
def squeeze(self, dim:Optional[int]=None) -> Tensor:
|
||||
def squeeze(self, dim:int|None=None) -> Tensor:
|
||||
"""
|
||||
Returns a tensor with specified dimensions of input of size 1 removed.
|
||||
If `dim` is not specified, all dimensions with size 1 are removed.
|
||||
@@ -1497,7 +1497,7 @@ class Tensor(SimpleMathTrait):
|
||||
dim = self._resolve_dim(dim)
|
||||
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
|
||||
|
||||
def roll(self, shifts:Union[int, tuple[int, ...]], dims:Union[int, tuple[int, ...]]) -> Tensor:
|
||||
def roll(self, shifts:int|tuple[int, ...], dims:int|tuple[int, ...]) -> Tensor:
|
||||
"""
|
||||
Rolls the tensor along specified dimension(s).
|
||||
The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension.
|
||||
@@ -1559,13 +1559,13 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
# ***** reduce ops *****
|
||||
|
||||
def _reduce(self, op:Ops, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
|
||||
def _reduce(self, op:Ops, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
|
||||
axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1)))
|
||||
if self.ndim == 0: axis = ()
|
||||
ret = self._apply_uop(UOp.r, op=op, axis=axis)
|
||||
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis))
|
||||
|
||||
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
|
||||
def sum(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None):
|
||||
"""
|
||||
Returns the sum of the elements of the tensor along the specified axis or axes.
|
||||
|
||||
@@ -1592,7 +1592,7 @@ class Tensor(SimpleMathTrait):
|
||||
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim)
|
||||
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
|
||||
|
||||
def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
|
||||
def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None):
|
||||
"""
|
||||
Returns the product of the elements of the tensor along the specified axis or axes.
|
||||
|
||||
@@ -1618,7 +1618,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim)
|
||||
|
||||
def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
||||
def max(self, axis:int|Sequence[int]|None=None, keepdim=False):
|
||||
"""
|
||||
Returns the maximum value of the tensor along the specified axis or axes.
|
||||
|
||||
@@ -1643,7 +1643,7 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
|
||||
|
||||
def min(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
||||
def min(self, axis:int|Sequence[int]|None=None, keepdim=False):
|
||||
"""
|
||||
Returns the minimum value of the tensor along the specified axis or axes.
|
||||
|
||||
@@ -1666,7 +1666,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self._inverse().max(axis=axis, keepdim=keepdim)._inverse()
|
||||
|
||||
def any(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
||||
def any(self, axis:int|Sequence[int]|None=None, keepdim=False):
|
||||
"""
|
||||
Tests if any element evaluates to `True` along the specified axis or axes.
|
||||
|
||||
@@ -1688,7 +1688,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self.bool().max(axis, keepdim)
|
||||
|
||||
def all(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
||||
def all(self, axis:int|Sequence[int]|None=None, keepdim=False):
|
||||
"""
|
||||
Tests if all element evaluates to `True` along the specified axis or axes.
|
||||
|
||||
@@ -1730,7 +1730,7 @@ class Tensor(SimpleMathTrait):
|
||||
is_nan_close = (self.isnan() & other.isnan()) & equal_nan
|
||||
return is_finite_close | is_infinite_close | is_nan_close
|
||||
|
||||
def mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
||||
def mean(self, axis:int|Sequence[int]|None=None, keepdim=False):
|
||||
"""
|
||||
Returns the mean value of the tensor along the specified axis or axes.
|
||||
|
||||
@@ -1756,7 +1756,7 @@ class Tensor(SimpleMathTrait):
|
||||
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
|
||||
return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])).cast(output_dtype)
|
||||
|
||||
def var(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
|
||||
def var(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1):
|
||||
"""
|
||||
Returns the variance of the tensor along the specified axis or axes.
|
||||
|
||||
@@ -1782,7 +1782,7 @@ class Tensor(SimpleMathTrait):
|
||||
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
|
||||
return squares.sum(axis=axis, keepdim=keepdim).div(smax([0, n-correction]))
|
||||
|
||||
def var_mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
|
||||
def var_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1):
|
||||
"""
|
||||
Calculates the variance and mean over the dimensions specified by dim.
|
||||
Syntactic sugar around `Tensor.var` and `Tensor.mean` to match `torch.var_mean`.
|
||||
@@ -1799,7 +1799,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self.var(axis, keepdim, correction), self.mean(axis, keepdim)
|
||||
|
||||
def std(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
|
||||
def std(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1):
|
||||
"""
|
||||
Returns the standard deviation of the tensor along the specified axis or axes.
|
||||
|
||||
@@ -1823,7 +1823,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self.var(axis, keepdim, correction).sqrt()
|
||||
|
||||
def std_mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
|
||||
def std_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1):
|
||||
"""
|
||||
Calculates the standard deviation and mean over the dimensions specified by dim.
|
||||
Syntactic sugar around `Tensor.std` and `Tensor.mean` to match `torch.std_mean`.
|
||||
@@ -1840,13 +1840,13 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self.std(axis, keepdim, correction), self.mean(axis, keepdim)
|
||||
|
||||
def _softmax(self, axis, dtype:Optional[DTypeLike]=None):
|
||||
def _softmax(self, axis, dtype:DTypeLike|None=None):
|
||||
m = self - self.max(axis=axis, keepdim=True).detach()
|
||||
if dtype is not None: m = m.cast(dtype)
|
||||
e = m.exp()
|
||||
return m, e, e.sum(axis=axis, keepdim=True)
|
||||
|
||||
def softmax(self, axis=-1, dtype:Optional[DTypeLike]=None):
|
||||
def softmax(self, axis=-1, dtype:DTypeLike|None=None):
|
||||
"""
|
||||
Applies the softmax function to the tensor along the specified axis.
|
||||
|
||||
@@ -1869,7 +1869,7 @@ class Tensor(SimpleMathTrait):
|
||||
_, e, ss = self._softmax(axis, dtype)
|
||||
return e.div(ss)
|
||||
|
||||
def log_softmax(self, axis=-1, dtype:Optional[DTypeLike]=None):
|
||||
def log_softmax(self, axis=-1, dtype:DTypeLike|None=None):
|
||||
"""
|
||||
Applies the log-softmax function to the tensor along the specified axis.
|
||||
|
||||
@@ -2005,7 +2005,7 @@ class Tensor(SimpleMathTrait):
|
||||
return self._inverse().argmax(axis=axis, keepdim=keepdim)
|
||||
|
||||
@staticmethod
|
||||
def einsum(formula:str, *operands:Tensor|Sequence[Tensor], acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
||||
def einsum(formula:str, *operands:Tensor|Sequence[Tensor], acc_dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
|
||||
|
||||
@@ -2051,7 +2051,7 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
# ***** processing ops *****
|
||||
|
||||
def _pool(self, k_:tuple[sint, ...], stride:Union[tuple[int, ...], int]=1, dilation:Union[tuple[int, ...], int]=1) -> Tensor:
|
||||
def _pool(self, k_:tuple[sint, ...], stride:int|tuple[int, ...]=1, dilation:int|tuple[int, ...]=1) -> Tensor:
|
||||
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
|
||||
s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_))
|
||||
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
|
||||
@@ -2076,12 +2076,12 @@ class Tensor(SimpleMathTrait):
|
||||
x = x.shrink(tuple(noop + flatten(((0,o), (0,k)) for o,k in zip(o_,k_))))
|
||||
return x.permute(*range(len(noop)), *[len(noop)+i*2 for i in range(len(i_))], *[len(noop)+i*2+1 for i in range(len(i_))])
|
||||
|
||||
def _resolve_pool_pads(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]:
|
||||
def _resolve_pool_pads(self, padding:int|Sequence[int], dims:int) -> Sequence[int]:
|
||||
if not isinstance(padding, int) and not (len(padding) == 2*dims or len(padding) == dims):
|
||||
raise ValueError(f"Padding must be an int or a sequence of length {dims} or {2*dims}, but got {padding=} for {self.shape=} with {dims=}.")
|
||||
return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
|
||||
|
||||
def _apply_ceil_mode(self, pads:Sequence[int], k_:tuple[sint, ...], s_:Union[tuple[int, ...], int], d_:Union[tuple[int, ...], int]) -> list[int]:
|
||||
def _apply_ceil_mode(self, pads:Sequence[int], k_:tuple[sint, ...], s_:int|tuple[int, ...], d_:int|tuple[int, ...]) -> list[int]:
|
||||
(d_,s_), i_ = (make_tuple(x, len(k_)) for x in (d_,s_)), self.shape[-len(k_):]
|
||||
pads, grouped_pads = list(pads), _flat_to_grouped(pads)
|
||||
# https://arxiv.org/pdf/1603.07285 section 5.1, relationship 15.
|
||||
@@ -2181,8 +2181,8 @@ class Tensor(SimpleMathTrait):
|
||||
if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation)
|
||||
return self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
|
||||
|
||||
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
|
||||
acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
||||
def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
|
||||
acc_dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Applies a convolution over a tensor with a given `weight` and optional `bias`.
|
||||
|
||||
@@ -2255,7 +2255,7 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
|
||||
|
||||
def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
|
||||
def conv_transpose2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
|
||||
"""
|
||||
Applies a transposed convolution over a tensor with a given `weight` and optional `bias`.
|
||||
|
||||
@@ -2294,7 +2294,7 @@ class Tensor(SimpleMathTrait):
|
||||
padding = flatten((((k-1)*d-pB,(k-1)*d-pA+op) for k,d,(pB,pA),op in reversed(list(zip(HW, dilation, padding, output_padding)))))
|
||||
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
||||
|
||||
def dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
||||
def dot(self, w:Tensor, acc_dtype:DTypeLike|None=None) -> Tensor:
|
||||
|
||||
"""
|
||||
Performs dot product between two tensors.
|
||||
@@ -2322,7 +2322,7 @@ class Tensor(SimpleMathTrait):
|
||||
w = w.reshape(*w.shape[0:-2], *[1]*min(dx-1, dw-1, 1), *w.shape[axis_w:]).transpose(-1, axis_w)
|
||||
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)
|
||||
|
||||
def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
||||
def matmul(self, x:Tensor, reverse=False, acc_dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Performs matrix multiplication between two tensors.
|
||||
|
||||
@@ -2487,7 +2487,7 @@ class Tensor(SimpleMathTrait):
|
||||
src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
|
||||
return src, mask
|
||||
|
||||
def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']]=None) -> Tensor:
|
||||
def scatter(self, dim:int, index:Tensor, src:Tensor|ConstType, reduce:Literal['multiply', 'add']|None=None) -> Tensor:
|
||||
"""
|
||||
Scatters `src` values along an axis specified by `dim`.
|
||||
Apply `add` or `multiply` reduction operation with `reduce`.
|
||||
@@ -2552,7 +2552,7 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
src, mask = self._pre_scatter(dim, index, src)
|
||||
def _inv_mask(a:Union[Tensor, ConstType], b:Union[Tensor, ConstType]) -> Tensor: return mask.any(-1).logical_not().where(a, b)
|
||||
def _inv_mask(a:Tensor|ConstType, b:Tensor|ConstType) -> Tensor: return mask.any(-1).logical_not().where(a, b)
|
||||
# TODO: should not overwrite acc_dtype here?
|
||||
if reduce == "sum": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0))
|
||||
if reduce == "prod": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1))
|
||||
@@ -2821,7 +2821,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return (self.isinf()|self.isnan()).logical_not()
|
||||
|
||||
def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor:
|
||||
def lerp(self, end:Tensor, weight:Tensor|float) -> Tensor:
|
||||
"""
|
||||
Linearly interpolates between `self` and `end` by `weight`.
|
||||
|
||||
@@ -3167,7 +3167,7 @@ class Tensor(SimpleMathTrait):
|
||||
raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
|
||||
return self.reshape(shape)._apply_uop(UOp.expand, arg=new_shape)
|
||||
|
||||
def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]:
|
||||
def _broadcasted(self, y:Tensor|ConstType|UOp, reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]:
|
||||
x: Tensor = self
|
||||
if not isinstance(y, Tensor):
|
||||
# make y a Tensor
|
||||
@@ -3186,7 +3186,7 @@ class Tensor(SimpleMathTrait):
|
||||
# broadcast
|
||||
return x._broadcast_to(out_shape:=_broadcast_shape(x.shape, y.shape)), y._broadcast_to(out_shape)
|
||||
|
||||
def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
def add(self, x:Tensor|ConstType, reverse=False) -> Tensor:
|
||||
"""
|
||||
Adds `self` and `x`.
|
||||
Equivalent to `self + x`.
|
||||
@@ -3206,7 +3206,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self._apply_broadcasted_uop(UOp.add, x, reverse)
|
||||
|
||||
def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
def sub(self, x:Tensor|ConstType, reverse=False) -> Tensor:
|
||||
"""
|
||||
Subtracts `x` from `self`.
|
||||
Equivalent to `self - x`.
|
||||
@@ -3227,7 +3227,7 @@ class Tensor(SimpleMathTrait):
|
||||
a, b = self._broadcasted(x, reverse)
|
||||
return a + (-b)
|
||||
|
||||
def mul(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
def mul(self, x:Tensor|ConstType, reverse=False) -> Tensor:
|
||||
"""
|
||||
Multiplies `self` and `x`.
|
||||
Equivalent to `self * x`.
|
||||
@@ -3247,7 +3247,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self._apply_broadcasted_uop(UOp.mul, x, reverse)
|
||||
|
||||
def idiv(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
def idiv(self, x:Tensor|ConstType, reverse=False) -> Tensor:
|
||||
"""
|
||||
Divides `self` by `x`.
|
||||
Equivalent to `self // x`.
|
||||
@@ -3260,7 +3260,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self._apply_broadcasted_uop(UOp.idiv, x, reverse)
|
||||
|
||||
def div(self, x:Union[Tensor, ConstType], reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor:
|
||||
def div(self, x:Tensor|ConstType, reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor:
|
||||
"""
|
||||
Divides `self` by `x`.
|
||||
Equivalent to `self / x`.
|
||||
@@ -3287,7 +3287,7 @@ class Tensor(SimpleMathTrait):
|
||||
if rounding_mode is not None: raise RuntimeError(f"{rounding_mode=} is not supported")
|
||||
return d
|
||||
|
||||
def mod(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
def mod(self, x:Tensor|ConstType, reverse=False) -> Tensor:
|
||||
"""
|
||||
Mod `self` by `x`.
|
||||
Equivalent to `self % x`.
|
||||
@@ -3300,7 +3300,7 @@ class Tensor(SimpleMathTrait):
|
||||
a, b = self._broadcasted(x, reverse)
|
||||
return a - a.div(b, rounding_mode="floor") * b
|
||||
|
||||
def bitwise_xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
def bitwise_xor(self, x:Tensor|ConstType, reverse=False) -> Tensor:
|
||||
"""
|
||||
Computes bitwise xor of `self` and `x`.
|
||||
Equivalent to `self ^ x`.
|
||||
@@ -3316,7 +3316,7 @@ class Tensor(SimpleMathTrait):
|
||||
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
||||
return self._apply_broadcasted_uop(UOp.bitwise_xor, x, reverse)
|
||||
|
||||
def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
def bitwise_and(self, x:Tensor|ConstType, reverse=False) -> Tensor:
|
||||
"""
|
||||
Compute the bitwise AND of `self` and `x`.
|
||||
Equivalent to `self & x`.
|
||||
@@ -3331,7 +3331,7 @@ class Tensor(SimpleMathTrait):
|
||||
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
||||
return self._apply_broadcasted_uop(UOp.bitwise_and, x, reverse)
|
||||
|
||||
def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
def bitwise_or(self, x:Tensor|ConstType, reverse=False) -> Tensor:
|
||||
"""
|
||||
Compute the bitwise OR of `self` and `x`.
|
||||
Equivalent to `self | x`.
|
||||
@@ -3384,7 +3384,7 @@ class Tensor(SimpleMathTrait):
|
||||
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
|
||||
return self.idiv(2 ** x)
|
||||
|
||||
def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
def pow(self, x:Tensor|ConstType, reverse=False) -> Tensor:
|
||||
"""
|
||||
Computes power of `self` with `x`.
|
||||
Equivalent to `self ** x`.
|
||||
@@ -3407,7 +3407,7 @@ class Tensor(SimpleMathTrait):
|
||||
ret = base._apply_uop(UOp.pow, exponent)
|
||||
return ret.round().cast(self.dtype) if not dtypes.is_float(self.dtype) else ret
|
||||
|
||||
def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
|
||||
def maximum(self, x:Tensor|ConstType) -> Tensor:
|
||||
"""
|
||||
Computes element-wise maximum of `self` and `x`.
|
||||
|
||||
@@ -3420,7 +3420,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self._apply_broadcasted_uop(UOp.maximum, x)
|
||||
|
||||
def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
|
||||
def minimum(self, x:Tensor|ConstType) -> Tensor:
|
||||
"""
|
||||
Computes element-wise minimum of `self` and `x`.
|
||||
|
||||
@@ -3434,7 +3434,7 @@ class Tensor(SimpleMathTrait):
|
||||
t, x = self._broadcasted(x)
|
||||
return t._inverse().maximum(x._inverse())._inverse()
|
||||
|
||||
def where(self:Tensor, x:Union[Tensor, ConstType, sint], y:Union[Tensor, ConstType, sint]):
|
||||
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint):
|
||||
"""
|
||||
Return a tensor of elements selected from either `x` or `y`, depending on `self`.
|
||||
`output_i = x_i if self_i else y_i`.
|
||||
@@ -3458,7 +3458,7 @@ class Tensor(SimpleMathTrait):
|
||||
cond, y = cond._broadcasted(y, match_dtype=False)
|
||||
return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y))
|
||||
|
||||
def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self)
|
||||
def masked_fill(self:Tensor, mask:Tensor, value:Tensor|ConstType): return mask.where(value, self)
|
||||
|
||||
def copysign(self, other) -> Tensor:
|
||||
"""
|
||||
@@ -3503,7 +3503,7 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
|
||||
def linear(self, weight:Tensor, bias:Tensor|None=None):
|
||||
"""
|
||||
Applies a linear transformation to `self` using `weight` and `bias`.
|
||||
|
||||
@@ -3530,7 +3530,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return functools.reduce(lambda x,f: f(x), ll, self)
|
||||
|
||||
def layernorm(self, axis:Union[int,tuple[int,...]]=-1, eps:float=1e-5) -> Tensor:
|
||||
def layernorm(self, axis:int|tuple[int,...]=-1, eps:float=1e-5) -> Tensor:
|
||||
"""
|
||||
Applies Layer Normalization over a mini-batch of inputs.
|
||||
|
||||
@@ -3549,7 +3549,7 @@ class Tensor(SimpleMathTrait):
|
||||
y = (self - self.mean(axis, keepdim=True))
|
||||
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
|
||||
|
||||
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,tuple[int,...]]=1) -> Tensor:
|
||||
def batchnorm(self, weight:Tensor|None, bias:Tensor|None, mean:Tensor, invstd:Tensor, axis:int|tuple[int, ...]=1) -> Tensor:
|
||||
"""
|
||||
Applies Batch Normalization over a mini-batch of inputs.
|
||||
|
||||
@@ -3722,7 +3722,7 @@ class Tensor(SimpleMathTrait):
|
||||
ret = -self.log_softmax(axis=1).mul(Y).sum(axis=1)
|
||||
return ret._do_reduction(reduction)
|
||||
|
||||
def nll_loss(self, Y:Tensor, weight:Optional[Tensor]=None, ignore_index:Optional[int]=None, reduction:ReductionStr="mean") -> Tensor:
|
||||
def nll_loss(self, Y:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean") -> Tensor:
|
||||
"""
|
||||
Compute the negative log likelihood loss between log-probabilities and target labels.
|
||||
|
||||
@@ -3805,7 +3805,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return dtypes.is_float(self.dtype)
|
||||
|
||||
def size(self, dim:Optional[int]=None) -> Union[sint, tuple[sint, ...]]:
|
||||
def size(self, dim:int|None=None) -> sint|tuple[sint, ...]:
|
||||
"""
|
||||
Return the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.
|
||||
|
||||
@@ -3935,7 +3935,7 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
# *** image Tensor function replacements ***
|
||||
|
||||
def image_dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
||||
def image_dot(self, w:Tensor, acc_dtype:DTypeLike|None=None) -> Tensor:
|
||||
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
||||
x, dx, dw = self, self.ndim, w.ndim
|
||||
if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D")
|
||||
@@ -3951,7 +3951,7 @@ class Tensor(SimpleMathTrait):
|
||||
cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
|
||||
return cx.image_conv2d(cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
|
||||
|
||||
def image_conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None) -> Tensor:
|
||||
def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None) -> Tensor:
|
||||
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
|
||||
|
||||
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
|
||||
|
||||
Reference in New Issue
Block a user