mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 23:38:58 -05:00
3524 lines
158 KiB
Python
3524 lines
158 KiB
Python
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
|
from __future__ import annotations
|
|
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib
|
|
from contextlib import ContextDecorator
|
|
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Literal
|
|
from collections import defaultdict
|
|
|
|
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype
|
|
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
|
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA
|
|
from tinygrad.lazy import LazyBuffer
|
|
from tinygrad.multi import MultiLazyBuffer
|
|
from tinygrad.ops import MetaOps, truncate, smax, resolve, UOp, UOps, BinaryOps
|
|
from tinygrad.device import Device, Buffer, BufferOptions
|
|
from tinygrad.shape.symbolic import sint, Variable
|
|
from tinygrad.engine.realize import run_schedule, memory_planner
|
|
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
|
|
|
# **** start with two base classes, Tensor and Function ****
|
|
|
|
class Function:
|
|
def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor, metadata:Optional[Metadata]=None):
|
|
self.device = device
|
|
self.needs_input_grad = [t.requires_grad for t in tensors]
|
|
self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
|
|
if self.requires_grad: self.parents = tensors
|
|
self.metadata = metadata
|
|
|
|
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
|
|
def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")
|
|
|
|
@classmethod
|
|
def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
|
|
ctx = fxn(x[0].device, *x, metadata=_METADATA.get())
|
|
ret = Tensor.__new__(Tensor)
|
|
ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
|
|
ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
|
|
return ret
|
|
|
|
import tinygrad.function as F
|
|
|
|
def _metaop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Tuple[LazyBuffer, ...]=()):
|
|
if isinstance(device, str): return LazyBuffer.metaop(op, shape, dtype, device, arg, src)
|
|
return MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype, d, arg, src) for d in device], None)
|
|
|
|
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
|
import numpy as np
|
|
return dtypes.fields()[np.dtype(npdtype).name]
|
|
def _to_np_dtype(dtype:DType) -> Optional[type]:
|
|
import numpy as np
|
|
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
|
|
|
|
def _fromnp(x: 'np.ndarray') -> LazyBuffer: # type: ignore [name-defined] # noqa: F821
|
|
ret = LazyBuffer.metaop(MetaOps.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
|
|
# fake realize
|
|
ret.buffer.allocate(x)
|
|
del ret.srcs
|
|
return ret
|
|
|
|
def get_shape(x) -> Tuple[int, ...]:
|
|
if not hasattr(x, "__len__") or not hasattr(x, "__getitem__") or isinstance(x, str): return ()
|
|
if (aapi := (hasattr(x, "shape") and x.shape == ())): return ()
|
|
subs = [get_shape(xi) for xi in x]
|
|
if not all_same(subs): raise ValueError(f"inhomogeneous shape from {x}")
|
|
slen = 1 if aapi else len(subs)
|
|
return (slen,) + (subs[0] if subs else ())
|
|
|
|
def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
|
|
if isinstance(x, bytes): ret, data = LazyBuffer.metaop(MetaOps.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
|
|
else:
|
|
ret = LazyBuffer.metaop(MetaOps.EMPTY, get_shape(x), dtype, "PYTHON")
|
|
assert dtype.fmt is not None, f"{dtype=} has None fmt"
|
|
truncate_function = truncate[dtype]
|
|
data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
|
|
# fake realize
|
|
ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data)))
|
|
del ret.srcs
|
|
return ret
|
|
|
|
def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str, Tuple[str, ...]]) -> List[List[Tensor]]:
|
|
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device) for m in mat], dim=dim)
|
|
for k in range(len(mat[0]))] for dim in range(dims)]
|
|
|
|
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
|
|
def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
|
|
# multiply mat_1 @ mat_2 @ t with foldable constants, where mat_i acts on vector t along dimension i; roughly kron(mat, mat) @ t
|
|
# due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic
|
|
t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims
|
|
# precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...)
|
|
matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device)
|
|
# multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t
|
|
ret = sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
|
|
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
|
|
return ret
|
|
|
|
def _pad_left(*shapes:Tuple[sint, ...]) -> Tuple[Tuple[sint, ...], ...]:
|
|
max_dim = max(len(shape) for shape in shapes)
|
|
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
|
|
def _broadcast_shape(*shapes:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
|
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_pad_left(*shapes)))
|
|
|
|
ReductionStr = Literal["mean", "sum", "none"]
|
|
|
|
class Tensor:
|
|
"""
|
|
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
|
|
|
|
```python exec="true" session="tensor"
|
|
from tinygrad import Tensor, dtypes, nn
|
|
import numpy as np
|
|
import math
|
|
np.set_printoptions(precision=4)
|
|
```
|
|
"""
|
|
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
|
|
__deletable__ = ('_ctx',)
|
|
training: ClassVar[bool] = False
|
|
no_grad: ClassVar[bool] = False
|
|
|
|
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, 'np.ndarray', bytes, MultiLazyBuffer, UOp, pathlib.Path], # type: ignore [name-defined] # noqa: F821
|
|
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
|
|
if dtype is not None: dtype = to_dtype(dtype)
|
|
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
|
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
|
|
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
|
|
|
# tensors can have gradients if you have called .backward
|
|
self.grad: Optional[Tensor] = None
|
|
|
|
# NOTE: this can be in three states. False and None: no gradient, True: gradient
|
|
# None (the default) will be updated to True if it's put in an optimizer
|
|
self.requires_grad: Optional[bool] = requires_grad
|
|
|
|
# internal variable used for autograd graph construction
|
|
self._ctx: Optional[Function] = None
|
|
|
|
# create a LazyBuffer from the different types of inputs
|
|
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
|
elif isinstance(data, get_args(ConstType)): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
|
|
elif isinstance(data, UOp):
|
|
assert data.op is UOps.ASSIGN and data.src[0].op is UOps.DEFINE_VAR and data.src[1].op is UOps.CONST, f"can't create tensor from UOp {data}"
|
|
data = _metaop(MetaOps.CONST, tuple(), dtype or data.dtype, device, data)
|
|
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
|
|
elif isinstance(data, (list, tuple)):
|
|
if dtype is None:
|
|
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
|
|
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
|
|
if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
|
|
else: data = _frompy(data, dtype)
|
|
elif data is None: data = _metaop(MetaOps.EMPTY, (0,), dtype or dtypes.default_float, device)
|
|
elif str(type(data)) == "<class 'numpy.ndarray'>":
|
|
import numpy as np
|
|
assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}"
|
|
if data.shape == (): data = _metaop(MetaOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
|
|
else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data) # type: ignore [name-defined]
|
|
elif isinstance(data, pathlib.Path):
|
|
dtype = dtype or dtypes.uint8
|
|
data = _metaop(MetaOps.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
|
|
|
|
# by this point, it has to be a LazyBuffer
|
|
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)):
|
|
raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
|
|
|
|
# data is a LazyBuffer, but it might be on the wrong device
|
|
if isinstance(device, tuple):
|
|
# if device is a tuple, we should have/construct a MultiLazyBuffer
|
|
if isinstance(data, MultiLazyBuffer):
|
|
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
|
|
self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = data
|
|
else:
|
|
self.lazydata = MultiLazyBuffer.from_sharded(data, device, None, None)
|
|
else:
|
|
self.lazydata = data if data.device == device else data.copy_to_device(device)
|
|
|
|
class train(ContextDecorator):
|
|
def __init__(self, mode:bool = True): self.mode = mode
|
|
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
|
|
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
|
|
|
|
class test(ContextDecorator):
|
|
def __init__(self, mode:bool = True): self.mode = mode
|
|
def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
|
|
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
|
|
|
|
def __repr__(self):
|
|
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
|
|
|
|
# Python has a non moving GC, so this should be okay
|
|
def __hash__(self): return id(self)
|
|
|
|
def __bool__(self): raise TypeError("__bool__ on Tensor is not defined")
|
|
|
|
def __len__(self):
|
|
if not self.shape: raise TypeError("len() of a 0-d tensor")
|
|
return self.shape[0]
|
|
|
|
@property
|
|
def device(self) -> Union[str, Tuple[str, ...]]: return self.lazydata.device
|
|
|
|
@property
|
|
def shape(self) -> Tuple[sint, ...]: return self.lazydata.shape
|
|
|
|
@property
|
|
def dtype(self) -> DType: return self.lazydata.dtype
|
|
|
|
# ***** data handlers ****
|
|
|
|
def schedule_with_vars(self, *lst:Tensor) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
|
"""
|
|
Creates the schedule needed to realize these Tensor(s), with Variables.
|
|
|
|
NOTE: A Tensor can only be scheduled once.
|
|
"""
|
|
if getenv("FUZZ_SCHEDULE"):
|
|
from test.external.fuzz_schedule import fuzz_schedule
|
|
fuzz_schedule(flatten([x.lazydata.lbs for x in (self,)+lst]))
|
|
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]))
|
|
return memory_planner(schedule), var_vals
|
|
|
|
def schedule(self, *lst:Tensor) -> List[ScheduleItem]:
|
|
"""Creates the schedule needed to realize these Tensor(s)."""
|
|
schedule, var_vals = self.schedule_with_vars(*lst)
|
|
assert len(var_vals) == 0
|
|
return schedule
|
|
|
|
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
|
|
"""Triggers the computation needed to create these Tensor(s)."""
|
|
run_schedule(*self.schedule_with_vars(*lst), do_update_stats=do_update_stats)
|
|
return self
|
|
|
|
def replace(self, x:Tensor) -> Tensor:
|
|
"""
|
|
Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
|
|
"""
|
|
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
|
|
assert not x.requires_grad and getattr(self, '_ctx', None) is None
|
|
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
|
|
self.lazydata = x.lazydata
|
|
return self
|
|
|
|
def assign(self, x) -> Tensor:
|
|
# TODO: this is a hack for writing to DISK. remove with working assign
|
|
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
|
if x.__class__ is not Tensor: x = Tensor(x, device="NPY", dtype=self.dtype)
|
|
self.contiguous().realize().lazydata.base.realized.copyin(x.numpy().data)
|
|
return self
|
|
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
|
|
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
|
|
if self.lazydata is x.lazydata: return self # a self assign is a NOOP
|
|
# NOTE: we allow cross device assign
|
|
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
|
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
|
|
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
|
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
|
|
assert not x.requires_grad # self requires_grad is okay?
|
|
if not self.lazydata.is_realized(): return self.replace(x)
|
|
self.lazydata = self.lazydata.assign(x.lazydata)
|
|
return self
|
|
|
|
def detach(self) -> Tensor:
|
|
"""
|
|
Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
|
|
"""
|
|
return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
|
|
|
def _data(self) -> memoryview:
|
|
if 0 in self.shape: return memoryview(bytearray(0))
|
|
# NOTE: this realizes on the object from as_buffer being a Python object
|
|
cpu = self.cast(self.dtype.scalar()).contiguous().to("CLANG").realize()
|
|
buf = cast(Buffer, cast(LazyBuffer, cpu.lazydata).base.realized)
|
|
if self.device != "CLANG": buf.options = BufferOptions(nolru=True)
|
|
return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)
|
|
|
|
def data(self) -> memoryview:
|
|
"""
|
|
Returns the data of this tensor as a memoryview.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([1, 2, 3, 4])
|
|
print(np.frombuffer(t.data(), dtype=np.int32))
|
|
```
|
|
"""
|
|
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
|
|
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
|
return self._data().cast(self.dtype.fmt, self.shape)
|
|
|
|
def item(self) -> ConstType:
|
|
"""
|
|
Returns the value of this tensor as a standard Python number.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor(42)
|
|
print(t.item())
|
|
```
|
|
"""
|
|
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
|
|
assert self.numel() == 1, "must have one element for item"
|
|
return self._data().cast(self.dtype.fmt)[0]
|
|
|
|
# TODO: should be Tensor.tolist() -> Union[List[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int]
|
|
# src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
|
|
def tolist(self) -> Union[Sequence[ConstType], ConstType]:
|
|
"""
|
|
Returns the value of this tensor as a nested list.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([1, 2, 3, 4])
|
|
print(t.tolist())
|
|
```
|
|
"""
|
|
return self.data().tolist()
|
|
|
|
def numpy(self) -> 'np.ndarray': # type: ignore [name-defined] # noqa: F821
|
|
"""
|
|
Returns the value of this tensor as a `numpy.ndarray`.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([1, 2, 3, 4])
|
|
print(repr(t.numpy()))
|
|
```
|
|
"""
|
|
import numpy as np
|
|
if self.dtype == dtypes.bfloat16: return self.float().numpy()
|
|
assert _to_np_dtype(self.dtype) is not None, f"no np dtype for {self.dtype}"
|
|
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
|
return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype)).reshape(self.shape)
|
|
|
|
def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor:
|
|
"""
|
|
Moves the tensor to the given device.
|
|
"""
|
|
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
|
if device == self.device: return self
|
|
if not isinstance(device, str): return self.shard(device)
|
|
ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
|
|
if self.grad is not None: ret.grad = self.grad.to(device)
|
|
if hasattr(self, '_ctx'): ret._ctx = self._ctx
|
|
return ret
|
|
|
|
def to_(self, device:Optional[Union[str, Tuple[str, ...]]]):
|
|
"""
|
|
Moves the tensor to the given device in place.
|
|
"""
|
|
real = self.to(device)
|
|
# TODO: is this assign?
|
|
if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
|
|
self.lazydata = real.lazydata
|
|
|
|
def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None) -> Tensor:
|
|
"""
|
|
Shards the tensor across the given devices. Optionally specify which axis to shard on, and how to split it across devices.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.empty(2, 3)
|
|
print(t.shard((t.device, t.device), axis=1, splits=(2, 1)).lazydata)
|
|
```
|
|
|
|
"""
|
|
assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
|
|
canonical_devices, bounds = tuple(Device.canonicalize(x) for x in devices), None
|
|
if axis is not None:
|
|
if axis < 0: axis += len(self.shape)
|
|
if splits is None:
|
|
sz = round_up(self.shape[axis], len(devices)) // len(devices)
|
|
splits = tuple([max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))])
|
|
assert sum(splits) == self.shape[axis], "specified splits do not sum up to axis shape"
|
|
boundaries = tuple(itertools.accumulate(splits))
|
|
bounds = tuple(zip((0,) + boundaries, boundaries))
|
|
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis, bounds),
|
|
device=canonical_devices, requires_grad=self.requires_grad)
|
|
|
|
def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None):
|
|
"""
|
|
Shards the tensor across the given devices in place.
|
|
"""
|
|
self.lazydata = self.shard(devices, axis, splits).lazydata
|
|
return self
|
|
|
|
@staticmethod
|
|
def from_uop(y:UOp, **kwargs) -> Tensor:
|
|
if y.op is UOps.ASSIGN: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor
|
|
if y.op is UOps.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
|
|
if y.op is UOps.ALU:
|
|
if y.arg is BinaryOps.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
|
|
if y.arg is BinaryOps.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
|
|
raise RuntimeError(f"unhandled Node {y}")
|
|
|
|
# ***** creation entrypoint *****
|
|
|
|
@staticmethod
|
|
def _metaop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs):
|
|
if isinstance(device, tuple):
|
|
return Tensor(MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \
|
|
for d in device], None), device, dtype, **kwargs)
|
|
return Tensor(LazyBuffer.metaop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), device, dtype, **kwargs)
|
|
|
|
@staticmethod
|
|
def empty(*shape, **kwargs):
|
|
"""
|
|
Creates an empty tensor with the given shape.
|
|
|
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.empty(2, 3)
|
|
print(t.shape)
|
|
```
|
|
"""
|
|
return Tensor._metaop(MetaOps.EMPTY, argfix(*shape), **kwargs)
|
|
|
|
@staticmethod
|
|
def from_blob(ptr:int, shape:Tuple[int, ...], **kwargs) -> Tensor:
|
|
"""
|
|
Exposes the pointer as a Tensor without taking ownership of the original data.
|
|
The pointer must remain valid for the entire lifetime of the created Tensor.
|
|
|
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
"""
|
|
|
|
r = Tensor._metaop(MetaOps.EMPTY, shape, **kwargs)
|
|
r.lazydata.buffer.allocate(external_ptr=ptr)
|
|
del r.lazydata.srcs # fake realize
|
|
return r
|
|
|
|
_seed: int = int(time.time())
|
|
_device_seeds: Dict[str, int] = {}
|
|
_device_rng_counters: Dict[str, Tensor] = {}
|
|
@staticmethod
|
|
def manual_seed(seed=0):
|
|
"""
|
|
Sets the seed for random operations.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
print(Tensor.rand(5).numpy())
|
|
print(Tensor.rand(5).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42) # reset to the same seed
|
|
print(Tensor.rand(5).numpy())
|
|
print(Tensor.rand(5).numpy())
|
|
```
|
|
"""
|
|
Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {}
|
|
|
|
@staticmethod
|
|
def _threefry_random_bits(key0, key1, counts0, counts1):
|
|
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
|
|
key = (Tensor([key0], device=x.device, dtype=dtypes.uint64, requires_grad=False) << 32) | key1
|
|
x = F.Threefry.apply(*x._broadcasted(key))
|
|
counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
|
|
return counts0.cat(counts1)
|
|
|
|
@staticmethod
|
|
def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, **kwargs) -> Tensor:
|
|
"""
|
|
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
|
|
|
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
t = Tensor.rand(2, 3)
|
|
print(t.numpy())
|
|
```
|
|
"""
|
|
if not dtypes.is_float(dtype := to_dtype(dtype or dtypes.default_float)): raise ValueError(f"rand only supports float dtypes, got {dtype}")
|
|
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
|
|
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
|
|
_device = device = Device.canonicalize(device)
|
|
|
|
# when using MOCKGPU and NV generate rand on CLANG
|
|
if getenv("MOCKGPU") and device.startswith("NV"): device = "CLANG"
|
|
|
|
# generate per device seeds and rng counter if we haven't seen this device yet
|
|
if device not in Tensor._device_seeds:
|
|
Tensor._device_seeds[device] = int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big") & 0xffffffff
|
|
Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
|
|
had_counter = False
|
|
else: had_counter = True
|
|
|
|
# if shape has 0, return zero tensor
|
|
if (num := math.ceil(((num_ := prod(shape)) * dtype.itemsize) / 4)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
|
|
|
|
# increment rng counter for devices
|
|
if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num)
|
|
|
|
# threefry random bits
|
|
counts0 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._device_rng_counters[device])
|
|
counts1 = counts0 + math.ceil(num / 2)
|
|
bits = Tensor._threefry_random_bits(Tensor._seed, Tensor._device_seeds[device], counts0, counts1)[:num]
|
|
|
|
# bitcast to uint with same number of bits
|
|
_, nmant = dtypes.finfo(dtype)
|
|
uint_dtype = {1: dtypes.uint8, 2: dtypes.uint16, 4: dtypes.uint32, 8: dtypes.uint64}[dtype.itemsize]
|
|
bits = bits.bitcast(uint_dtype)
|
|
# only randomize the mantissa bits and set the exponent to 1
|
|
one = Tensor.ones_like(bits, device=bits.device, dtype=dtype).bitcast(uint_dtype)
|
|
bits = bits.rshift((dtype.itemsize * 8) - nmant).bitwise_or(one)
|
|
# bitcast back to the original dtype and reshape
|
|
out = bits.bitcast(dtype)[:num_].sub(1).reshape(shape)
|
|
|
|
# move back to the original device if we were using MOCKGPU
|
|
if getenv("MOCKGPU") and _device: out = out.to(_device)
|
|
|
|
out.requires_grad = kwargs.get("requires_grad")
|
|
return out.contiguous()
|
|
|
|
# ***** creation helper functions *****
|
|
|
|
@staticmethod
|
|
def full(shape:Tuple[sint, ...], fill_value:ConstType, **kwargs) -> Tensor:
|
|
"""
|
|
Creates a tensor with the given shape, filled with the given value.
|
|
|
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor.full((2, 3), 42).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor.full((2, 3), False).numpy())
|
|
```
|
|
"""
|
|
return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
|
|
|
|
@staticmethod
|
|
def zeros(*shape, **kwargs) -> Tensor:
|
|
"""
|
|
Creates a tensor with the given shape, filled with zeros.
|
|
|
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor.zeros(2, 3).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy())
|
|
```
|
|
"""
|
|
return Tensor.full(argfix(*shape), 0.0, **kwargs)
|
|
|
|
@staticmethod
|
|
def ones(*shape, **kwargs) -> Tensor:
|
|
"""
|
|
Creates a tensor with the given shape, filled with ones.
|
|
|
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor.ones(2, 3).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy())
|
|
```
|
|
"""
|
|
return Tensor.full(argfix(*shape), 1.0, **kwargs)
|
|
|
|
@staticmethod
|
|
def arange(start, stop=None, step=1, **kwargs) -> Tensor:
|
|
"""
|
|
Returns a 1-D tensor of size `ceil((stop - start) / step)` with values from `[start, stop)`, with spacing between values given by `step`.
|
|
|
|
If `stop` is not specified, values are generated from `[0, start)` with the given `step`.
|
|
|
|
If `stop` is specified, values are generated from `[start, stop)` with the given `step`.
|
|
|
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor.arange(5).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor.arange(5, 10).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor.arange(5, 10, 2).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor.arange(5.5, 10, 2).numpy())
|
|
```
|
|
"""
|
|
if stop is None: stop, start = start, 0
|
|
assert all(isinstance(s, (int, float)) for s in (start, stop, step)), f"symbolic arange not supported {start=}, {stop=}, {step=}"
|
|
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
|
|
# NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
|
|
if (stop-start)/step <= 0: return Tensor([], dtype=dtype, **kwargs)
|
|
return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
|
|
|
|
@staticmethod
|
|
def eye(n:int, m:Optional[int]=None, **kwargs) -> Tensor:
|
|
"""
|
|
Returns a 2-D tensor with `n` rows and `m` columns, with ones on the diagonal and zeros elsewhere.
|
|
|
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor.eye(3).numpy())
|
|
```
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor.eye(2, 4).numpy())
|
|
```
|
|
"""
|
|
if n < 0 or (m is not None and m < 0): raise ValueError(f"cannot have negative {n=}, {m=}")
|
|
x = Tensor.ones((n,1),**kwargs).pad((None,(0,n))).flatten().shrink(((0,n*n),)).reshape(n,n)
|
|
return x if m is None else x.pad((None, (0, m-n))) if m > n else x.shrink((None, (0, m)))
|
|
|
|
def full_like(self, fill_value:ConstType, **kwargs) -> Tensor:
|
|
"""
|
|
Creates a tensor with the same shape as `self`, filled with the given value.
|
|
If `dtype` is not specified, the dtype of `self` is used.
|
|
|
|
You can pass in the `device` keyword argument to control device of the tensor.
|
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.ones(2, 3)
|
|
print(Tensor.full_like(t, 42).numpy())
|
|
```
|
|
"""
|
|
return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
|
|
|
|
def zeros_like(self, **kwargs) -> Tensor:
|
|
"""
|
|
Creates a tensor with the same shape as `self`, filled with zeros.
|
|
|
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.ones(2, 3)
|
|
print(Tensor.zeros_like(t).numpy())
|
|
```
|
|
"""
|
|
return self.full_like(0, **kwargs)
|
|
|
|
def ones_like(self, **kwargs) -> Tensor:
|
|
"""
|
|
Creates a tensor with the same shape as `self`, filled with ones.
|
|
|
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.zeros(2, 3)
|
|
print(Tensor.ones_like(t).numpy())
|
|
```
|
|
"""
|
|
return self.full_like(1, **kwargs)
|
|
|
|
def rand_like(self, **kwargs) -> Tensor:
|
|
"""
|
|
Creates a tensor with the same shape and sharding as `self`, filled with random values from a uniform distribution over the interval `[0, 1)`.
|
|
|
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.ones(2, 3)
|
|
print(Tensor.rand_like(t).numpy())
|
|
```
|
|
"""
|
|
dtype = kwargs.pop("dtype", self.dtype)
|
|
device = kwargs.pop("device", self.device)
|
|
if isinstance(self.device, tuple):
|
|
assert isinstance(self.lazydata, MultiLazyBuffer)
|
|
if self.lazydata.axis is not None:
|
|
rands = [cast(LazyBuffer, Tensor.rand(*lb.shape, device=lb.device, dtype=dtype).lazydata) for lb in self.lazydata.lbs]
|
|
return Tensor(MultiLazyBuffer(rands, self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
|
|
return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device)
|
|
return Tensor.rand(*self.shape, device=device, dtype=dtype, **kwargs)
|
|
|
|
# ***** rng hlops *****
|
|
|
|
@staticmethod
|
|
def randn(*shape, dtype:Optional[DTypeLike]=None, **kwargs) -> Tensor:
|
|
"""
|
|
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
|
|
If `dtype` is not specified, the default type is used.
|
|
|
|
You can pass in the `device` keyword argument to control device of the tensor.
|
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
print(Tensor.randn(2, 3).numpy())
|
|
```
|
|
"""
|
|
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
|
|
src = Tensor.rand((2, *argfix(*shape)), **{**kwargs, "dtype": dtypes.float32})
|
|
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)
|
|
|
|
@staticmethod
|
|
def randint(*shape, low=0, high=10, **kwargs) -> Tensor:
|
|
"""
|
|
Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
|
|
If `dtype` is not specified, the default type is used.
|
|
|
|
You can pass in the `device` keyword argument to control device of the tensor.
|
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
print(Tensor.randint(2, 3, low=5, high=10).numpy())
|
|
```
|
|
"""
|
|
if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers")
|
|
dtype = kwargs.pop("dtype", dtypes.int32)
|
|
if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int")
|
|
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
|
|
|
|
@staticmethod
|
|
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
|
|
"""
|
|
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
|
|
|
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
|
|
```
|
|
"""
|
|
return (std * Tensor.randn(*shape, **kwargs)) + mean
|
|
|
|
@staticmethod
|
|
def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
|
|
"""
|
|
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
|
|
|
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
|
|
```
|
|
"""
|
|
dtype = kwargs.pop("dtype", dtypes.default_float)
|
|
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
|
|
|
|
@staticmethod
|
|
def scaled_uniform(*shape, **kwargs) -> Tensor:
|
|
"""
|
|
Creates a tensor with the given shape, filled with random values from a uniform distribution
|
|
over the interval `[-prod(shape)**-0.5, prod(shape)**-0.5)`.
|
|
|
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
print(Tensor.scaled_uniform(2, 3).numpy())
|
|
```
|
|
"""
|
|
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
|
|
|
|
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
|
|
@staticmethod
|
|
def glorot_uniform(*shape, **kwargs) -> Tensor:
|
|
"""
|
|
<https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
|
|
|
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
print(Tensor.glorot_uniform(2, 3).numpy())
|
|
```
|
|
"""
|
|
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(argfix(*shape)[0]+prod(argfix(*shape)[1:])))**0.5)
|
|
|
|
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
|
|
@staticmethod
|
|
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
|
"""
|
|
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
|
|
|
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
print(Tensor.kaiming_uniform(2, 3).numpy())
|
|
```
|
|
"""
|
|
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
|
|
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
|
|
|
|
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
|
|
@staticmethod
|
|
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
|
"""
|
|
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
|
|
|
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
print(Tensor.kaiming_normal(2, 3).numpy())
|
|
```
|
|
"""
|
|
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
|
|
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
|
|
|
|
def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor:
|
|
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
|
|
assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
|
|
weight = self.unsqueeze(0) if self.ndim == 1 else self
|
|
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
|
|
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1).to(self.device)
|
|
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
|
|
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
|
|
|
|
# ***** toposort and backward pass *****
|
|
|
|
def _deepwalk(self):
|
|
def _walk(node, visited):
|
|
visited.add(node)
|
|
# if tensor is not leaf, reset grad
|
|
if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None
|
|
if ctx:
|
|
for i in node._ctx.parents:
|
|
if i not in visited: yield from _walk(i, visited)
|
|
yield node
|
|
return list(_walk(self, set()))
|
|
|
|
def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor:
|
|
"""
|
|
Propagates the gradient of a tensor backwards through the computation graph.
|
|
If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
|
|
If 'retain_graph' is false, the graph used to compute the grads will be freed. Otherwise, it will be kept. Keeping it can increase memory usage.
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
|
|
t.sum().backward()
|
|
print(t.grad.numpy())
|
|
```
|
|
"""
|
|
toposorted = self._deepwalk()
|
|
if gradient is None:
|
|
assert self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
|
|
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
|
|
# this is "implicit gradient creation"
|
|
gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
|
|
|
|
assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}"
|
|
self.grad = gradient
|
|
for t0 in reversed(toposorted):
|
|
if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
|
|
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None)
|
|
grads = t0._ctx.backward(t0.grad.lazydata)
|
|
_METADATA.reset(token)
|
|
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
|
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
|
|
for t, g in zip(t0._ctx.parents, grads):
|
|
if g is not None and t.requires_grad:
|
|
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
|
t.grad = g if t.grad is None else (t.grad + g)
|
|
if not retain_graph: del t0._ctx
|
|
return self
|
|
|
|
# ***** movement low level ops *****
|
|
|
|
def view(self, *shape) -> Tensor:
|
|
"""`.view` is an alias for `.reshape`."""
|
|
return self.reshape(shape)
|
|
|
|
def reshape(self, shape, *args) -> Tensor:
|
|
"""
|
|
Returns a tensor with the same data as the original tensor but with a different shape.
|
|
`shape` can be passed as a tuple or as separate arguments.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.arange(6)
|
|
print(t.reshape(2, 3).numpy())
|
|
```
|
|
"""
|
|
# resolve None and args
|
|
new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))])
|
|
# resolve -1
|
|
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
|
|
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
|
|
return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
|
|
|
|
def expand(self, shape, *args) -> Tensor:
|
|
"""
|
|
Returns a tensor that is expanded to the shape that is specified.
|
|
Expand can also increase the number of dimensions that a tensor has.
|
|
|
|
Passing a `-1` or `None` to a dimension means that its size will not be changed.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([1, 2, 3])
|
|
print(t.expand(4, -1).numpy())
|
|
```
|
|
"""
|
|
return self._broadcast_to(tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_pad_left(self.shape, argfix(shape, *args))))))
|
|
|
|
def permute(self, order, *args) -> Tensor:
|
|
"""
|
|
Returns a tensor that is a permutation of the original tensor.
|
|
The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
|
|
`order` can be passed as a tuple or as separate arguments.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.arange(6).reshape(2, 3)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.permute(1, 0).numpy())
|
|
```
|
|
"""
|
|
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
|
|
if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
|
|
return F.Permute.apply(self, order=order_arg)
|
|
|
|
def flip(self, axis, *args) -> Tensor:
|
|
"""
|
|
Returns a tensor that reverses the order of the original tensor along given `axis`.
|
|
`axis` can be passed as a tuple or as separate arguments.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.arange(6).reshape(2, 3)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.flip(0).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.flip((0, 1)).numpy())
|
|
```
|
|
"""
|
|
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
|
|
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at least once, getting {axis_arg}")
|
|
return F.Flip.apply(self, axis=axis_arg)
|
|
|
|
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
|
|
"""
|
|
Returns a tensor that shrinks the each axis based on input arg.
|
|
`arg` must have the same length as `self.ndim`.
|
|
For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.arange(9).reshape(3, 3)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.shrink(((None, (1, 3)))).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.shrink((((0, 2), (0, 2)))).numpy())
|
|
```
|
|
"""
|
|
if all(x is None or x == (0,s) for x,s in zip(arg, self.shape)): return self
|
|
return F.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape)))
|
|
|
|
def pad(self, arg:Tuple[Optional[Tuple[sint, sint]], ...], value:float=0.0) -> Tensor:
|
|
"""
|
|
Returns a tensor that pads the each axis based on input arg.
|
|
`arg` must have the same length as `self.ndim`.
|
|
For each axis, it can be `None`, which means no pad, or a tuple `(pad_before, pad_after)`.
|
|
If `value` is specified, the tensor is padded with `value` instead of `0.0`.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.arange(6).reshape(2, 3)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.pad(((None, (1, 2)))).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.pad(((None, (1, 2))), -2).numpy())
|
|
```
|
|
"""
|
|
if all(x is None or x == (0,0) for x in arg): return self
|
|
ret = F.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg)))
|
|
return ret if 0 == value else ret + F.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value)
|
|
|
|
# ***** movement high level ops *****
|
|
|
|
# Supported Indexing Implementations:
|
|
# 1. Int indexing (no copy)
|
|
# - for all dims where there's int, shrink -> reshape
|
|
# - negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
|
|
# - X = Tensor.rand(4,5,9); X[2,-2] shrinks the Tensor to X.shrink(((2, 3), (3, 4), (0, 9))) -> X.shape=(1,1,9)
|
|
# - Then we reshape (collapse) the int dim away such that for X: (1,1,9) -> (9,)
|
|
# 2. Slice indexing (no copy)
|
|
# - for all dims where slice is start:end:stride, shrink -> Optional[flip] -> pad -> reshape -> shrink
|
|
# - first shrink the Tensor to X.shrink(((start, end),))
|
|
# - then we apply stride through Optional[flip] -> pad -> reshape -> shrink
|
|
# - flip where dim value is negative
|
|
# - pad on dims to be multiple of strides, such that reshaping [dim_size_padded] -> [dim_size_padded // stride, stride] is possible
|
|
# - shrink [dim_size_padded // stride, stride] -> [dim_size_padded // stride, 1]
|
|
# - reshape [dim_size_padded // stride, 1] -> [dim_size_padded // stride] and now you have your stride
|
|
# 3. None indexing (no copy)
|
|
# - reshape (inject) a dim at the dim where there's None
|
|
# 4. Tensor indexing (copy)
|
|
# - use Tensor.arange == tensor_index to create masks for dims with Tensors (adds a dim for each mask)
|
|
# - combine masks together with mul
|
|
# - apply mask to self by mask * self
|
|
# - sum reduce away the extra dims added from creating masks
|
|
# Tiny Things:
|
|
# 1. Supported indices: Union[int, slice, Tensor, None, List, Tuple, Ellipsis]
|
|
# - for any list, List[Union[List, Tuple, int]], must have homogeneous shape
|
|
# - for any tuple, Tuple[Union[List, Tuple, int]], must have homogeneous shape
|
|
# 2. Bool indexing is not supported
|
|
# 3. Out of bounds Tensor indexing results in 0
|
|
# - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are out of bounds
|
|
def _getitem(self, indices, v: Optional[Tensor] = None) -> Tensor:
|
|
# 1. indices normalization and validation
|
|
# treat internal tuples and lists as Tensors and standardize indices to list type
|
|
if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, self.device, requires_grad=False)]
|
|
elif isinstance(indices, (tuple, list)):
|
|
indices = [Tensor(i, self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices]
|
|
else: indices = [indices]
|
|
|
|
# turn scalar Tensors into const val for int indexing if possible
|
|
indices = [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices]
|
|
# move Tensor indices to the same device as self
|
|
indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices]
|
|
|
|
# filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
|
|
ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis]
|
|
fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
|
|
num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
|
|
indices[fill_idx:fill_idx+1] = [slice(None)] * (self.ndim - num_indices)
|
|
|
|
# use Dict[type, List[dimension]] to track elements in indices
|
|
type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
|
|
|
|
# record None for dimension injection later and filter None and record rest of indices
|
|
type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
|
|
indices_filtered = [i for i in indices if i is not None]
|
|
for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
|
|
|
|
if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
|
|
for index_type in type_dim:
|
|
if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported")
|
|
if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}")
|
|
|
|
# 2. basic indexing, uses only movement ops (no copy)
|
|
# currently indices_filtered: Tuple[Union[int, slice, Tensor], ...]
|
|
# turn indices in indices_filtered to Tuple[new_slice, strides]
|
|
for dim in type_dim[int]:
|
|
if (index := indices_filtered[dim]) >= (size := self.shape[dim]) or index < -size:
|
|
raise IndexError(f"{index=} is out of bounds on {dim=} with {size=}")
|
|
indices_filtered[dim] = ((index, index+1), 1) if index >= 0 else ((size+index, size+index+1), 1)
|
|
for dim in type_dim[slice]:
|
|
if (index := indices_filtered[dim]).step == 0: raise ValueError(f"{index=} on {dim=} cannot have 0 as step")
|
|
s, e, st = index.indices(self.shape[dim])
|
|
indices_filtered[dim] = ((0, 0) if (st * (e - s)) < 0 else (s, e) if st > 0 else (e+1, s+1), st)
|
|
# skip all Tensor dims for basic indexing
|
|
for dim in type_dim[Tensor]:
|
|
dtype = indices_filtered[dim].dtype
|
|
if not dtypes.is_int(dtype): raise IndexError(f"{dtype=} on {dim=} is not supported, only int tensor indexing is supported")
|
|
indices_filtered[dim] = ((0, self.shape[dim]), 1)
|
|
|
|
new_slice, strides = ((), ()) if not indices_filtered else zip(*indices_filtered)
|
|
# flip negative strides
|
|
ret = self.shrink(new_slice).flip(tuple(i for i, st in enumerate(strides) if st < 0))
|
|
# handle stride != 1 or -1
|
|
if any(abs(st) != 1 for st in strides):
|
|
strides = tuple(abs(s) for s in strides)
|
|
# pad shape to multiple of stride
|
|
ret = ret.pad(tuple((0, round_up(s, st) - s) for s, st in zip(ret.shape, strides)))
|
|
ret = ret.reshape(tuple(flatten((s // st, st) for s, st in zip(ret.shape, strides))))
|
|
ret = ret.shrink(tuple(flatten(((0, s), (0, 1)) for s in ret.shape[::2]))).reshape(ret.shape[::2])
|
|
|
|
# inject 1 for dim where it's None and collapse dim for int
|
|
new_shape = list(ret.shape)
|
|
for dim in type_dim[None]: new_shape.insert(dim, 1)
|
|
for dim in (dims_collapsed := tuple(dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int]))): new_shape.pop(dim)
|
|
|
|
ret = ret.reshape(new_shape)
|
|
|
|
# 3. advanced indexing (copy)
|
|
if type_dim[Tensor]:
|
|
dim_tensors = [(dim, i) for dim, i in enumerate(indices) if isinstance(i, Tensor)]
|
|
# calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim
|
|
def calc_dim(tensor_dim:int) -> int:
|
|
return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d)
|
|
|
|
assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}"
|
|
# track tensor_dim and tensor_index using a dict
|
|
# calc_dim to get dim and use that to normalize the negative tensor indices
|
|
idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in dim_tensors}
|
|
|
|
masks, first_dim, last_dim = [], min(idx.keys()), max(idx.keys())
|
|
pre_reduce_shape = ret.shape[:first_dim] + (big_shape := _broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:]
|
|
|
|
# create masks
|
|
for dim, i in idx.items():
|
|
try: i = i.reshape(i.shape + (1,)*(ret.ndim - first_dim)).expand(pre_reduce_shape)
|
|
except ValueError as e: raise IndexError(f"cannot broadcast indices: {e}") from e
|
|
a = Tensor.arange(ret.shape[dim], device=self.device, requires_grad=False).reshape((ret.shape[dim],) + (1,)*(ret.ndim - dim - 1))
|
|
masks.append(i == a)
|
|
|
|
# reduce masks to 1 mask
|
|
mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks)
|
|
|
|
# inject 1's for the extra dims added in create masks
|
|
reshape_arg = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:]
|
|
# sum reduce the extra dims introduced in create masks
|
|
ret = (ret.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype)
|
|
|
|
# special permute case
|
|
if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
|
|
ret = ret.permute(*range(first_dim, first_dim+len(big_shape)), *range(0, first_dim), *range(first_dim+len(big_shape), ret.ndim))
|
|
|
|
# for advanced setitem, returns whole tensor with indices replaced
|
|
if v is not None:
|
|
v = v.cast(self.dtype)._broadcast_to(_broadcast_shape(ret.shape, v.shape))
|
|
# add back reduced dims from sum
|
|
for dim in sum_axis: v = v.unsqueeze(dim)
|
|
# axis to be reduced to match self.shape
|
|
axis = tuple(range(first_dim, first_dim + len(big_shape)))
|
|
# apply mask to v(broadcasted) and reduce such that if v contains repeated indices the last one remains
|
|
v = v * mask
|
|
for dim in axis: v = functools.reduce(lambda x,y: y.where(y, x), v.split(1, dim))
|
|
# reduce mask and select from v(get rid of extra dims from reduce) for each True element in mask else select from self
|
|
ret = mask.any(axis).where(v.squeeze(), self)
|
|
|
|
return ret
|
|
|
|
def __getitem__(self, indices) -> Tensor:
|
|
return self._getitem(indices)
|
|
|
|
def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None:
|
|
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
|
self._getitem(indices).assign(v)
|
|
return
|
|
# NOTE: check that setitem target is valid first
|
|
if not all(lb.st.contiguous for lb in self.lazydata.lbs): raise RuntimeError("setitem target needs to be contiguous")
|
|
if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
|
|
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
|
|
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
|
|
|
|
res = self.realize()._getitem(indices, v)
|
|
# if shapes match and data is not shared it's a copy and we assign to self
|
|
if res.shape == self.shape and res.lazydata is not self.lazydata:
|
|
self.assign(res).realize()
|
|
else: # no copy, basic setitem
|
|
v = v.cast(res.dtype)._broadcast_to(_broadcast_shape(res.shape, v.shape)).contiguous()
|
|
res.assign(v).realize()
|
|
|
|
def gather(self:Tensor, dim:int, index:Tensor) -> Tensor:
|
|
"""
|
|
Gathers values along an axis specified by `dim`.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([[1, 2], [3, 4]])
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy())
|
|
```
|
|
"""
|
|
assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
|
|
dim = self._resolve_dim(dim)
|
|
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
|
|
index = index.to(self.device)
|
|
x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
|
|
return ((index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * x).sum(-1, acc_dtype=self.dtype)
|
|
|
|
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
|
"""
|
|
Concatenates self with other `Tensor` in `args` along an axis specified by `dim`.
|
|
All tensors must have the same shape except in the concatenating dimension.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]])
|
|
print(t0.cat(t1, t2, dim=0).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t0.cat(t1, t2, dim=1).numpy())
|
|
```
|
|
"""
|
|
dim = self._resolve_dim(dim)
|
|
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
|
|
catargs = [self, *args]
|
|
cat_dims = [s.shape[dim] for s in catargs]
|
|
cat_dim_cumsum = [0, *itertools.accumulate(cat_dims)]
|
|
slc:List[List[Optional[Tuple[sint, sint]]]] = [[None for _ in self.shape] for _ in catargs]
|
|
for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d)
|
|
return functools.reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
|
|
|
|
def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
|
"""
|
|
Concatenates self with other `Tensor` in `args` along a new dimension specified by `dim`.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6])
|
|
print(t0.stack(t1, t2, dim=0).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t0.stack(t1, t2, dim=1).numpy())
|
|
```
|
|
"""
|
|
# checks for shapes and number of dimensions delegated to cat
|
|
return self.unsqueeze(dim).cat(*[t.unsqueeze(dim) for t in args], dim=dim)
|
|
|
|
def repeat_interleave(self, repeats:int, dim:Optional[int]=None) -> Tensor:
|
|
"""
|
|
Repeat elements of a tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([1, 2, 3])
|
|
print(t.repeat_interleave(2).numpy())
|
|
```
|
|
"""
|
|
x, dim = (self.flatten(), 0) if dim is None else (self, dim)
|
|
shp = x.shape
|
|
return x.reshape(*shp[:dim+1], 1, *shp[dim+1:]).expand(*shp[:dim+1], repeats, *shp[dim+1:]).reshape(*shp[:dim], shp[dim]*repeats, *shp[dim+1:])
|
|
|
|
def repeat(self, repeats, *args) -> Tensor:
|
|
"""
|
|
Repeats tensor number of times along each dimension specified by `repeats`.
|
|
`repeats` can be passed as a tuple or as separate arguments.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([1, 2, 3])
|
|
print(t.repeat(4, 2).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.repeat(4, 2, 1).shape)
|
|
```
|
|
"""
|
|
repeats = argfix(repeats, *args)
|
|
base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
|
|
new_shape = [x for b in base_shape for x in [1, b]]
|
|
expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
|
|
final_shape = [r*s for r,s in zip(repeats, base_shape)]
|
|
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
|
|
|
|
def _resolve_dim(self, dim:int, *, outer:bool=False) -> int:
|
|
if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer):
|
|
raise IndexError(f"{dim=} out of range {[-max(1, self.ndim+outer), max(1, self.ndim+outer)-1]}")
|
|
return dim + self.ndim+outer if dim < 0 else dim
|
|
|
|
def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
|
|
"""
|
|
Splits the tensor into chunks along the dimension specified by `dim`.
|
|
If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
|
|
If `sizes` is a list, it splits into `len(sizes)` chunks with size in `dim` according to `size`.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.arange(10).reshape(5, 2)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
split = t.split(2)
|
|
print("\\n".join([repr(x.numpy()) for x in split]))
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
split = t.split([1, 4])
|
|
print("\\n".join([repr(x.numpy()) for x in split]))
|
|
```
|
|
"""
|
|
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
|
dim = self._resolve_dim(dim)
|
|
if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
|
|
assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
|
|
return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
|
|
|
|
def chunk(self, chunks:int, dim:int=0) -> List[Tensor]:
|
|
"""
|
|
Splits the tensor into `chunks` number of chunks along the dimension `dim`.
|
|
If the tensor size along `dim` is not divisible by `chunks`, all returned chunks will be the same size except the last one.
|
|
The function may return fewer than the specified number of chunks.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
chunked = Tensor.arange(11).chunk(6)
|
|
print("\\n".join([repr(x.numpy()) for x in chunked]))
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
chunked = Tensor.arange(12).chunk(6)
|
|
print("\\n".join([repr(x.numpy()) for x in chunked]))
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
chunked = Tensor.arange(13).chunk(6)
|
|
print("\\n".join([repr(x.numpy()) for x in chunked]))
|
|
```
|
|
"""
|
|
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
|
assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
|
|
dim = self._resolve_dim(dim)
|
|
return list(self.split(math.ceil(self.shape[dim]/chunks) if self.shape[dim] else [0]*chunks, dim=dim))
|
|
|
|
def squeeze(self, dim:Optional[int]=None) -> Tensor:
|
|
"""
|
|
Returns a tensor with specified dimensions of input of size 1 removed.
|
|
If `dim` is not specified, all dimensions with size 1 are removed.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.zeros(2, 1, 2, 1, 2)
|
|
print(t.squeeze().shape)
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.squeeze(0).shape)
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.squeeze(1).shape)
|
|
```
|
|
"""
|
|
if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
|
|
dim = self._resolve_dim(dim)
|
|
return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
|
|
|
|
def unsqueeze(self, dim:int) -> Tensor:
|
|
"""
|
|
Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([1, 2, 3, 4])
|
|
print(t.unsqueeze(0).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.unsqueeze(1).numpy())
|
|
```
|
|
"""
|
|
dim = self._resolve_dim(dim, outer=True)
|
|
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
|
|
|
def pad2d(self, padding:Sequence[int], value:float=0.0) -> Tensor:
|
|
"""
|
|
Returns a tensor that pads the last two axes specified by `padding` (padding_left, padding_right, padding_top, padding_bottom).
|
|
If `value` is specified, the tensor is padded with `value` instead of `0.0`.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.arange(9).reshape(1, 1, 3, 3)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.pad2d((1, 1, 2, 0), value=-float("inf")).numpy())
|
|
```
|
|
"""
|
|
pads = tuple((max(p0, 0), max(p1, 0)) for p0, p1 in zip(padding[::2], padding[1::2]))[::-1]
|
|
padded = self.pad((None,) * (self.ndim - len(padding) // 2) + tuple(pads), value=value)
|
|
shrink = tuple((-min(p0, 0), min(p1 + s, s)) for p0, p1, s in zip(padding[::2], padding[1::2], padded.shape[::-1]))[::-1]
|
|
return padded.shrink((None,) * (self.ndim - len(padding) // 2) + shrink)
|
|
|
|
@property
|
|
def T(self) -> Tensor:
|
|
"""`.T` is an alias for `.transpose()`."""
|
|
return self.transpose()
|
|
|
|
def transpose(self, dim0=1, dim1=0) -> Tensor:
|
|
"""
|
|
Returns a tensor that is a transposed version of the original tensor.
|
|
The given dimensions `dim0` and `dim1` are swapped.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.arange(6).reshape(2, 3)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.transpose(0, 1).numpy())
|
|
```
|
|
"""
|
|
order = list(range(self.ndim))
|
|
order[dim0], order[dim1] = order[dim1], order[dim0]
|
|
return self.permute(order)
|
|
|
|
def flatten(self, start_dim=0, end_dim=-1):
|
|
"""
|
|
Flattens the tensor by reshaping it into a one-dimensional tensor.
|
|
If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.arange(8).reshape(2, 2, 2)
|
|
print(t.flatten().numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.flatten(start_dim=1).numpy())
|
|
```
|
|
"""
|
|
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
|
|
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
|
|
|
|
def unflatten(self, dim:int, sizes:Tuple[int,...]):
|
|
"""
|
|
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
|
|
```
|
|
"""
|
|
dim = self._resolve_dim(dim)
|
|
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
|
|
|
|
def roll(self, shifts:Union[int, Tuple[int, ...]], dims:Union[int, Tuple[int, ...]]) -> Tensor:
|
|
"""
|
|
Rolls the tensor along specified dimension(s).
|
|
The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor.rand(3, 4, 1).roll(shifts=1, dims=0))
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor.rand(3, 4, 1).roll(shifts=-1, dims=0))
|
|
```
|
|
"""
|
|
dims, rolled = tuple(self._resolve_dim(d) for d in make_pair(dims, 1)), self
|
|
for dim, shift in zip(dims, make_pair(shifts, 1)):
|
|
shift = shift % self.shape[dim]
|
|
rolled = Tensor.cat(rolled[tuple(slice(None) if i != dim else slice(-shift, None) for i in range(rolled.ndim))],
|
|
rolled[tuple(slice(None) if i != dim else slice(None, -shift) for i in range(rolled.ndim))], dim=dim)
|
|
return rolled
|
|
|
|
# ***** reduce ops *****
|
|
|
|
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
|
|
if self.ndim == 0:
|
|
if axis is not None and any(a not in [-1, 0] for a in fully_flatten([axis])): raise IndexError(f"{axis=} out of range of [-1, 0]")
|
|
axis = ()
|
|
axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis))
|
|
axis_ = tuple(self._resolve_dim(x) for x in axis_)
|
|
ret = fxn.apply(self, axis=axis_)
|
|
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis_))
|
|
|
|
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
|
|
"""
|
|
Returns the sum of the elements of the tensor along the specified axis or axes.
|
|
|
|
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
|
which the maximum is computed and whether the reduced dimensions are retained.
|
|
|
|
You can pass in `acc_dtype` keyword argument to control the data type of the accumulation.
|
|
If not specified, the accumulation data type is chosen based on the input tensor's data type.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.arange(6).reshape(2, 3)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.sum().numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.sum(axis=0).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.sum(axis=1).numpy())
|
|
```
|
|
"""
|
|
ret = self.cast(acc_dtype or sum_acc_dtype(self.dtype))._reduce(F.Sum, axis, keepdim)
|
|
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
|
|
|
|
def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
|
|
"""
|
|
Returns the product of the elements of the tensor along the specified axis or axes.
|
|
|
|
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
|
which the maximum is computed and whether the reduced dimensions are retained.
|
|
|
|
You can pass in `acc_dtype` keyword argument to control the data type of the accumulation.
|
|
If not specified, the accumulation data type is chosen based on the input tensor's data type.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([-1, -2, -3, 1, 2, 3]).reshape(2, 3)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.prod().numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.prod(axis=0).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.prod(axis=1).numpy())
|
|
```
|
|
"""
|
|
return self.cast(acc_dtype or self.dtype)._reduce(F.Prod, axis, keepdim)
|
|
|
|
def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
|
"""
|
|
Returns the maximum value of the tensor along the specified axis or axes.
|
|
|
|
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
|
which the maximum is computed and whether the reduced dimensions are retained.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([[1, 0, 2], [5, 4, 3]])
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.max().numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.max(axis=0).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.max(axis=1, keepdim=True).numpy())
|
|
```
|
|
"""
|
|
return self._reduce(F.Max, axis, keepdim)
|
|
|
|
def min(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
|
"""
|
|
Returns the minimum value of the tensor along the specified axis or axes.
|
|
|
|
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
|
which the minimum is computed and whether the reduced dimensions are retained.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([[1, 0, 2], [5, 4, 3]])
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.min().numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.min(axis=0).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.min(axis=1, keepdim=True).numpy())
|
|
```
|
|
"""
|
|
return -((-self).max(axis=axis, keepdim=keepdim))
|
|
|
|
def any(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
|
"""
|
|
Tests if any element evaluates to `True` along the specified axis or axes.
|
|
|
|
You can pass in `axis` and `keepdim` keyword arguments to control the reduce axis and whether the reduced dimensions are retained.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([[True, True], [True, False], [False, False]])
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.any().numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.any(axis=0).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.any(axis=1, keepdim=True).numpy())
|
|
```
|
|
"""
|
|
return self.bool().max(axis, keepdim)
|
|
|
|
def all(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
|
"""
|
|
Tests if all element evaluates to `True` along the specified axis or axes.
|
|
|
|
You can pass in `axis` and `keepdim` keyword arguments to control the reduce axis and whether the reduced dimensions are retained.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([[True, True], [True, False], [False, False]])
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.all().numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.all(axis=0).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.all(axis=1, keepdim=True).numpy())
|
|
```
|
|
"""
|
|
return self.logical_not().any(axis, keepdim).logical_not()
|
|
|
|
def mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
|
"""
|
|
Returns the mean value of the tensor along the specified axis or axes.
|
|
|
|
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
|
which the mean is computed and whether the reduced dimensions are retained.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
t = Tensor.normal(2, 3, mean=2.5, std=0.5)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.mean().numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.mean(axis=0).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.mean(axis=1).numpy())
|
|
```
|
|
"""
|
|
output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
|
|
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
|
|
return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])).cast(output_dtype)
|
|
|
|
def var(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
|
|
"""
|
|
Returns the variance of the tensor along the specified axis or axes.
|
|
|
|
You can pass in `axis`, `keepdim`, and `correction` keyword arguments to control the axis along
|
|
which the variance is computed, whether the reduced dimensions are retained, and the Bessel's correction applied.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
t = Tensor.normal(2, 3, mean=2.5, std=0.5)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.var().numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.var(axis=0).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.var(axis=1).numpy())
|
|
```
|
|
"""
|
|
squares = (self - self.mean(axis=axis, keepdim=True)).square()
|
|
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
|
|
return squares.sum(axis=axis, keepdim=keepdim).div(smax([0, n-correction]))
|
|
|
|
def std(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
|
|
"""
|
|
Returns the standard deviation of the tensor along the specified axis or axes.
|
|
|
|
You can pass in `axis`, `keepdim`, and `correction` keyword arguments to control the axis along
|
|
which the standard deviation is computed, whether the reduced dimensions are retained, and the Bessel's correction applied.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
t = Tensor.normal(2, 3, mean=2.5, std=0.5)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.std().numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.std(axis=0).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.std(axis=1).numpy())
|
|
```
|
|
"""
|
|
return self.var(axis, keepdim, correction).sqrt()
|
|
|
|
def std_mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
|
|
"""
|
|
Calculates the standard deviation and mean over the dimensions specified by dim.
|
|
Syntactic sugar around `Tensor.std` and `Tensor.mean` to match `torch.std_mean`.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
t = Tensor.normal(2, 3, mean=2.5, std=0.5)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
std, mean = t.std_mean()
|
|
print(std.numpy(), mean.numpy())
|
|
```
|
|
"""
|
|
return self.std(axis, keepdim, correction), self.mean(axis, keepdim)
|
|
|
|
def _softmax(self, axis):
|
|
m = self - self.max(axis=axis, keepdim=True)
|
|
e = m.exp()
|
|
return m, e, e.sum(axis=axis, keepdim=True)
|
|
|
|
def softmax(self, axis=-1):
|
|
"""
|
|
Applies the softmax function to the tensor along the specified axis.
|
|
|
|
Rescales the elements of the tensor such that they lie in the range [0, 1] and sum to 1.
|
|
|
|
You can pass in the `axis` keyword argument to control the axis along which the softmax is computed.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
t = Tensor.randn(2, 3)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.softmax().numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.softmax(axis=0).numpy())
|
|
```
|
|
"""
|
|
_, e, ss = self._softmax(axis)
|
|
return e.div(ss)
|
|
|
|
def log_softmax(self, axis=-1):
|
|
"""
|
|
Applies the log-softmax function to the tensor along the specified axis.
|
|
|
|
The log-softmax function is a numerically stable alternative to the softmax function in log space.
|
|
|
|
You can pass in the `axis` keyword argument to control the axis along which the log-softmax is computed.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
t = Tensor.randn(2, 3)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.log_softmax().numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.log_softmax(axis=0).numpy())
|
|
```
|
|
"""
|
|
m, _, ss = self._softmax(axis)
|
|
return m - ss.log()
|
|
|
|
def logsumexp(self, axis=None, keepdim=False):
|
|
"""
|
|
Computes the log-sum-exp of the tensor along the specified axis or axes.
|
|
|
|
The log-sum-exp function is a numerically stable way to compute the logarithm of the sum of exponentials.
|
|
|
|
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
|
which the log-sum-exp is computed and whether the reduced dimensions are retained.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
t = Tensor.randn(2, 3)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.logsumexp().numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.logsumexp(axis=0).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.logsumexp(axis=1).numpy())
|
|
```
|
|
"""
|
|
m = self.max(axis=axis, keepdim=True)
|
|
return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis)
|
|
|
|
def argmax(self, axis=None, keepdim=False):
|
|
"""
|
|
Returns the indices of the maximum value of the tensor along the specified axis.
|
|
|
|
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
|
which the maximum is computed and whether the reduced dimensions are retained.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([[1, 0, 2], [5, 4, 3]])
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.argmax().numpy()) # Returns the index of the maximum value in the flattened tensor.
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.argmax(axis=0).numpy()) # Returns the indices of the maximum values along axis 0.
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.argmax(axis=1).numpy()) # Returns the indices of the maximum values along axis 1.
|
|
```
|
|
"""
|
|
if axis is None: return self.flatten().argmax(0)
|
|
axis = self._resolve_dim(axis)
|
|
m = self == self.max(axis=axis, keepdim=True)
|
|
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
|
|
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.int32)
|
|
|
|
def argmin(self, axis=None, keepdim=False):
|
|
"""
|
|
Returns the indices of the minimum value of the tensor along the specified axis.
|
|
|
|
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
|
which the minimum is computed and whether the reduced dimensions are retained.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([[1, 0, 2], [5, 4, 3]])
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.argmin().numpy()) # Returns the index of the minimum value in the flattened tensor.
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.argmin(axis=0).numpy()) # Returns the indices of the minimum values along axis 0.
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1.
|
|
```
|
|
"""
|
|
return (-self).argmax(axis=axis, keepdim=keepdim)
|
|
|
|
def rearrange(self, formula: str, **sizes) -> Tensor:
|
|
"""
|
|
Rearranges input according to formula
|
|
|
|
See: https://einops.rocks/api/rearrange/
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
x = Tensor([[1, 2], [3, 4]])
|
|
print(Tensor.rearrange(x, "batch channel -> (batch channel)).numpy())
|
|
```
|
|
"""
|
|
def parse_formula(formula: str):
|
|
tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
|
|
lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
|
|
pairs = list(zip(lparens, rparens))
|
|
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
|
|
return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
|
|
|
|
assert formula.count("->") == 1, 'need exactly one "->" in formula'
|
|
|
|
(lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
|
|
|
|
for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
|
|
assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
|
|
for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
|
|
assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
|
|
assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
|
|
|
|
# resolve ellipsis
|
|
if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
|
|
lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
|
|
unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
|
|
flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
|
|
|
|
# apply movement ops in order unflatten -> permute -> flatten/unsqueeze
|
|
t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
|
|
for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
|
|
t = t.permute([lhs.index(name) for name in rhs])
|
|
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
|
|
|
|
@staticmethod
|
|
def einsum(formula:str, *raw_xs, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
|
"""
|
|
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
|
|
|
|
See: https://pytorch.org/docs/stable/generated/torch.einsum.html
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
x = Tensor([[1, 2], [3, 4]])
|
|
y = Tensor([[5, 6], [7, 8]])
|
|
print(Tensor.einsum("ij,ij->", x, y).numpy())
|
|
```
|
|
"""
|
|
def parse_formula(formula: str, *operands: Tensor):
|
|
if "." in formula:
|
|
ell_chars, ell_longest = "".join(set(string.ascii_letters) - set(formula)), 0
|
|
for i, inp in enumerate(filter(lambda x: "..." in x, inputs := formula.split("->")[0].split(","))):
|
|
if (ell_count := max(operands[i].ndim, 1) - (len(inp) - 3)) > ell_longest: ell_longest = ell_count
|
|
inputs[i] = inp.replace("...", "" if ell_count == 0 else ell_chars[-ell_count:])
|
|
inputs_str, out_ellipse = ",".join(inputs), "" if ell_longest == 0 else ell_chars[-ell_longest:]
|
|
return (inputs_str, formula.split("->")[1].replace("...", out_ellipse)) if "->" in formula else (inputs_str, \
|
|
out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse)))
|
|
return formula.split("->") if "->" in formula else (formula, ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
|
|
|
|
xs:Tuple[Tensor, ...] = argfix(*raw_xs)
|
|
inputs_str, output = parse_formula(formula.replace(" ", ""), *xs)
|
|
inputs = inputs_str.split(",")
|
|
assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
|
|
|
|
# map the value of each letter in the formula
|
|
letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs)]).items())
|
|
|
|
xs_:List[Tensor] = []
|
|
lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs]
|
|
for x,(order,letters) in zip(xs, [list(zip(*l)) for l in lhs]):
|
|
# permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
|
|
xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val]))
|
|
|
|
# determine the inverse permutation to revert back to original order
|
|
rhs_letter_order = argsort(list(output))
|
|
rhs_order = argsort(rhs_letter_order)
|
|
|
|
# sum over all axes that's not in the output, then permute to the output order
|
|
return functools.reduce(lambda a,b:a*b, xs_) \
|
|
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output],acc_dtype=acc_dtype).permute(rhs_order)
|
|
|
|
# ***** processing ops *****
|
|
|
|
def _pool(self, k_:Tuple[sint, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
|
|
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
|
|
assert all_int(self.shape) and all_int(k_), f"does not support symbolic {self.shape=}, {k_=}"
|
|
s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
|
|
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
|
|
noop_, i_ = [None] * len(self.shape[:-len(k_)]), self.shape[-len(k_):]
|
|
o_ = [math.ceil((i - d * (k-1))/s) for i,d,k,s in zip(i_, d_, k_, s_)]
|
|
if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
|
|
# repeats such that we don't need padding
|
|
xup = self.repeat([1]*len(noop_) + [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)])
|
|
# handle dilation
|
|
xup = xup.shrink(tuple(noop_ + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])).reshape(noop_ + flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
|
|
# handle stride
|
|
xup = xup.shrink(
|
|
tuple(noop_ + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_)))).reshape(noop_ + flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
|
|
xup = xup.shrink(tuple(noop_ + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_)))).reshape(noop_ + flatten((k,o) for k,o in zip(k_, o_)))
|
|
# permute to move reduce to the end
|
|
return xup.permute(*range(len(noop_)), *[len(noop_)+i*2+1 for i in range(len(i_))], *[len(noop_)+i*2 for i in range(len(i_))])
|
|
# TODO: once the shapetracker can optimize well, remove this alternative implementation
|
|
xup = self.pad(tuple(noop_ + [(0, max(0,o*s-i)) for i,o,s in zip(i_, o_, s_)])).shrink(tuple(noop_ + [(0,o*s) for o,s in zip(o_, s_)]))
|
|
xup = xup.reshape(noop_ + flatten(((o,s) for o,s in zip(o_, s_))))
|
|
xup = xup.shrink(tuple(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_))))
|
|
return xup.permute(*range(len(noop_)), *[len(noop_)+i*2 for i in range(len(i_))], *[len(noop_)+i*2+1 for i in range(len(i_))])
|
|
|
|
def _padding2d(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]:
|
|
return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
|
|
|
|
# NOTE: these work for more than 2D
|
|
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, count_include_pad=True):
|
|
"""
|
|
Applies average pooling over a tensor.
|
|
|
|
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
|
|
|
|
See: https://paperswithcode.com/method/average-pooling
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.arange(25).reshape(1, 1, 5, 5)
|
|
print(t.avg_pool2d().numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.avg_pool2d(padding=1).numpy())
|
|
```
|
|
"""
|
|
padding_, axis = self._padding2d(padding, len(k_ := make_pair(kernel_size))), tuple(range(-len(k_), 0))
|
|
def pool(x:Tensor) -> Tensor: return x.pad2d(padding_)._pool(k_, stride if stride is not None else k_, dilation)
|
|
return pool(self).mean(axis=axis) if count_include_pad else pool(self).sum(axis=axis) / pool(self.ones_like()).sum(axis=axis)
|
|
|
|
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0):
|
|
"""
|
|
Applies max pooling over a tensor.
|
|
|
|
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
|
|
|
|
See: https://paperswithcode.com/method/max-pooling
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.arange(25).reshape(1, 1, 5, 5)
|
|
print(t.max_pool2d().numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.max_pool2d(padding=1).numpy())
|
|
```
|
|
"""
|
|
padding_ = self._padding2d(padding, len(k_ := make_pair(kernel_size)))
|
|
return self.pad2d(padding_, value=float('-inf'))._pool(k_, stride if stride is not None else k_, dilation).max(axis=tuple(range(-len(k_), 0)))
|
|
|
|
def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:DTypeLike|None=None) -> Tensor:
|
|
"""
|
|
Applies a convolution over a tensor with a given `weight` and optional `bias`.
|
|
|
|
NOTE: unlike PyTorch, this implementation is not limited to only 2d convolutions and instead works for any number of dimensions.
|
|
|
|
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.arange(9).reshape(1, 1, 3, 3)
|
|
w = Tensor.ones(1, 1, 2, 2)
|
|
print(t.conv2d(w).numpy())
|
|
```
|
|
"""
|
|
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
|
|
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
|
|
if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
|
|
padding_ = self._padding2d(padding, len(HW))
|
|
|
|
# conv2d is a pooling op (with padding)
|
|
x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
|
|
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
|
|
if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not WINO:
|
|
# normal conv
|
|
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501
|
|
|
|
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
|
|
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, acc_dtype=acc_dtype).reshape(bs, cout, *oyx) # noqa: E501
|
|
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
|
|
|
|
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
|
|
winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]]
|
|
winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
|
|
winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order doubles compile time
|
|
|
|
# todo: stride == dilation
|
|
# use padding to round up to 4x4 output tiles
|
|
# (bs, cin_, tyx, HWI)
|
|
d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501
|
|
# move HW to the front: # (HWI, bs, cin_, tyx)
|
|
d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW)))
|
|
tyx = d.shape[-len(HWI):] # dim of tiling
|
|
|
|
g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
|
|
|
|
# compute 6x6 winograd tiles: GgGt, BtdB
|
|
# (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
|
|
gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
|
|
# (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
|
|
dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx)
|
|
|
|
# matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
|
|
ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), acc_dtype=acc_dtype), len(HW))
|
|
|
|
# interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
|
|
ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]])
|
|
# merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final
|
|
ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(tuple((0, s) for s in [bs, cout, *oyx]))
|
|
|
|
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
|
|
|
|
def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
|
|
"""
|
|
Applies a transposed convolution over a tensor with a given `weight` and optional `bias`.
|
|
|
|
NOTE: unlike PyTorch, this implementation is not limited to only 2d transposed convolutions and instead works for any number of dimensions.
|
|
|
|
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.arange(9).reshape(1, 1, 3, 3)
|
|
w = Tensor.ones(1, 1, 2, 2)
|
|
print(t.conv_transpose2d(w).numpy())
|
|
```
|
|
"""
|
|
x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1))
|
|
HW = weight.shape[2:]
|
|
stride, dilation, padding, output_padding = [make_pair(x, len(HW)) for x in (stride, dilation, padding, output_padding)]
|
|
if any(s>1 for s in stride):
|
|
# handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1))
|
|
x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
|
|
x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
|
|
x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
|
|
x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
|
|
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, dilation, padding, output_padding)))))
|
|
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
|
|
|
def dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
|
"""
|
|
Performs dot product between two tensors.
|
|
|
|
You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
a = Tensor([[1, 2], [3, 4]])
|
|
b = Tensor([[5, 6], [7, 8]])
|
|
print(a.dot(b).numpy())
|
|
```
|
|
"""
|
|
n1, n2 = len(self.shape), len(w.shape)
|
|
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
|
if (L:=self.shape[-1]) != (R:=w.shape[-min(n2, 2)]): raise AssertionError(f"shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})")
|
|
x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
|
|
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
|
|
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)
|
|
|
|
def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
|
"""
|
|
Performs matrix multiplication between two tensors.
|
|
|
|
You can pass in the `reverse` keyword argument to control the order of the matrix multiplication.
|
|
You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
a = Tensor([[1, 2], [3, 4]])
|
|
b = Tensor([[5, 6], [7, 8]])
|
|
print(a.matmul(b).numpy())
|
|
```
|
|
"""
|
|
return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
|
|
|
|
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
|
|
assert self.shape[axis] != 0
|
|
pl_sz = self.shape[axis] - int(not _first_zero)
|
|
return self.transpose(axis,-1).pad2d((pl_sz,-int(_first_zero)))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
|
|
def cumsum(self, axis:int=0) -> Tensor:
|
|
"""
|
|
Computes the cumulative sum of the tensor along the specified axis.
|
|
|
|
You can pass in the `axis` keyword argument to control the axis along which the cumulative sum is computed.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.ones(2, 3)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.cumsum(1).numpy())
|
|
```
|
|
"""
|
|
axis = self._resolve_dim(axis)
|
|
if self.ndim == 0 or 0 in self.shape: return self
|
|
# TODO: someday the optimizer will find this on it's own
|
|
# for now this is a two stage cumsum
|
|
SPLIT = 256
|
|
if self.shape[axis] <= SPLIT*2: return self._cumsum(axis)
|
|
ret = self.transpose(axis,-1).pad2d((round_up(self.shape[axis], SPLIT)-self.shape[axis], 0))
|
|
ret = ret.unflatten(-1, (-1, SPLIT))._cumsum(-1)
|
|
base_add = ret[..., -1]._cumsum(-1, _first_zero=True)
|
|
base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
|
|
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -self.shape[axis]:].transpose(axis,-1)
|
|
return fix(ret) + fix(base_add)
|
|
|
|
@staticmethod
|
|
def _tri(r:sint, c:sint, diagonal:int=0, **kwargs) -> Tensor:
|
|
assert isinstance(r, int) and isinstance(c, int), f"does not support symbolic, getting {r=}, {c=}"
|
|
if r == 0 or c == 0 or diagonal >= c: return Tensor.zeros(r,c,**kwargs)
|
|
if r+diagonal <= 0: return Tensor.ones(r,c,**kwargs)
|
|
s = r+c-1
|
|
# build a (s, s) upper triangle
|
|
t = Tensor.ones(s,s,**kwargs).pad((None,(0,s))).flatten().shrink(((0,s*(2*s-1)),)).reshape(s,-1).shrink((None,(0,s)))
|
|
return t[:r,-diagonal:c-diagonal] if diagonal <= 0 else t[diagonal:r+diagonal,:c]
|
|
|
|
def triu(self, diagonal:int=0) -> Tensor:
|
|
"""
|
|
Returns the upper triangular part of the tensor, the other elements are set to 0.
|
|
|
|
The argument `diagonal` determines which diagonal is on the boundary. `diagonal = 0` means the main diagonal.
|
|
Positive `diagonal` means above the main diagonal, and negative `diagonal` means below the main diagonal.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.triu(diagonal=0).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.triu(diagonal=1).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.triu(diagonal=-1).numpy())
|
|
```
|
|
"""
|
|
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal, device=self.device, dtype=dtypes.bool).where(self, 0).cast(self.dtype)
|
|
|
|
def tril(self, diagonal:int=0) -> Tensor:
|
|
"""
|
|
Returns the lower triangular part of the tensor, the other elements are set to 0.
|
|
|
|
The argument `diagonal` determines which diagonal is on the boundary. `diagonal = 0` means the main diagonal.
|
|
Positive `diagonal` means above the main diagonal, and negative `diagonal` means below the main diagonal.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.tril(diagonal=0).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.tril(diagonal=1).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.tril(diagonal=-1).numpy())
|
|
```
|
|
"""
|
|
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(0, self).cast(self.dtype)
|
|
|
|
def interpolate(self, size:Tuple[int, ...], mode:str="linear", align_corners:bool=False) -> Tensor:
|
|
"""
|
|
Downsamples or Upsamples to the input `size`, accepts 0 to N batch dimensions.
|
|
|
|
The interpolation algorithm is selected with `mode` which currently only supports `linear`, `nearest` and `nearest-exact`.
|
|
To run `bilinear` or `trilinear`, pass in a 2D or 3D size.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([[1, 2, 3, 4], [21, 22, 23, 24], [41, 42, 43, 44]])
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.interpolate(size=(2,3), mode="linear").numpy())
|
|
```
|
|
"""
|
|
assert isinstance(size, (tuple,list)) and all_int(size) and 0 < len(size) <= self.ndim, f"invalid {size=}"
|
|
assert mode in ("linear", "nearest", "nearest-exact"), "only supports linear, nearest or nearest-exact interpolate"
|
|
assert not (align_corners and mode != "linear"), "align_corners option can only be set with the interpolating mode linear"
|
|
x, expand = self, list(self.shape)
|
|
for i in range(-1,-len(size)-1,-1):
|
|
scale = (self.shape[i] - int(align_corners)) / (size[i] - int(align_corners))
|
|
arr, reshape = Tensor.arange(size[i], dtype=dtypes.float32, device=self.device), [1] * self.ndim
|
|
reshape[i] = expand[i] = size[i]
|
|
if mode == "linear":
|
|
index = (scale*arr if align_corners else (scale*(arr+0.5))-0.5).clip(0, self.shape[i]-1)
|
|
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
|
|
x = x.gather(i, low).lerp(x.gather(i, high), perc)
|
|
else:
|
|
index = (scale*(arr+0.5) if mode=="nearest-exact" else scale*arr).cast(dtypes.int32).reshape(reshape).expand(expand)
|
|
x = x.gather(i, index)
|
|
return x.cast(self.dtype)
|
|
|
|
# ***** unary ops *****
|
|
|
|
def logical_not(self):
|
|
"""
|
|
Computes the logical NOT of the tensor element-wise.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([False, True]).logical_not().numpy())
|
|
```
|
|
"""
|
|
return F.Neq.apply(*self.cast(dtypes.bool)._broadcasted(True))
|
|
def neg(self):
|
|
"""
|
|
Negates the tensor element-wise.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).neg().numpy())
|
|
```
|
|
"""
|
|
return self*-1 if self.dtype != dtypes.bool else self.logical_not()
|
|
def contiguous(self):
|
|
"""
|
|
Returns a contiguous tensor.
|
|
"""
|
|
return F.Contiguous.apply(self)
|
|
def contiguous_backward(self):
|
|
"""
|
|
Inserts a contiguous operation in the backward pass.
|
|
"""
|
|
return F.ContiguousBackward.apply(self)
|
|
def log(self):
|
|
"""
|
|
Computes the natural logarithm element-wise.
|
|
|
|
See: https://en.wikipedia.org/wiki/Logarithm
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([1., 2., 4., 8.]).log().numpy())
|
|
```
|
|
"""
|
|
return F.Log.apply(self.cast(least_upper_float(self.dtype)))
|
|
def log2(self):
|
|
"""
|
|
Computes the base-2 logarithm element-wise.
|
|
|
|
See: https://en.wikipedia.org/wiki/Logarithm
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([1., 2., 4., 8.]).log2().numpy())
|
|
```
|
|
"""
|
|
return self.log()/math.log(2)
|
|
def exp(self):
|
|
"""
|
|
Computes the exponential function element-wise.
|
|
|
|
See: https://en.wikipedia.org/wiki/Exponential_function
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([0., 1., 2., 3.]).exp().numpy())
|
|
```
|
|
"""
|
|
return F.Exp.apply(self.cast(least_upper_float(self.dtype)))
|
|
def exp2(self):
|
|
"""
|
|
Computes the base-2 exponential function element-wise.
|
|
|
|
See: https://en.wikipedia.org/wiki/Exponential_function
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([0., 1., 2., 3.]).exp2().numpy())
|
|
```
|
|
"""
|
|
return F.Exp.apply(self*math.log(2))
|
|
def relu(self):
|
|
"""
|
|
Applies the Rectified Linear Unit (ReLU) function element-wise.
|
|
|
|
- Described: https://paperswithcode.com/method/relu
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy())
|
|
```
|
|
"""
|
|
return F.Relu.apply(self)
|
|
def sigmoid(self):
|
|
"""
|
|
Applies the Sigmoid function element-wise.
|
|
|
|
- Described: https://en.wikipedia.org/wiki/Sigmoid_function
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sigmoid().numpy())
|
|
```
|
|
"""
|
|
return F.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
|
|
def sqrt(self):
|
|
"""
|
|
Computes the square root of the tensor element-wise.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([1., 2., 3., 4.]).sqrt().numpy())
|
|
```
|
|
"""
|
|
return F.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
|
|
def rsqrt(self):
|
|
"""
|
|
Computes the reciprocal of the square root of the tensor element-wise.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([1., 2., 3., 4.]).rsqrt().numpy())
|
|
```
|
|
"""
|
|
return self.reciprocal().sqrt()
|
|
def sin(self):
|
|
"""
|
|
Computes the sine of the tensor element-wise.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy())
|
|
```
|
|
"""
|
|
return F.Sin.apply(self.cast(least_upper_float(self.dtype)))
|
|
def cos(self):
|
|
"""
|
|
Computes the cosine of the tensor element-wise.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).cos().numpy())
|
|
```
|
|
"""
|
|
return ((math.pi/2)-self).sin()
|
|
def tan(self):
|
|
"""
|
|
Computes the tangent of the tensor element-wise.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([0., math.pi/4, math.pi/2, 3*math.pi/4, math.pi]).tan().numpy())
|
|
```
|
|
"""
|
|
return self.sin() / self.cos()
|
|
|
|
# ***** math functions *****
|
|
|
|
def trunc(self: Tensor) -> Tensor:
|
|
"""
|
|
Truncates the tensor element-wise.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).trunc().numpy())
|
|
```
|
|
"""
|
|
return self.cast(dtypes.int32).cast(self.dtype)
|
|
def ceil(self: Tensor) -> Tensor:
|
|
"""
|
|
Rounds the tensor element-wise towards positive infinity.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).ceil().numpy())
|
|
```
|
|
"""
|
|
return (self > (b := self.trunc())).where(b+1, b)
|
|
def floor(self: Tensor) -> Tensor:
|
|
"""
|
|
Rounds the tensor element-wise towards negative infinity.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).floor().numpy())
|
|
```
|
|
"""
|
|
return (self < (b := self.trunc())).where(b-1, b)
|
|
def round(self: Tensor) -> Tensor:
|
|
"""
|
|
Rounds the tensor element-wise with rounding half to even.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).round().numpy())
|
|
```
|
|
"""
|
|
return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
|
|
|
|
def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor:
|
|
"""
|
|
Linearly interpolates between `self` and `end` by `weight`.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([1., 2., 3.]).lerp(Tensor([4., 5., 6.]), 0.5).numpy())
|
|
```
|
|
"""
|
|
if self.dtype == dtypes.uint8 and isinstance(weight, Tensor):
|
|
w_i = (weight * (1<<(W_PREC:=7)) + 0.5).cast(dtypes.int16)
|
|
return (self+(((end - self).cast(dtypes.int8) * w_i + (1<<W_PREC-1)).cast(dtypes.uint16) >> W_PREC)).cast(dtypes.uint8)
|
|
return self + (end - self) * weight
|
|
|
|
def square(self):
|
|
"""
|
|
Squares the tensor element-wise.
|
|
Equivalent to `self*self`.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).square().numpy())
|
|
```
|
|
"""
|
|
return self*self
|
|
def clamp(self, min_=None, max_=None):
|
|
"""
|
|
Clips (clamps) the values in the tensor between `min_` and `max_` element-wise.
|
|
If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).clip(-1, 1).numpy())
|
|
```
|
|
"""
|
|
if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None")
|
|
ret = self.maximum(min_) if min_ is not None else self
|
|
return ret.minimum(max_) if max_ is not None else ret
|
|
def clip(self, min_=None, max_=None):
|
|
"""
|
|
Alias for `Tensor.clamp`.
|
|
"""
|
|
return self.clamp(min_, max_)
|
|
def sign(self):
|
|
"""
|
|
Returns the sign of the tensor element-wise.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy())
|
|
```
|
|
"""
|
|
return F.Sign.apply(self)
|
|
def abs(self):
|
|
"""
|
|
Computes the absolute value of the tensor element-wise.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).abs().numpy())
|
|
```
|
|
"""
|
|
return self * self.sign()
|
|
def reciprocal(self):
|
|
"""
|
|
Compute `1/x` element-wise.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([1., 2., 3., 4.]).reciprocal().numpy())
|
|
```
|
|
"""
|
|
return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
|
|
|
|
# ***** activation functions *****
|
|
|
|
def elu(self, alpha=1.0):
|
|
"""
|
|
Applies the Exponential Linear Unit (ELU) function element-wise.
|
|
|
|
- Described: https://paperswithcode.com/method/elu
|
|
- Paper: https://arxiv.org/abs/1511.07289v5
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).elu().numpy())
|
|
```
|
|
"""
|
|
return self.relu() - alpha*(1-self.exp()).relu()
|
|
|
|
def celu(self, alpha=1.0):
|
|
"""
|
|
Applies the Continuously differentiable Exponential Linear Unit (CELU) function element-wise.
|
|
|
|
- Described: https://paperswithcode.com/method/celu
|
|
- Paper: https://arxiv.org/abs/1704.07483
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).celu().numpy())
|
|
```
|
|
"""
|
|
return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
|
|
|
|
def swish(self):
|
|
"""
|
|
See `.silu()`
|
|
|
|
- Paper: https://arxiv.org/abs/1710.05941v1
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).swish().numpy())
|
|
```
|
|
"""
|
|
return self * self.sigmoid()
|
|
|
|
def silu(self):
|
|
"""
|
|
Applies the Sigmoid Linear Unit (SiLU) function element-wise.
|
|
|
|
- Described: https://paperswithcode.com/method/silu
|
|
- Paper: https://arxiv.org/abs/1606.08415
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).silu().numpy())
|
|
```
|
|
"""
|
|
return self.swish() # The SiLU function is also known as the swish function.
|
|
|
|
def relu6(self):
|
|
"""
|
|
Applies the ReLU6 function element-wise.
|
|
|
|
- Described: https://paperswithcode.com/method/relu6
|
|
- Paper: https://arxiv.org/abs/1704.04861v1
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-9., -6., -3., 0., 3., 6., 9.]).relu6().numpy())
|
|
```
|
|
"""
|
|
return self.relu() - (self-6).relu()
|
|
|
|
def hardswish(self):
|
|
"""
|
|
Applies the Hardswish function element-wise.
|
|
|
|
- Described: https://paperswithcode.com/method/hard-swish
|
|
- Paper: https://arxiv.org/abs/1905.02244v5
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardswish().numpy())
|
|
```
|
|
"""
|
|
return self * (self+3).relu6() * (1/6)
|
|
|
|
def tanh(self):
|
|
"""
|
|
Applies the Hyperbolic Tangent (tanh) function element-wise.
|
|
|
|
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Tanh
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).tanh().numpy())
|
|
```
|
|
"""
|
|
return 2.0 * ((2.0 * self).sigmoid()) - 1.0
|
|
|
|
def sinh(self):
|
|
"""
|
|
Applies the Hyperbolic Sine (sinh) function element-wise.
|
|
|
|
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Sinh
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sinh().numpy())
|
|
```
|
|
"""
|
|
return (self.exp() - self.neg().exp()) / 2
|
|
|
|
def cosh(self):
|
|
"""
|
|
Applies the Hyperbolic Cosine (cosh) function element-wise.
|
|
|
|
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Cosh
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).cosh().numpy())
|
|
```
|
|
"""
|
|
return (self.exp() + self.neg().exp()) / 2
|
|
|
|
def atanh(self):
|
|
"""
|
|
Applies the Inverse Hyperbolic Tangent (atanh) function element-wise.
|
|
|
|
- Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#atanh
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).atanh().numpy())
|
|
```
|
|
"""
|
|
return ((1 + self)/(1 - self)).log() / 2
|
|
|
|
def asinh(self):
|
|
"""
|
|
Applies the Inverse Hyperbolic Sine (asinh) function element-wise.
|
|
|
|
- Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#asinh
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).asinh().numpy())
|
|
```
|
|
"""
|
|
return (self + (self.square() + 1).sqrt()).log()
|
|
|
|
def acosh(self):
|
|
"""
|
|
Applies the Inverse Hyperbolic Cosine (acosh) function element-wise.
|
|
|
|
- Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#acosh
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).acosh().numpy())
|
|
```
|
|
"""
|
|
return (self + (self.square() - 1).sqrt()).log()
|
|
|
|
def hardtanh(self, min_val=-1, max_val=1):
|
|
"""
|
|
Applies the Hardtanh function element-wise.
|
|
|
|
- Described: https://paperswithcode.com/method/hardtanh-activation
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).hardtanh().numpy())
|
|
```
|
|
"""
|
|
return self.clip(min_val, max_val)
|
|
|
|
def gelu(self):
|
|
"""
|
|
Applies the Gaussian Error Linear Unit (GELU) function element-wise.
|
|
|
|
- Described: https://paperswithcode.com/method/gelu
|
|
- Paper: https://arxiv.org/abs/1606.08415v5
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).gelu().numpy())
|
|
```
|
|
"""
|
|
return 0.5 * self * (1 + (math.sqrt(2 / math.pi) * (self + 0.044715 * self ** 3)).tanh())
|
|
|
|
def quick_gelu(self):
|
|
"""
|
|
Applies the Sigmoid GELU approximation element-wise.
|
|
|
|
- Described: https://paperswithcode.com/method/gelu
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).quick_gelu().numpy())
|
|
```
|
|
"""
|
|
return self * (self * 1.702).sigmoid()
|
|
|
|
def leakyrelu(self, neg_slope=0.01):
|
|
"""
|
|
Applies the Leaky ReLU function element-wise.
|
|
|
|
- Described: https://paperswithcode.com/method/leaky-relu
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu().numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu(neg_slope=0.42).numpy())
|
|
```
|
|
"""
|
|
return self.relu() - (-neg_slope*self).relu()
|
|
|
|
def mish(self):
|
|
"""
|
|
Applies the Mish function element-wise.
|
|
|
|
- Described: https://paperswithcode.com/method/mish
|
|
- Paper: https://arxiv.org/abs/1908.08681v3
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).mish().numpy())
|
|
```
|
|
"""
|
|
return self * self.softplus().tanh()
|
|
|
|
def softplus(self, beta=1):
|
|
"""
|
|
Applies the Softplus function element-wise.
|
|
|
|
- Described: https://paperswithcode.com/method/softplus
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softplus().numpy())
|
|
```
|
|
"""
|
|
return (1/beta) * (1 + (self*beta).exp()).log()
|
|
|
|
def softsign(self):
|
|
"""
|
|
Applies the Softsign function element-wise.
|
|
|
|
- Described: https://paperswithcode.com/method/softsign
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softsign().numpy())
|
|
```
|
|
"""
|
|
return self / (1 + self.abs())
|
|
|
|
# ***** broadcasted elementwise ops *****
|
|
def _broadcast_to(self, shape:Tuple[sint, ...]) -> Tensor:
|
|
if self.shape == shape: return self
|
|
if self.ndim > len(shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {shape=}")
|
|
# first pad left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
|
|
padded, _ = _pad_left(self.shape, shape)
|
|
# for each dimension, check either from_ is 1, or it does not change
|
|
if any(resolve(from_ != 1, False) and resolve(from_ != to, False) for from_,to in zip(padded, shape)):
|
|
raise ValueError(f"cannot broadcast from shape={self.shape} to {shape=}")
|
|
return F.Expand.apply(self.reshape(padded), shape=shape)
|
|
|
|
def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
|
|
x: Tensor = self
|
|
if not isinstance(y, Tensor):
|
|
# make y a Tensor
|
|
assert isinstance(y, (*get_args(ConstType), UOp)), f"{type(y)=}, {y=}"
|
|
if isinstance(x.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
|
|
elif not isinstance(y, UOp): y_dtype = dtypes.from_py(y)
|
|
if isinstance(y, UOp): y = Tensor.from_uop(y, device=x.device)
|
|
else: y = Tensor(dtypes.as_const(y, y_dtype), x.device, y_dtype, requires_grad=False)
|
|
|
|
if match_dtype and x.dtype != y.dtype:
|
|
output_dtype = least_upper_dtype(x.dtype, y.dtype)
|
|
x, y = x.cast(output_dtype), y.cast(output_dtype)
|
|
|
|
if reverse: x, y = y, x
|
|
|
|
# broadcast
|
|
out_shape = _broadcast_shape(x.shape, y.shape)
|
|
return x._broadcast_to(out_shape), y._broadcast_to(out_shape)
|
|
|
|
def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
|
|
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unmasked_const() \
|
|
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
|
|
|
|
def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
|
"""
|
|
Adds `self` and `x`.
|
|
Equivalent to `self + x`.
|
|
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
t = Tensor.randn(4)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.add(20).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.add(Tensor([[2.0], [3.5]])).numpy())
|
|
```
|
|
"""
|
|
return F.Add.apply(*self._broadcasted(x, reverse))
|
|
|
|
def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
|
"""
|
|
Subtracts `x` from `self`.
|
|
Equivalent to `self - x`.
|
|
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
t = Tensor.randn(4)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.sub(20).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.sub(Tensor([[2.0], [3.5]])).numpy())
|
|
```
|
|
"""
|
|
a, b = self._broadcasted(x, reverse)
|
|
return a + (-b)
|
|
|
|
def mul(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
|
"""
|
|
Multiplies `self` and `x`.
|
|
Equivalent to `self * x`.
|
|
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
t = Tensor.randn(4)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.mul(3).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.mul(Tensor([[-1.0], [2.0]])).numpy())
|
|
```
|
|
"""
|
|
return F.Mul.apply(*self._broadcasted(x, reverse))
|
|
|
|
def div(self, x:Union[Tensor, ConstType], reverse=False, upcast=True) -> Tensor:
|
|
"""
|
|
Divides `self` by `x`.
|
|
Equivalent to `self / x`.
|
|
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
|
|
By default, `div` performs true division. Set `upcast` to `False` for integer division.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
t = Tensor.randn(4)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.div(3).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4])).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4]), upcast=False).numpy())
|
|
```
|
|
"""
|
|
numerator, denominator = self._broadcasted(x, reverse)
|
|
if upcast: numerator, denominator = numerator.cast(least_upper_float(numerator.dtype)), denominator.cast(least_upper_float(denominator.dtype))
|
|
return (numerator * denominator.reciprocal()) if dtypes.is_float(numerator.dtype) else F.IDiv.apply(numerator, denominator)
|
|
|
|
def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
|
"""
|
|
Computes bitwise xor of `self` and `x`.
|
|
Equivalent to `self ^ x`.
|
|
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-1, -2, 3]).xor(Tensor([1, 0, 3])).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([True, True, False, False]).xor(Tensor([True, False, True, False])).numpy())
|
|
```
|
|
"""
|
|
return F.Xor.apply(*self._broadcasted(x, reverse))
|
|
|
|
def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
|
"""
|
|
Compute the bit-wise AND of `self` and `x`.
|
|
Equivalent to `self & x`.
|
|
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([2, 5, 255]).bitwise_and(Tensor([3, 14, 16])).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([True, True, False, False]).bitwise_and(Tensor([True, False, True, False])).numpy())
|
|
```
|
|
"""
|
|
assert dtypes.is_int(self.dtype)
|
|
return F.BitwiseAnd.apply(*self._broadcasted(x, reverse))
|
|
|
|
def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
|
"""
|
|
Compute the bit-wise OR of `self` and `x`.
|
|
Equivalent to `self | x`.
|
|
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([2, 5, 255]).bitwise_or(Tensor([4, 4, 4])).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([True, True, False, False]).bitwise_or(Tensor([True, False, True, False])).numpy())
|
|
```
|
|
"""
|
|
assert dtypes.is_int(self.dtype)
|
|
return F.BitwiseOr.apply(*self._broadcasted(x, reverse))
|
|
|
|
def lshift(self, x:int):
|
|
"""
|
|
Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
|
|
Equivalent to `self << x`.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([1, 3, 31], dtype=dtypes.uint8).lshift(2).numpy())
|
|
```
|
|
"""
|
|
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
|
|
return self.mul(2 ** x)
|
|
|
|
def rshift(self, x:int):
|
|
"""
|
|
Computes right arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
|
|
Equivalent to `self >> x`.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([4, 13, 125], dtype=dtypes.uint8).rshift(2).numpy())
|
|
```
|
|
"""
|
|
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
|
|
return self.div(2 ** x, upcast=False)
|
|
|
|
def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
|
"""
|
|
Computes power of `self` with `x`.
|
|
Equivalent to `self ** x`.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-1, 2, 3]).pow(2).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-1, 2, 3]).pow(Tensor([-1.5, 0.5, 1.5])).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print((2 ** Tensor([-1, 2, 3])).numpy())
|
|
```
|
|
"""
|
|
x = self._to_const_val(x)
|
|
if not isinstance(x, Tensor) and not reverse:
|
|
# simple pow identities
|
|
if x < 0: return self.reciprocal().pow(-x)
|
|
if x == 0: return 1 + self * 0
|
|
if int(x - 0.5) + 0.5 == x: return self.pow(int(x - 0.5)) * self.sqrt()
|
|
if int(x) == x: return self.pow(x // 2).square() * (1 if x % 2 == 0 else self)
|
|
|
|
# positive const ** self
|
|
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
|
|
|
|
base, exponent = self._broadcasted(x, reverse=reverse)
|
|
# start with b ** e = exp(e * log(b))
|
|
ret = base.abs().log().mul(exponent).exp()
|
|
# correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent)
|
|
negative_base = (base < 0).detach().where(1, 0)
|
|
# 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent
|
|
correct_sign = 1 + negative_base * ((exponent * math.pi).cos() - 1)
|
|
# inject nan for negative base and non-integer exponent
|
|
inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
|
|
# apply correct_sign inject_nan, and fix 0 ** 0 = 1
|
|
return ((base == 0) * (exponent == 0)).detach().where(1, ret * correct_sign * inject_nan)
|
|
|
|
def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
|
|
"""
|
|
Computes element-wise maximum of `self` and `x`.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-1, 2, 3]).maximum(1).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
|
|
```
|
|
"""
|
|
return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
|
|
|
|
def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
|
|
"""
|
|
Computes element-wise minimum of `self` and `x`.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-1, 2, 3]).minimum(1).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy())
|
|
```
|
|
"""
|
|
return -((-self).maximum(-x))
|
|
|
|
def where(self:Tensor, x:Union[Tensor, ConstType], y:Union[Tensor, ConstType]):
|
|
"""
|
|
Return a tensor of elements selected from either `x` or `y`, depending on `self`.
|
|
`output_i = x_i if self_i else y_i`.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
cond = Tensor([[True, True, False], [True, False, False]])
|
|
print(cond.where(1, 3).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
cond = Tensor.randn(2, 3)
|
|
print(cond.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print((cond > 0).where(cond, -float("inf")).numpy())
|
|
```
|
|
"""
|
|
if isinstance(x, Tensor): x, y = x._broadcasted(y)
|
|
elif isinstance(y, Tensor): y, x = y._broadcasted(x)
|
|
cond, x = self._broadcasted(x, match_dtype=False)
|
|
cond, y = cond._broadcasted(y, match_dtype=False)
|
|
return F.Where.apply(cond.cast(dtypes.bool), *x._broadcasted(y))
|
|
|
|
def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self)
|
|
|
|
# ***** op wrappers *****
|
|
|
|
def __neg__(self) -> Tensor: return self.neg()
|
|
|
|
def __add__(self, x) -> Tensor: return self.add(x)
|
|
def __sub__(self, x) -> Tensor: return self.sub(x)
|
|
def __mul__(self, x) -> Tensor: return self.mul(x)
|
|
def __pow__(self, x) -> Tensor: return self.pow(x)
|
|
def __truediv__(self, x) -> Tensor: return self.div(x)
|
|
def __floordiv__(self, x) -> Tensor: return self.div(x, upcast=False)
|
|
def __matmul__(self, x) -> Tensor: return self.matmul(x)
|
|
def __and__(self, x) -> Tensor: return self.bitwise_and(x)
|
|
def __or__(self, x) -> Tensor: return self.bitwise_or(x)
|
|
def __xor__(self, x) -> Tensor: return self.xor(x)
|
|
def __lshift__(self, x) -> Tensor: return self.lshift(x)
|
|
def __rshift__(self, x) -> Tensor: return self.rshift(x)
|
|
|
|
def __radd__(self, x) -> Tensor: return self.add(x, True)
|
|
def __rsub__(self, x) -> Tensor: return self.sub(x, True)
|
|
def __rmul__(self, x) -> Tensor: return self.mul(x, True)
|
|
def __rpow__(self, x) -> Tensor: return self.pow(x, True)
|
|
def __rtruediv__(self, x) -> Tensor: return self.div(x, True)
|
|
def __rfloordiv__(self, x) -> Tensor: return self.div(x, True, upcast=False)
|
|
def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True)
|
|
def __rand__(self, x) -> Tensor: return self.bitwise_and(x, True)
|
|
def __ror__(self, x) -> Tensor: return self.bitwise_or(x, True)
|
|
def __rxor__(self, x) -> Tensor: return self.xor(x, True)
|
|
|
|
def __iadd__(self, x) -> Tensor: return self.assign(self.add(x))
|
|
def __isub__(self, x) -> Tensor: return self.assign(self.sub(x))
|
|
def __imul__(self, x) -> Tensor: return self.assign(self.mul(x))
|
|
def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
|
|
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
|
|
def __ifloordiv__(self, x) -> Tensor: return self.assign(self.div(x, upcast=False))
|
|
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
|
|
def __iand__(self, x) -> Tensor: return self.assign(self.bitwise_and(x))
|
|
def __ior__(self, x) -> Tensor: return self.assign(self.bitwise_or(x))
|
|
def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
|
|
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
|
|
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
|
|
|
|
def __lt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
|
|
def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
|
|
def __ge__(self, x) -> Tensor: return (self<x).logical_not()
|
|
def __le__(self, x) -> Tensor: return (self>x).logical_not()
|
|
def __ne__(self, x) -> Tensor: return F.Neq.apply(*self._broadcasted(x)) # type: ignore[override]
|
|
def __eq__(self, x) -> Tensor: return (self!=x).logical_not() # type: ignore[override]
|
|
|
|
# ***** functional nn ops *****
|
|
|
|
def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
|
|
"""
|
|
Applies a linear transformation to `self` using `weight` and `bias`.
|
|
|
|
See: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([[1, 2], [3, 4]])
|
|
weight = Tensor([[1, 2], [3, 4]])
|
|
bias = Tensor([1, 2])
|
|
print(t.linear(weight, bias).numpy())
|
|
```
|
|
"""
|
|
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
|
|
return x.add(bias) if bias is not None else x
|
|
|
|
def sequential(self, ll:List[Callable[[Tensor], Tensor]]):
|
|
"""
|
|
Applies a sequence of functions to `self` chaining the output of each function to the input of the next.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([1, 2, 3])
|
|
print(t.sequential([lambda x: x * 2, lambda x: x + 1]).numpy())
|
|
```
|
|
"""
|
|
return functools.reduce(lambda x,f: f(x), ll, self)
|
|
|
|
def layernorm(self, axis:Union[int,Tuple[int,...]]=-1, eps:float=1e-5) -> Tensor:
|
|
"""
|
|
Applies Layer Normalization over a mini-batch of inputs.
|
|
|
|
- Described: https://paperswithcode.com/method/layer-normalization
|
|
- Paper: https://arxiv.org/abs/1607.06450v1
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.randn(8, 10, 16) * 2 + 8
|
|
print(t.mean().item(), t.std().item())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = t.layernorm()
|
|
print(t.mean().item(), t.std().item())
|
|
```
|
|
"""
|
|
y = (self - self.mean(axis, keepdim=True))
|
|
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
|
|
|
|
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor:
|
|
"""
|
|
Applies Batch Normalization over a mini-batch of inputs.
|
|
|
|
- Described: https://paperswithcode.com/method/batch-normalization
|
|
- Paper: https://arxiv.org/abs/1502.03167
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.randn(8, 4, 16, 16) * 2 + 8
|
|
print(t.mean().item(), t.std().item())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = t.batchnorm(None, None, t.mean(axis=(0,2,3)), t.var(axis=(0,2,3)).add(1e-5).rsqrt())
|
|
print(t.mean().item(), t.std().item())
|
|
```
|
|
"""
|
|
axis_ = argfix(axis)
|
|
shape = tuple(s if ax in axis_ else 1 for ax, s in enumerate(self.shape))
|
|
x = self - mean.reshape(shape)
|
|
if weight is not None: x = x * weight.reshape(shape)
|
|
ret = x.mul(invstd.reshape(shape) if len(invstd.shape) == len(axis_) else invstd)
|
|
return (ret + bias.reshape(shape)) if bias is not None else ret
|
|
|
|
def dropout(self, p=0.5) -> Tensor:
|
|
"""
|
|
Applies dropout to `self`.
|
|
|
|
NOTE: dropout is only applied when `Tensor.training` is `True`.
|
|
|
|
- Described: https://paperswithcode.com/method/dropout
|
|
- Paper: https://jmlr.org/papers/v15/srivastava14a.html
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
Tensor.manual_seed(42)
|
|
t = Tensor.randn(2, 2)
|
|
with Tensor.train():
|
|
print(t.dropout().numpy())
|
|
```
|
|
"""
|
|
if not Tensor.training or p == 0: return self
|
|
return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float) >= p).where(self, 0) * (1/(1.0 - p))
|
|
|
|
def one_hot(self, num_classes:int=-1) -> Tensor:
|
|
"""
|
|
Converts `self` to a one-hot tensor.
|
|
|
|
`num_classes` defaults to -1, which means num_classes will be inferred as max(self) + 1.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([0, 1, 3, 3, 4])
|
|
print(t.one_hot(5).numpy())
|
|
```
|
|
"""
|
|
if num_classes == -1: num_classes = (self.max()+1).item()
|
|
return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0)
|
|
|
|
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,
|
|
dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
|
|
"""
|
|
Computes scaled dot-product attention.
|
|
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
|
|
|
|
- Described: https://paperswithcode.com/method/scaled
|
|
- Paper: https://arxiv.org/abs/1706.03762v7
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
q = Tensor.randn(2, 4, 8)
|
|
k = Tensor.randn(2, 4, 8)
|
|
v = Tensor.randn(2, 4, 8)
|
|
print(q.scaled_dot_product_attention(k, v).numpy())
|
|
```
|
|
"""
|
|
# NOTE: it also works when `key` and `value` have symbolic shape.
|
|
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
|
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
|
|
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
|
|
qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
|
|
return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
|
|
|
|
def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
|
|
if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
|
|
reductions: Dict[str, Callable[[Tensor], Tensor]] = {"mean": Tensor.mean, "sum": Tensor.sum, "none": lambda x: x}
|
|
return reductions[reduction](self)
|
|
|
|
def binary_crossentropy(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
|
|
"""
|
|
Computes the binary cross-entropy loss between `self` and `Y`.
|
|
|
|
See: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([0.1, 0.9, 0.2])
|
|
Y = Tensor([0, 1, 0])
|
|
print(t.binary_crossentropy(Y).item())
|
|
```
|
|
"""
|
|
return (-Y*self.log() - (1-Y)*(1-self).log())._do_reduction(reduction)
|
|
|
|
def binary_crossentropy_logits(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
|
|
"""
|
|
Computes the binary cross-entropy loss between `self` and `Y` where `self` is logits.
|
|
|
|
See: https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([-1, 2, -3])
|
|
Y = Tensor([0, 1, 0])
|
|
print(t.binary_crossentropy_logits(Y).item())
|
|
```
|
|
"""
|
|
return (self.maximum(0) - Y * self + (1 + self.abs().neg().exp()).log())._do_reduction(reduction)
|
|
|
|
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index:int=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor:
|
|
"""
|
|
Computes the sparse categorical cross-entropy loss between `self` and `Y`.
|
|
|
|
NOTE: `self` is logits and `Y` is the target labels.
|
|
NOTE: unlike PyTorch, this function expects the class axis to be -1
|
|
|
|
See: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([[-1, 2, -3], [1, -2, 3]])
|
|
Y = Tensor([1, 2])
|
|
print(t.sparse_categorical_crossentropy(Y).item())
|
|
```
|
|
"""
|
|
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
|
|
assert reduction in ("mean", "sum", "none"), "reduction must be one of ['mean', 'sum', 'none']"
|
|
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
|
|
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
|
|
y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
|
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
|
|
unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing)
|
|
# NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here)
|
|
return -(unreduced.sum() / loss_mask.sum() if reduction == "mean" else (unreduced.sum() if reduction == "sum" else unreduced))
|
|
|
|
def cross_entropy(self, Y:Tensor, reduction:ReductionStr="mean", label_smoothing:float=0.0) -> Tensor:
|
|
"""
|
|
Compute the cross entropy loss between input logits and target.
|
|
|
|
NOTE: `self` are logits and `Y` are the target labels or class probabilities.
|
|
|
|
See: https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([[-1, 2, -3], [1, -2, 3]])
|
|
Y = Tensor([1, 2])
|
|
print(t.cross_entropy(Y).item())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([[-1, 2, -3], [1, -2, 3]])
|
|
Y = Tensor([1, 2])
|
|
print(t.cross_entropy(Y, reduction='none').numpy())
|
|
```
|
|
"""
|
|
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
|
|
Y = Y.one_hot(num_classes=cast(int, self.shape[1])) if Y.ndim < 2 else Y
|
|
Y = (1 - label_smoothing)*Y + label_smoothing / cast(int, Y.shape[1])
|
|
ret = -self.log_softmax(axis=1).mul(Y).sum(axis=1)
|
|
return ret._do_reduction(reduction)
|
|
|
|
# ***** Tensor Properties *****
|
|
|
|
@property
|
|
def ndim(self) -> int:
|
|
"""
|
|
Returns the number of dimensions in the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([[1, 2], [3, 4]])
|
|
print(t.ndim)
|
|
```
|
|
"""
|
|
return len(self.shape)
|
|
|
|
def numel(self) -> sint:
|
|
"""
|
|
Returns the total number of elements in the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
|
print(t.numel())
|
|
```
|
|
"""
|
|
return prod(self.shape)
|
|
|
|
def element_size(self) -> int:
|
|
"""
|
|
Returns the size in bytes of an individual element in the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([5], dtype=dtypes.int16)
|
|
print(t.element_size())
|
|
```
|
|
"""
|
|
return self.dtype.itemsize
|
|
|
|
def nbytes(self) -> int:
|
|
"""
|
|
Returns the total number of bytes of all elements in the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([8, 9], dtype=dtypes.float)
|
|
print(t.nbytes())
|
|
```
|
|
"""
|
|
return self.numel() * self.element_size()
|
|
|
|
def is_floating_point(self) -> bool:
|
|
"""
|
|
Returns `True` if the tensor contains floating point types, i.e. is one of `dtype.float64`, `dtype.float32`,
|
|
`dtype.float16`, `dtype.bfloat16`.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([8, 9], dtype=dtypes.float32)
|
|
print(t.is_floating_point())
|
|
```
|
|
"""
|
|
return dtypes.is_float(self.dtype)
|
|
|
|
def size(self, dim:Optional[int]=None) -> Union[sint, Tuple[sint, ...]]:
|
|
"""
|
|
Return the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([[4, 5, 6], [7, 8, 9]])
|
|
print(t.size())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.size(dim=1))
|
|
```
|
|
"""
|
|
return self.shape if dim is None else self.shape[dim]
|
|
|
|
# ***** cast ops *****
|
|
|
|
def llvm_bf16_cast(self, dtype:DTypeLike):
|
|
# hack for devices that don't support bfloat16
|
|
assert self.dtype == dtypes.bfloat16
|
|
return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
|
|
|
|
def cast(self, dtype:DTypeLike) -> Tensor:
|
|
"""
|
|
Casts `self` to the given `dtype`.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([-1, 2.5, 3], dtype=dtypes.float)
|
|
print(t.dtype, t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = t.cast(dtypes.int32)
|
|
print(t.dtype, t.numpy())
|
|
```
|
|
"""
|
|
return self if self.dtype == (dt:=to_dtype(dtype)) else F.Cast.apply(self, dtype=dt)
|
|
|
|
def bitcast(self, dtype:DTypeLike) -> Tensor:
|
|
"""
|
|
Bitcasts `self` to the given `dtype` of the same itemsize.
|
|
|
|
`self` must not require a gradient.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
|
|
print(t.dtype, t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = t.bitcast(dtypes.uint32)
|
|
print(t.dtype, t.numpy())
|
|
```
|
|
"""
|
|
if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
|
|
dt = to_dtype(dtype)
|
|
if (not isinstance(self.device, str) or not self.device.startswith("DISK")) and (ns:=dt.itemsize) != (os:=self.dtype.itemsize):
|
|
if (self.shape[-1]*os) % ns != 0: raise RuntimeError("unsupported size in bitcast")
|
|
new_uint, old_uint = to_dtype(f"uint{8*ns}"), to_dtype(f"uint{8*os}")
|
|
tmp = self.bitcast(old_uint)
|
|
if ns > os: return functools.reduce(Tensor.add, (tmp[..., i::ns//os].cast(new_uint) << 8*i*os for i in range(ns//os))).bitcast(dtype)
|
|
return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype)
|
|
return F.Cast.apply(self, dtype=dt, bitcast=True) if self.dtype != dt else self
|
|
|
|
def float(self) -> Tensor:
|
|
"""
|
|
Convenience method to cast `self` to a `float32` Tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
|
|
print(t.dtype, t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = t.float()
|
|
print(t.dtype, t.numpy())
|
|
```
|
|
"""
|
|
return self.cast(dtypes.float32)
|
|
|
|
def half(self) -> Tensor:
|
|
"""
|
|
Convenience method to cast `self` to a `float16` Tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
|
|
print(t.dtype, t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = t.half()
|
|
print(t.dtype, t.numpy())
|
|
```
|
|
"""
|
|
return self.cast(dtypes.float16)
|
|
|
|
def int(self) -> Tensor:
|
|
"""
|
|
Convenience method to cast `self` to a `int32` Tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([-1.5, -0.5, 0.0, 0.5, 1.5])
|
|
print(t.dtype, t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = t.int()
|
|
print(t.dtype, t.numpy())
|
|
```
|
|
"""
|
|
return self.cast(dtypes.int32)
|
|
|
|
def bool(self) -> Tensor:
|
|
"""
|
|
Convenience method to cast `self` to a `bool` Tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([-1, 0, 1])
|
|
print(t.dtype, t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = t.bool()
|
|
print(t.dtype, t.numpy())
|
|
```
|
|
"""
|
|
return self.cast(dtypes.bool)
|
|
|
|
# *** image Tensor function replacements ***
|
|
|
|
def image_dot(self, w:Tensor, acc_dtype=None):
|
|
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
|
n1, n2 = len(self.shape), len(w.shape)
|
|
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
|
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
|
|
bs, groups, cin, cout = prod(self.shape[0:-2]), prod(w.shape[0:-2]), w.shape[-2], w.shape[-1]
|
|
out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout, )
|
|
|
|
# NOTE: with NHWC we can remove the transposes
|
|
# bs x groups*cin x H x W
|
|
cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1))
|
|
# groups*cout x cin x H, W
|
|
cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
|
|
return cx.image_conv2d(cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
|
|
|
|
def image_conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None):
|
|
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
|
|
|
|
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
|
|
x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)
|
|
|
|
# hack for non multiples of 4 on cin
|
|
if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
|
|
x = x.reshape(bs, groups, cin, iy, ix) # do this always?
|
|
added_input_channels = 4 - (cin % 4)
|
|
w = w.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(w.ndim)))
|
|
x = x.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(x.ndim)))
|
|
cin = cin + added_input_channels
|
|
x = x.reshape(bs, groups*cin, iy, ix)
|
|
|
|
# hack for non multiples of 4 on rcout
|
|
added_output_channels = 0
|
|
if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
|
|
added_output_channels = 4 - (rcout % 4)
|
|
rcout += added_output_channels
|
|
cout = groups * rcout
|
|
w = w.pad(tuple((0, added_output_channels) if i == 1 else None for i in range(w.ndim)))
|
|
|
|
# packed (note: flipping bs and iy would make the auto-padding work)
|
|
x = x.permute(0,2,3,1)
|
|
cin_last = iy == 1 and ix == 1
|
|
if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1)
|
|
elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3)
|
|
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1)
|
|
|
|
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
|
|
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
|
|
x, w = x.contiguous(), w.contiguous()
|
|
|
|
# expand out
|
|
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
|
|
cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
|
|
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
|
|
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
|
|
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
|
|
|
|
# prepare input
|
|
x = x.permute(0,3,4,5,1,2).pad2d(self._padding2d(padding, 2))._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
|
|
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
|
|
|
|
# prepare weights
|
|
w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))
|
|
|
|
# the conv!
|
|
ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), acc_dtype=acc_dtype)
|
|
|
|
# undo hack for non multiples of 4 on C.rcout
|
|
if added_output_channels != 0:
|
|
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
|
|
cout = groups * (rcout - added_output_channels)
|
|
|
|
# NCHW output
|
|
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
|
|
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
|
|
|
|
if IMAGE:
|
|
# if IMAGE>0 we install these replacement functions in Tensor (hack!)
|
|
setattr(Tensor, "conv2d", Tensor.image_conv2d)
|
|
setattr(Tensor, "dot", Tensor.image_dot)
|
|
|
|
def _metadata_wrapper(fn):
|
|
def _wrapper(*args, **kwargs):
|
|
if _METADATA.get() is not None: return fn(*args, **kwargs)
|
|
|
|
if TRACEMETA >= 2:
|
|
caller_frame = sys._getframe(frame := 1)
|
|
caller_module = caller_frame.f_globals.get("__name__", None)
|
|
caller_func = caller_frame.f_code.co_name
|
|
if caller_module is None: return fn(*args, **kwargs)
|
|
|
|
# if its called from nn we want to step up frames until we are out of nn
|
|
while caller_module.startswith("tinygrad.nn") and "optim" not in caller_module:
|
|
caller_frame = sys._getframe(frame := frame + 1)
|
|
caller_module = caller_frame.f_globals.get("__name__", None)
|
|
if caller_module is None: return fn(*args, **kwargs)
|
|
|
|
# if its called from a lambda in tinygrad we want to look two more frames up
|
|
if caller_module.startswith("tinygrad") and caller_func == "<lambda>": caller_frame = sys._getframe(frame := frame + 2)
|
|
caller_module = caller_frame.f_globals.get("__name__", None)
|
|
if caller_module is None: return fn(*args, **kwargs)
|
|
caller_func = caller_frame.f_code.co_name
|
|
caller_lineno = caller_frame.f_lineno
|
|
|
|
caller = f"{caller_module}:{caller_lineno}::{caller_func}"
|
|
else: caller = ""
|
|
|
|
token = _METADATA.set(Metadata(name=fn.__name__, caller=caller))
|
|
ret = fn(*args, **kwargs)
|
|
_METADATA.reset(token)
|
|
return ret
|
|
return _wrapper
|
|
|
|
if TRACEMETA >= 1:
|
|
for name, fn in inspect.getmembers(Tensor, inspect.isfunction):
|
|
if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential"]: continue
|
|
setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn)))
|