mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
* initialize Tensor grad same type as self * also test different default float * check dtype + try/finally * don't test_gradient_dtype if f16 is not supported * fix bad merge --------- Co-authored-by: chenyu <chenyu@fastmail.com>
1053 lines
63 KiB
Python
1053 lines
63 KiB
Python
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
|
from __future__ import annotations
|
|
import time, math, itertools, functools
|
|
from contextlib import ContextDecorator
|
|
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, Dict, DefaultDict, cast, get_args
|
|
from collections import defaultdict
|
|
import numpy as np
|
|
|
|
from tinygrad.dtype import DType, dtypes, ImageDType, Scalar, least_upper_float, least_upper_dtype, cast_scalar
|
|
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv
|
|
from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
|
|
from tinygrad.lazy import LazyBuffer
|
|
from tinygrad.features.multi import MultiLazyBuffer
|
|
from tinygrad.ops import LoadOps
|
|
from tinygrad.device import Buffer, Device
|
|
from tinygrad.shape.symbolic import sint
|
|
from tinygrad.realize import run_schedule, create_schedule
|
|
|
|
# **** start with two base classes, Tensor and Function ****
|
|
|
|
class Function:
|
|
def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor):
|
|
self.device = device
|
|
self.needs_input_grad = [t.requires_grad for t in tensors]
|
|
self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
|
|
if self.requires_grad: self.parents = tensors
|
|
|
|
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
|
|
def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")
|
|
|
|
@classmethod
|
|
def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
|
|
ctx = fxn(x[0].device, *x)
|
|
ret = Tensor.__new__(Tensor)
|
|
ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
|
|
ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
|
|
return ret
|
|
|
|
import tinygrad.mlops as mlops
|
|
|
|
def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Tuple[LazyBuffer, ...]=()):
|
|
if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src)
|
|
return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None)
|
|
|
|
def _fromcpu(x: np.ndarray) -> LazyBuffer:
|
|
return LazyBuffer.loadop(LoadOps.EMPTY, x.shape, dtypes.from_np(x.dtype), "EXT",
|
|
_buf=Buffer("EXT", 0, dtypes.from_np(x.dtype), (memoryview(bytearray()), None)) if x.size == 0 else \
|
|
Buffer("EXT", prod(x.shape), dtypes.from_np(x.dtype), (flat_mv(np.require(x, requirements='C').data), x)))
|
|
|
|
def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str, Tuple[str, ...]]) -> List[List[Tensor]]:
|
|
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device) for m in mat], dim=dim)
|
|
for k in range(len(mat[0]))] for dim in range(dims)]
|
|
|
|
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
|
|
def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
|
|
# multiply mat_1 @ mat_2 @ t with foldable constants, where mat_i acts on vector t along dimension i; roughly kron(mat, mat) @ t
|
|
# due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic
|
|
t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims
|
|
# precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...)
|
|
matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device)
|
|
# multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t
|
|
ret = sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
|
|
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
|
|
return ret
|
|
|
|
class Tensor:
|
|
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
|
|
__deletable__ = ('_ctx',)
|
|
training: ClassVar[bool] = False
|
|
class train(ContextDecorator):
|
|
def __init__(self, mode:bool = True): self.mode = mode
|
|
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
|
|
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
|
|
|
|
no_grad: ClassVar[bool] = False
|
|
class inference_mode(ContextDecorator):
|
|
def __init__(self, mode:bool = True): self.mode = mode
|
|
def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
|
|
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
|
|
def __init__(self, data:Union[None, Scalar, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer],
|
|
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
|
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
|
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
|
# tensors have gradients, buffers do not
|
|
self.grad: Optional[Tensor] = None
|
|
|
|
# NOTE: this can be in three states. False and None: no gradient, True: gradient
|
|
# None (the default) will be updated to True if it's put in an optimizer
|
|
self.requires_grad: Optional[bool] = requires_grad
|
|
|
|
# internal variables used for autograd graph construction
|
|
self._ctx: Optional[Function] = None
|
|
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
|
elif isinstance(data, get_args(Scalar)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
|
|
elif isinstance(data, bytes): data = _fromcpu(np.frombuffer(data, np.uint8))
|
|
elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
|
|
elif isinstance(data, list):
|
|
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool
|
|
elif d and all_int(d): dtype = dtype or dtypes.default_int
|
|
else: dtype = dtype or dtypes.default_float
|
|
if dtype == dtypes.bfloat16: data = Tensor(_fromcpu(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
|
|
else: data = _fromcpu(np.array(data, dtype.np))
|
|
elif isinstance(data, np.ndarray):
|
|
if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
|
|
else: data = _fromcpu(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
|
|
|
|
# data is a LazyBuffer, but it might be on the wrong device
|
|
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
|
|
if isinstance(device, tuple):
|
|
# TODO: what if it's a MultiLazyBuffer on other devices?
|
|
self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = MultiLazyBuffer.from_sharded(data, device, None) if isinstance(data, LazyBuffer) else data
|
|
else:
|
|
self.lazydata = data if data.device == device else data.copy_to_device(device)
|
|
|
|
def __repr__(self): return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
|
|
|
|
# Python has a non moving GC, so this should be okay
|
|
def __hash__(self): return id(self)
|
|
|
|
def __bool__(self): raise TypeError("__bool__ on Tensor is not defined")
|
|
|
|
@property
|
|
def device(self) -> Union[str, Tuple[str, ...]]: return self.lazydata.device
|
|
|
|
@property
|
|
def shape(self) -> Tuple[sint, ...]: return self.lazydata.shape
|
|
|
|
@property
|
|
def dtype(self) -> DType: return self.lazydata.dtype
|
|
|
|
# ***** data handlers ****
|
|
|
|
@staticmethod
|
|
def corealize(lst:Iterable[Tensor]):
|
|
run_schedule(create_schedule(flatten([x.lazydata.lbs if isinstance(x.lazydata, MultiLazyBuffer) else [x.lazydata] for x in lst])))
|
|
|
|
def realize(self) -> Tensor:
|
|
Tensor.corealize([self])
|
|
return self
|
|
|
|
def replace(self, x:Tensor) -> Tensor:
|
|
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
|
|
assert not x.requires_grad and getattr(self, '_ctx', None) is None
|
|
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
|
|
self.lazydata = x.lazydata
|
|
return self
|
|
|
|
def assign(self, x) -> Tensor:
|
|
# TODO: this is a hack for writing to DISK. remove with working assign
|
|
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
|
if x.__class__ is not Tensor: x = Tensor(x, device="EXT", dtype=self.dtype)
|
|
self.contiguous().realize().lazydata.base.realized.copyin(x.numpy().data)
|
|
return self
|
|
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
|
|
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
|
|
if self.lazydata is x.lazydata: return self # a self assign is a NOOP
|
|
# NOTE: we allow cross device assign
|
|
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
|
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
|
|
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
|
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
|
|
assert not x.requires_grad # self requires_grad is okay?
|
|
if not self.lazydata.is_realized(): return self.replace(x)
|
|
self.lazydata = self.lazydata.assign(x.lazydata)
|
|
return self
|
|
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
|
|
|
def _data(self) -> memoryview:
|
|
if 0 in self.shape: return memoryview(bytearray(0))
|
|
t = self if isinstance(self.device, str) else self.to(self.device[0]) # deal with multitensor
|
|
return cast(Buffer, t.cast(t.dtype.scalar()).contiguous().realize().lazydata.base.realized).as_buffer()
|
|
|
|
def data(self) -> memoryview:
|
|
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
|
|
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
|
return self._data().cast(self.dtype.fmt, self.shape if len(self.shape) else (1,))
|
|
def item(self) -> Scalar:
|
|
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
|
|
assert self.numel() == 1, "must have one element for item"
|
|
return self._data().cast(self.dtype.fmt)[0]
|
|
def numpy(self) -> np.ndarray:
|
|
if self.dtype == dtypes.bfloat16: return self.float().numpy()
|
|
assert self.dtype.np is not None, f"no np dtype for {self.dtype}"
|
|
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
|
return np.frombuffer(self._data(), dtype=self.dtype.np).reshape(self.shape)
|
|
|
|
def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor:
|
|
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
|
if device is None or device == self.device: return self
|
|
if not isinstance(device, str): return self.shard(device)
|
|
ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
|
|
if self.grad is not None: ret.grad = self.grad.to(device)
|
|
if hasattr(self, '_ctx'): ret._ctx = self._ctx
|
|
return ret
|
|
|
|
def to_(self, device:Optional[Union[str, Tuple[str, ...]]]):
|
|
real = self.to(device)
|
|
# TODO: is this assign?
|
|
if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
|
|
self.lazydata = real.lazydata
|
|
|
|
def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None) -> Tensor:
|
|
assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
|
|
canonical_devices = tuple(Device.canonicalize(x) for x in devices)
|
|
if axis is not None and axis < 0: axis += len(self.shape)
|
|
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices, requires_grad=self.requires_grad)
|
|
|
|
def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None):
|
|
self.lazydata = self.shard(devices, axis).lazydata
|
|
return self
|
|
|
|
# ***** creation llop entrypoint *****
|
|
|
|
@staticmethod
|
|
def _loadop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
|
|
if isinstance(device, tuple):
|
|
return Tensor(MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \
|
|
for d in device], None), device, dtype, **kwargs)
|
|
return Tensor(LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), device, dtype, **kwargs)
|
|
|
|
@staticmethod
|
|
def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, argfix(*shape), **kwargs)
|
|
|
|
_seed: int = int(time.time())
|
|
_rng_counter: Optional[Tensor] = None
|
|
@staticmethod
|
|
def manual_seed(seed=0): Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False)
|
|
|
|
@staticmethod
|
|
def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, **kwargs):
|
|
if Tensor._rng_counter is None: Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
|
|
if not THREEFRY.value:
|
|
if dtype == dtypes.bfloat16:
|
|
return Tensor.rand(*shape, **kwargs, device=device, dtype=dtypes.float).cast(dtypes.bfloat16)
|
|
return Tensor._loadop(LoadOps.CUSTOM, argfix(*shape), arg=custom_random, device=device, dtype=dtype, **kwargs)
|
|
|
|
# threefry
|
|
if (num := prod((shape:=argfix(*shape)))) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
|
|
counts = (Tensor.arange(num, device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device)).realize().pad(((0,num%2),))
|
|
Tensor._rng_counter.assign(Tensor._rng_counter + num).realize()
|
|
|
|
rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
|
|
ks = [0x0, Tensor._seed ^ 0x0 ^ 0x1BD11BDA, Tensor._seed]
|
|
|
|
x = [(c := counts.chunk(2))[0] + ks[-1], c[1] + ks[0]]
|
|
for i in range(5):
|
|
for r in rotations[i % 2]: x[0], x[1] = (x0 := x[0] + x[1]), x0 ^ ((x[1] * (2 ** r)) + (x[1].div(2 ** (32 - r), upcast=False)))
|
|
x = [(x[0] + ks[i % 3]), (x[1] + ks[(i + 1) % 3] + i + 1)]
|
|
out = x[0].cat(x[1])[:num].div(2 ** 8, upcast=False).cast(dtypes.float32).div(2 ** 24)
|
|
out = out.reshape(shape).cast(dtypes.default_float if dtype is None else dtype)
|
|
out.requires_grad = kwargs.get("requires_grad")
|
|
return out.contiguous()
|
|
|
|
# ***** creation helper functions *****
|
|
|
|
@staticmethod
|
|
def full(shape:Tuple[sint, ...], fill_value:Scalar, **kwargs):
|
|
return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
|
|
|
|
@staticmethod
|
|
def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0.0, **kwargs)
|
|
|
|
@staticmethod
|
|
def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1.0, **kwargs)
|
|
|
|
@staticmethod
|
|
def arange(start, stop=None, step=1, **kwargs):
|
|
if stop is None: stop, start = start, 0
|
|
assert all(isinstance(s, (int, float)) for s in (start, stop, step)), "symbolic arange not supported"
|
|
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
|
|
return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
|
|
|
|
@staticmethod
|
|
def eye(dim:int, **kwargs):
|
|
return Tensor.ones((dim,1),**kwargs).pad((None,(0,dim))).flatten().shrink(((0,dim*dim),)).reshape(dim, dim)
|
|
|
|
def full_like(self, fill_value:Scalar, **kwargs):
|
|
return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
|
|
def zeros_like(self, **kwargs): return self.full_like(0, **kwargs)
|
|
def ones_like(self, **kwargs): return self.full_like(1, **kwargs)
|
|
|
|
# ***** rng hlops *****
|
|
|
|
@staticmethod
|
|
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
|
|
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
|
|
src = Tensor.rand((2, *argfix(*shape)), **kwargs)
|
|
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)
|
|
|
|
@staticmethod
|
|
def randint(*shape, low=0, high=10, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=low, high=high, dtype=dtypes.int32, **kwargs)
|
|
|
|
@staticmethod
|
|
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean
|
|
|
|
@staticmethod
|
|
def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
|
|
dtype = kwargs.pop("dtype", dtypes.default_float)
|
|
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
|
|
|
|
@staticmethod
|
|
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
|
|
|
|
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
|
|
@staticmethod
|
|
def glorot_uniform(*shape, **kwargs) -> Tensor:
|
|
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(argfix(*shape)[0]+prod(argfix(*shape)[1:])))**0.5)
|
|
|
|
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
|
|
@staticmethod
|
|
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
|
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
|
|
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
|
|
|
|
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
|
|
@staticmethod
|
|
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
|
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
|
|
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
|
|
|
|
def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor:
|
|
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
|
|
assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
|
|
weight = self.unsqueeze(0) if self.ndim == 1 else self
|
|
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
|
|
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1, device=self.device)
|
|
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
|
|
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
|
|
|
|
# ***** toposort and backward pass *****
|
|
|
|
def deepwalk(self):
|
|
def _deepwalk(node, visited):
|
|
visited.add(node)
|
|
if getattr(node, "_ctx", None):
|
|
for i in node._ctx.parents:
|
|
if i not in visited: yield from _deepwalk(i, visited)
|
|
yield node
|
|
return list(_deepwalk(self, set()))
|
|
|
|
def backward(self) -> Tensor:
|
|
assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})"
|
|
|
|
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
|
|
# this is "implicit gradient creation"
|
|
self.grad = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
|
|
|
|
for t0 in reversed(self.deepwalk()):
|
|
if t0.grad is None: raise RuntimeError("tensor has no grad")
|
|
grads = t0._ctx.backward(t0.grad.lazydata)
|
|
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
|
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
|
|
for t, g in zip(t0._ctx.parents, grads):
|
|
if g is not None and t.requires_grad:
|
|
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
|
t.grad = g if t.grad is None else (t.grad + g)
|
|
del t0._ctx
|
|
return self
|
|
|
|
# ***** movement mlops *****
|
|
|
|
def reshape(self, shape, *args) -> Tensor:
|
|
new_shape = argfix(shape, *args)
|
|
new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])
|
|
return mlops.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
|
|
def expand(self, shape, *args) -> Tensor:
|
|
new_shape = tuple([x if x != -1 and x is not None else s for s,x in zip(self.shape, argfix(shape, *args))])
|
|
return mlops.Expand.apply(self, shape=new_shape) if new_shape != self.shape else self
|
|
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
|
|
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
|
|
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
|
|
if all(x is None or x == (0,s) for x,s in zip(arg, self.shape)): return self
|
|
return mlops.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape)))
|
|
def pad(self, arg:Tuple[Optional[Tuple[sint, sint]], ...], value:float=0.0) -> Tensor:
|
|
if all(x is None or x == (0,0) for x in arg): return self
|
|
ret = mlops.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg)))
|
|
return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value)
|
|
|
|
# ***** movement hlops *****
|
|
|
|
# Supported Indexing Implementations:
|
|
# 1. Int indexing (no copy)
|
|
# - for all dims where there's int, shrink -> reshape
|
|
# - negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
|
|
# - X = Tensor.rand(4,5,9); X[2,-2] shrinks the Tensor to X.shrink(((2, 3), (3, 4), (0, 9))) -> X.shape=(1,1,9)
|
|
# - Then we reshape (collapse) the int dim away such that for X: (1,1,9) -> (9,)
|
|
# 2. Slice indexing (no copy)
|
|
# - for all dims where slice is start:end:stride, shrink -> Optional[flip] -> pad -> reshape -> shrink
|
|
# - first shrink the Tensor to X.shrink(((start, end),))
|
|
# - then we apply stride through Optional[flip] -> pad -> reshape -> shrink
|
|
# - flip where dim value is negative
|
|
# - pad 0's on dims such that reshaping [dim_size_padded] -> [dim_size_padded // stride, stride] is possible
|
|
# - shrink [dim_size_padded // stride, stride] -> [dim_size_padded // stride, 1]
|
|
# - reshape [dim_size_padded // stride, 1] -> [dim_size_padded // stride] and now you have your stride
|
|
# 3. None indexing (no copy)
|
|
# - reshape (inject) a dim at the dim where there's None
|
|
# 4. Tensor indexing (copy)
|
|
# - use Tensor.arange == tensor_index to create a mask
|
|
# - apply mask to self by mask * self for dims where index is a tensor
|
|
# - (mask * self).sum(dim) to reduce to correct shape
|
|
# Tiny Things:
|
|
# 1. Supported indices: Union[int, slice, Tensor, None, List, Tuple, Ellipsis]
|
|
# - for any list, List[Union[List, Tuple, int]], must have homogeneous shape
|
|
# - for any tuple, Tuple[Union[List, Tuple, int]], must have homogeneous shape
|
|
# 2. Bool indexing is not supported
|
|
# 3. Out of bounds Tensor indexing results in 0
|
|
# - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are OOB
|
|
def __getitem__(self, indices) -> Tensor:
|
|
# 1. indices normalization and validation
|
|
# treat internal tuples and lists as Tensors and standardize indices to list type
|
|
if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, self.device, requires_grad=False)]
|
|
elif isinstance(indices, (tuple, list)):
|
|
indices = [Tensor(list(i), self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices]
|
|
else: indices = [indices]
|
|
|
|
# turn scalar Tensors into const val for int indexing if possible
|
|
indices = [self._to_const_val(i) if isinstance(i, Tensor) else i for i in indices]
|
|
# move Tensor indices to the same device as self
|
|
indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices]
|
|
|
|
# filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
|
|
ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis]
|
|
fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
|
|
num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
|
|
indices[fill_idx:fill_idx+1] = [slice(None)] * (len(self.shape) - num_indices)
|
|
|
|
# use Dict[type, List[dimension]] to track elements in indices
|
|
type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
|
|
|
|
# record None for dimension injection later and filter None and record rest of indices
|
|
type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
|
|
indices_filtered = [v for v in indices if v is not None]
|
|
for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
|
|
|
|
for index_type in type_dim:
|
|
if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported")
|
|
if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
|
|
if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}")
|
|
|
|
# 2. basic indexing, uses only movement ops (no copy)
|
|
# currently indices_filtered: Tuple[Union[slice, int, Tensor], ...]
|
|
# turn indices in indices_filtered to Tuple[shrink_arg, strides]
|
|
for dim in type_dim[int]:
|
|
if (index := indices_filtered[dim]) >= (size := self.shape[dim]) or index < -size:
|
|
raise IndexError(f"{index=} is out of bounds on {dim=} with {size=}")
|
|
indices_filtered[dim] = ((index, index+1), 1) if index >= 0 else ((size+index, size+index+1), 1)
|
|
for dim in type_dim[slice]:
|
|
if (index := indices_filtered[dim]).step == 0: raise ValueError(f"{index=} on {dim=} cannot have 0 as step")
|
|
s, e, st = index.indices(self.shape[dim])
|
|
indices_filtered[dim] = ((0, 0) if (st * (e - s)) < 0 else (s, e) if st > 0 else (e+1, s+1), st)
|
|
# record tensors and skip all Tensor dims for basic indexing
|
|
tensor_index: List[Tensor] = []
|
|
for dim in type_dim[Tensor]:
|
|
tensor_index.append(index := indices_filtered[dim])
|
|
if not dtypes.is_int(index.dtype): raise IndexError(f"{index.dtype=} on {dim=} is not supported, only int tensor indexing is supported")
|
|
indices_filtered[dim] = ((0, self.shape[dim]), 1)
|
|
|
|
new_slice, strides = ((),()) if not indices_filtered else zip(*indices_filtered)
|
|
ret = self.shrink(new_slice).flip(tuple(i for i, s in enumerate(strides) if s < 0))
|
|
if any(abs(s) != 1 for s in strides):
|
|
strides = tuple(abs(s) for s in strides)
|
|
ret = ret.pad(tuple((0, round_up(sh, s) - sh) for s, sh in zip(strides, ret.shape)))
|
|
ret = ret.reshape(tuple(flatten((sh // s, s) for s, sh in zip(strides, ret.shape))))
|
|
ret = ret.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in ret.shape[::2]))).reshape(ret.shape[::2])
|
|
|
|
# inject 1 for dim where it's None and collapse dim for int
|
|
new_shape = list(ret.shape)
|
|
for dim in type_dim[None]: new_shape.insert(dim, 1)
|
|
for dim in (dims_collapsed := tuple(dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int]))): new_shape.pop(dim)
|
|
|
|
ret = ret.reshape(new_shape)
|
|
assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}"
|
|
|
|
# 3. advanced indexing (copy)
|
|
if type_dim[Tensor]:
|
|
# calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim
|
|
def calc_dim(tensor_dim:int) -> int:
|
|
return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d) + sum(1 for d in type_dim[None] if tensor_dim >= d)
|
|
|
|
# track tensor_dim and tensor_index using a dict
|
|
# calc_dim to get dim and use that to normalize the negative tensor indices
|
|
idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(type_dim[Tensor],tensor_index)}
|
|
|
|
# compute sum_dim, arange, and idx
|
|
max_idx_dim, first_dim, last_dim = max(i.ndim for i in idx.values()), min(idx.keys()), max(idx.keys())
|
|
sum_dim = tuple(d if n==0 else d+max_idx_dim-n for n,d in enumerate(idx.keys()))
|
|
arange = [Tensor.arange(ret.shape[d], requires_grad=False, device=self.device).reshape(ret.shape[d], *[1]*(ret.ndim+max_idx_dim-n-sd-1)) \
|
|
for n,(sd,d) in enumerate(zip(sum_dim, idx.keys()))]
|
|
reshaped_idx = [i.reshape(i.shape + (1,)*(ret.ndim - first_dim - (n or 1))) for n,i in enumerate(idx.values())]
|
|
ret = ret.reshape(ret.shape[:first_dim+1] + (1,)*max_idx_dim + ret.shape[first_dim+1:])
|
|
|
|
# iteratively eq -> mul -> sum fancy index
|
|
try:
|
|
for a,i,sd in zip(arange, reshaped_idx, sum_dim): ret = (a==i).mul(ret).sum(sd)
|
|
except AssertionError as exc: raise IndexError("cannot broadcast indices") from exc
|
|
|
|
# special permute case
|
|
if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
|
|
ret_dims = list(range(ret.ndim))
|
|
ret = ret.permute(ret_dims[first_dim:first_dim+max_idx_dim] + ret_dims[:first_dim] + ret_dims[first_dim+max_idx_dim:])
|
|
return ret
|
|
|
|
def __setitem__(self,indices,v): return self.__getitem__(indices).assign(v)
|
|
|
|
# NOTE: using slice is discouraged and things should migrate to pad and shrink
|
|
def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
|
|
arg_ = tuple(a if a is not None else (0, s) for s,a in zip(self.shape, arg))
|
|
padding = tuple((max(0, -l), max(0, r-s)) for s,(l,r) in zip(self.shape, arg_))
|
|
return self.pad(padding, value=value).shrink(tuple((l + pl, r + pl) for (l,r),(pl,_) in zip(arg_, padding)))
|
|
|
|
def gather(self:Tensor, idx:Tensor, dim:int) -> Tensor:
|
|
assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim"
|
|
assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape"
|
|
dim = self._resolve_dim(dim)
|
|
idx = idx.to(self.device).transpose(ax1=dim, ax2=0).unsqueeze(-1)
|
|
permarg = list(range(self.ndim))
|
|
permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]]
|
|
return ((idx == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(
|
|
tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim)
|
|
|
|
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
|
dim = self._resolve_dim(dim)
|
|
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
|
|
catargs = [self, *args]
|
|
cat_dims = [s.shape[dim] for s in catargs]
|
|
cat_dim_cumsum = [0, *itertools.accumulate(cat_dims)]
|
|
slc:List[List[Optional[Tuple[sint, sint]]]] = [[None for _ in self.shape] for _ in catargs]
|
|
for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d)
|
|
return functools.reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
|
|
|
|
@staticmethod
|
|
def stack(tensors:Sequence[Tensor], dim:int=0) -> Tensor:
|
|
unsqueezed_tensors = [tensor.unsqueeze(dim) for tensor in tensors]
|
|
# checks for shapes and number of dimensions delegated to cat
|
|
return unsqueezed_tensors[0].cat(*unsqueezed_tensors[1:], dim=dim)
|
|
|
|
def repeat(self, repeats:Sequence[int]) -> Tensor:
|
|
base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
|
|
new_shape = [x for b in base_shape for x in [1, b]]
|
|
expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
|
|
final_shape = [r*s for r,s in zip(repeats, base_shape)]
|
|
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
|
|
|
|
def _resolve_dim(self, dim:int, *, outer:bool=False) -> int:
|
|
if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer):
|
|
raise IndexError(f"{dim=} out of range {[-max(1, self.ndim+outer), max(1, self.ndim+outer)-1]}")
|
|
return dim + self.ndim+outer if dim < 0 else dim
|
|
|
|
def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
|
|
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
|
dim = self._resolve_dim(dim)
|
|
if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
|
|
assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
|
|
return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
|
|
|
|
def chunk(self, num:int, dim:int=0) -> List[Tensor]:
|
|
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
|
assert num > 0, f"expect num to be greater than 0, got: {num}"
|
|
dim = self._resolve_dim(dim)
|
|
return list(self.split(math.ceil(self.shape[dim]/num) if self.shape[dim] else [0]*num, dim=dim))
|
|
|
|
def squeeze(self, dim:Optional[int]=None) -> Tensor:
|
|
if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
|
|
dim = self._resolve_dim(dim)
|
|
return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
|
|
|
|
def unsqueeze(self, dim:int) -> Tensor:
|
|
dim = self._resolve_dim(dim, outer=True)
|
|
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
|
|
|
# (padding_left, padding_right, padding_top, padding_bottom)
|
|
def pad2d(self, padding:Sequence[int], value:float=0) -> Tensor:
|
|
slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1]
|
|
return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
|
|
|
|
@property
|
|
def T(self) -> Tensor: return self.transpose()
|
|
def transpose(self, ax1=1, ax2=0) -> Tensor:
|
|
order = list(range(self.ndim))
|
|
order[ax1], order[ax2] = order[ax2], order[ax1]
|
|
return self.permute(order)
|
|
def flatten(self, start_dim=0, end_dim=-1):
|
|
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
|
|
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
|
|
def unflatten(self, dim:int, sizes:Tuple[int,...]):
|
|
dim = self._resolve_dim(dim)
|
|
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
|
|
|
|
# ***** reduce ops *****
|
|
|
|
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False) -> Tensor:
|
|
axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis))
|
|
axis_ = tuple(x if x >= 0 else x+len(self.shape) for x in axis_)
|
|
shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_)
|
|
ret = fxn.apply(self, axis=axis_)
|
|
return ret if keepdim else ret.reshape(shape=shape)
|
|
|
|
def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None):
|
|
if acc_dtype is None: acc_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \
|
|
least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \
|
|
least_upper_dtype(self.dtype, dtypes.float)
|
|
# cast back to float16 or bfloat16 to match torch / jax behavior, but we use float for acc
|
|
output_dtype = self.dtype if self.dtype in (dtypes.float16, dtypes.bfloat16) else acc_dtype
|
|
return self.cast(acc_dtype)._reduce(mlops.Sum, axis, keepdim).cast(output_dtype)
|
|
|
|
def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim)
|
|
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
|
|
|
|
def mean(self, axis=None, keepdim=False):
|
|
assert all_int(self.shape), "does not support symbolic shape"
|
|
out = self.sum(axis=axis, keepdim=keepdim)
|
|
return out.div(prod(self.shape) / prod(out.shape)) if 0 not in out.shape else out
|
|
def var(self, axis=None, keepdim=False, correction=1):
|
|
assert all_int(self.shape), "does not support symbolic shape"
|
|
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
|
|
return square_sum.div(max(0, prod(self.shape)/prod(square_sum.shape)-correction))
|
|
def std(self, axis=None, keepdim=False, correction=1): return self.var(axis, keepdim, correction).sqrt()
|
|
|
|
def _softmax(self, axis):
|
|
if len(self.shape) == 0:
|
|
assert axis in [-1, 0], f"{axis=} out of range of [-1, 0]"
|
|
axis = None
|
|
m = self - self.max(axis=axis, keepdim=True)
|
|
e = m.exp()
|
|
return m, e, e.sum(axis=axis, keepdim=True)
|
|
|
|
def softmax(self, axis=-1):
|
|
_, e, ss = self._softmax(axis)
|
|
return e.div(ss)
|
|
|
|
def log_softmax(self, axis=-1):
|
|
m, _, ss = self._softmax(axis)
|
|
return m - ss.log()
|
|
|
|
def argmax(self, axis=None, keepdim=False):
|
|
# NOTE: return the first index if there are multiple occurrences of the maximum values
|
|
if axis is None:
|
|
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape)
|
|
return (prod(self.shape) - idx.max() - 1).cast(dtypes.int32)
|
|
axis = self._resolve_dim(axis)
|
|
m = self == self.max(axis=axis, keepdim=True)
|
|
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
|
|
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.int32)
|
|
def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)
|
|
|
|
@staticmethod
|
|
def einsum(formula:str, *raw_xs) -> Tensor:
|
|
xs:Tuple[Tensor] = argfix(*raw_xs)
|
|
formula = formula.replace(" ", "")
|
|
inputs_str, output = formula.split("->") if "->" in formula else (formula, sorted(formula))
|
|
inputs = [x for x in cast(str,inputs_str).split(',')]
|
|
assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
|
|
|
|
# map the value of each letter in the formula
|
|
letter_val = sorted(merge_dicts([{letter:dim for letter, dim in zip(letters, tensor.shape)} for letters, tensor in zip(inputs, xs)]).items())
|
|
|
|
xs_:List[Tensor] = []
|
|
lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs]
|
|
for x,(order,letters) in zip(xs, [list(zip(*l)) for l in lhs]):
|
|
# permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
|
|
xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val]))
|
|
|
|
rhs_order, rhs_letters = tuple(zip(*sorted(enumerate(output), key=lambda e:e[1]))) or ([], [])
|
|
# sum over all axes that's not in the output, then permute to the output order
|
|
return functools.reduce(lambda a,b:a*b, xs_) \
|
|
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in rhs_letters]).permute(rhs_order)
|
|
|
|
# ***** processing ops *****
|
|
|
|
def _pool(self, k_:Tuple[sint, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
|
|
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
|
|
assert all_int(self.shape) and all_int(k_), f"does not support symbolic {self.shape=}, {k_=}"
|
|
s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
|
|
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
|
|
noop_, i_ = [None] * len(self.shape[:-len(k_)]), self.shape[-len(k_):]
|
|
if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
|
|
o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)]
|
|
# repeats such that we don't need padding
|
|
xup = self.repeat([1]*len(noop_) + [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)])
|
|
# slice by dilation
|
|
xup = xup.slice(noop_ + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)]).reshape(noop_ + flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
|
|
# handle stride
|
|
xup = xup.slice(noop_ + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_))).reshape(noop_ + flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
|
|
xup = xup.slice(noop_ + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_))).reshape(noop_ + flatten((k,o) for k,o in zip(k_, o_)))
|
|
# permute to move reduce to the end
|
|
return xup.permute(*range(len(noop_)), *[len(noop_)+i*2+1 for i in range(len(i_))], *[len(noop_)+i*2 for i in range(len(i_))])
|
|
# TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
|
|
o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
|
|
xup = self.slice(noop_ + [(0,o*s) for o,s in zip(o_, s_)])
|
|
xup = xup.reshape(noop_ + flatten(((o,s) for o,s in zip(o_, s_))))
|
|
xup = xup.slice(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
|
|
return xup.permute(*range(len(noop_)), *[len(noop_)+i*2 for i in range(len(i_))], *[len(noop_)+i*2+1 for i in range(len(i_))])
|
|
|
|
# NOTE: these work for more than 2D
|
|
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(
|
|
make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
|
|
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(
|
|
make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
|
|
|
|
def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
|
|
HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
|
|
x, w = self, weight.unflatten(0, (groups, -1)).permute(0,2,1,*trailing).flip(trailing)
|
|
stride = make_pair(stride, len(HW))
|
|
if any(s>1 for s in stride):
|
|
x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
|
|
x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
|
|
x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
|
|
x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
|
|
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(
|
|
zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW)))))))
|
|
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
|
|
|
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor:
|
|
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
|
|
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
|
|
if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
|
|
padding_ = [padding]*2*len(HW) if isinstance(padding, int) else (padding if len(padding) == 2*len(HW) else [p for p in padding for _ in range(2)][::-1]) # noqa: E501
|
|
|
|
# conv2d is a pooling op (with padding)
|
|
x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
|
|
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
|
|
if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not WINO:
|
|
# normal conv
|
|
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501
|
|
|
|
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
|
|
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, acc_dtype=acc_dtype).reshape(bs, cout, *oyx) # noqa: E501
|
|
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
|
|
|
|
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
|
|
winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]]
|
|
winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
|
|
winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order doubles compile time
|
|
|
|
# todo: stride == dilation
|
|
# use padding to round up to 4x4 output tiles
|
|
# (bs, cin_, tyx, HWI)
|
|
d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501
|
|
# move HW to the front: # (HWI, bs, cin_, tyx)
|
|
d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW)))
|
|
tyx = d.shape[-len(HWI):] # dim of tiling
|
|
|
|
g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
|
|
|
|
# compute 6x6 winograd tiles: GgGt, BtdB
|
|
# (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
|
|
gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
|
|
# (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
|
|
dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx)
|
|
|
|
# matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
|
|
ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), acc_dtype=acc_dtype), len(HW))
|
|
|
|
# interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
|
|
ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]])
|
|
# merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final
|
|
ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(tuple((0, s) for s in [bs, cout, *oyx]))
|
|
|
|
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
|
|
|
|
def dot(self, w:Tensor, acc_dtype:Optional[DType]=None) -> Tensor:
|
|
n1, n2 = len(self.shape), len(w.shape)
|
|
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
|
assert (L:=self.shape[-1]) == (R:=w.shape[-min(n2, 2)]), f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})"
|
|
x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
|
|
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
|
|
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype))
|
|
|
|
def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DType]=None) -> Tensor:
|
|
return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
|
|
|
|
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
|
|
pl_sz = self.shape[axis] - int(not _first_zero and self.shape[axis] != 0)
|
|
return self.transpose(axis,-1).pad2d((pl_sz,0))._pool((self.shape[axis] or 1,)).sum(-1).transpose(axis,-1)
|
|
def cumsum(self, axis:int=0) -> Tensor:
|
|
# TODO: someday the optimizer will find this on it's own
|
|
# for now this is a two stage cumsum
|
|
SPLIT = 256
|
|
if self.shape[axis] <= SPLIT*2: return self._cumsum(axis)
|
|
ret = self.transpose(axis,-1).pad2d((round_up(self.shape[axis], SPLIT)-self.shape[axis], 0))
|
|
ret = ret.unflatten(-1, (-1, SPLIT))._cumsum(-1)
|
|
base_add = ret[..., -1]._cumsum(-1, _first_zero=True)[..., :-1]
|
|
base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
|
|
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -self.shape[axis]:].transpose(axis,-1)
|
|
return fix(ret) + fix(base_add)
|
|
|
|
@staticmethod
|
|
def _tri(r:sint, c:sint, k:int=0, **kwargs) -> Tensor:
|
|
assert all_int((r,c)), "does not support symbolic"
|
|
if r == 0: return Tensor.zeros((r, c), **kwargs)
|
|
return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
|
|
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0)
|
|
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self)
|
|
|
|
# ***** mlops (unary) *****
|
|
|
|
def logical_not(self): return mlops.Eq.apply(*self._broadcasted(False))
|
|
def neg(self): return mlops.Neg.apply(self) if self.dtype != dtypes.bool else self.logical_not()
|
|
def contiguous(self): return mlops.Contiguous.apply(self)
|
|
def contiguous_backward(self): return mlops.ContiguousBackward.apply(self)
|
|
def log(self): return mlops.Log.apply(self.cast(least_upper_float(self.dtype)))
|
|
def log2(self): return self.log()/math.log(2)
|
|
def exp(self): return mlops.Exp.apply(self.cast(least_upper_float(self.dtype)))
|
|
def exp2(self): return mlops.Exp.apply(self*math.log(2))
|
|
def relu(self): return mlops.Relu.apply(self)
|
|
def sigmoid(self): return mlops.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
|
|
def sin(self): return mlops.Sin.apply(self.cast(least_upper_float(self.dtype)))
|
|
def sqrt(self): return mlops.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
|
|
def rsqrt(self): return self.reciprocal().sqrt()
|
|
def cos(self): return ((math.pi/2)-self).sin()
|
|
def tan(self): return self.sin() / self.cos()
|
|
|
|
# ***** math functions (unary) *****
|
|
|
|
def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).cast(self.dtype)
|
|
def ceil(self: Tensor) -> Tensor: return (self > (b := self.trunc())).where(b+1, b)
|
|
def floor(self: Tensor) -> Tensor: return (self < (b := self.trunc())).where(b-1, b)
|
|
def round(self: Tensor) -> Tensor:
|
|
return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
|
|
|
|
def square(self): return self*self
|
|
def clip(self, min_, max_): return self.maximum(min_).minimum(max_)
|
|
def abs(self): return self.relu() + (-self).relu()
|
|
def sign(self): return ((self.float()) / (self.float().abs() + 1e-12)).cast(self.dtype)
|
|
def reciprocal(self): return mlops.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
|
|
|
|
# ***** activation functions (unary) *****
|
|
|
|
def elu(self, alpha=1.0): return self.relu() - alpha*(1-self.exp()).relu()
|
|
def celu(self, alpha=1.0): return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
|
|
def swish(self): return self * self.sigmoid()
|
|
def silu(self): return self.swish() # The SiLU function is also known as the swish function.
|
|
def relu6(self): return self.relu() - (self-6).relu()
|
|
def hardswish(self): return self * (self+3).relu6() * (1/6)
|
|
def tanh(self): return 2.0 * ((2.0 * self).sigmoid()) - 1.0
|
|
def sinh(self): return (self.exp() - self.neg().exp()) / 2
|
|
def cosh(self): return (self.exp() + self.neg().exp()) / 2
|
|
def atanh(self): return ((1 + self)/(1 - self)).log() / 2
|
|
def asinh(self): return (self + (self.square() + 1).sqrt()).log()
|
|
def acosh(self): return (self + (self.square() - 1).sqrt()).log()
|
|
def hardtanh(self, min_val=-1, max_val=1): return self.clip(min_val, max_val)
|
|
def gelu(self): return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
|
|
def quick_gelu(self): return self * (self * 1.702).sigmoid()
|
|
def leakyrelu(self, neg_slope=0.01): return self.relu() - (-neg_slope*self).relu()
|
|
def mish(self): return self * self.softplus().tanh()
|
|
def softplus(self, beta=1): return (1/beta) * (1 + (self*beta).exp()).log()
|
|
def softsign(self): return self / (1 + self.abs())
|
|
|
|
# ***** broadcasted elementwise mlops *****
|
|
|
|
def _broadcasted(self, y:Union[Tensor, Scalar], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
|
|
x: Tensor = self
|
|
if not isinstance(y, Tensor):
|
|
# make y a Tensor
|
|
assert isinstance(y, (float, int, bool)), f"{type(y)=}, {y=}"
|
|
if isinstance(self.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
|
|
else: y_dtype = dtypes.from_py(y)
|
|
y = Tensor(cast_scalar(y, y_dtype), self.device, y_dtype, requires_grad=False)
|
|
|
|
if match_dtype:
|
|
output_dtype = least_upper_dtype(x.dtype, y.dtype)
|
|
x, y = x.cast(output_dtype), y.cast(output_dtype)
|
|
|
|
if reverse: x, y = y, x
|
|
|
|
# left pad shape with 1s
|
|
if len(y.shape) < len(x.shape): y = y.reshape((1,) * (len(x.shape) - len(y.shape)) + y.shape)
|
|
elif len(x.shape) < len(y.shape): x = x.reshape((1,) * (len(y.shape) - len(x.shape)) + x.shape)
|
|
|
|
broadcasted_shape = tuple(0 if xi==0 or yi==0 else max(xi, yi) for xi, yi in zip(x.shape, y.shape))
|
|
return x.expand(broadcasted_shape), y.expand(broadcasted_shape)
|
|
|
|
def _to_const_val(self, x:Union[Tensor, Scalar]) -> Union[Tensor, Scalar]:
|
|
# TODO: update with multi
|
|
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_contiguous_const() \
|
|
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
|
|
|
|
def add(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
|
|
x = self._to_const_val(x)
|
|
return mlops.Add.apply(*self._broadcasted(x, reverse)) if isinstance(x, Tensor) or x else self
|
|
def sub(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
|
|
x = self._to_const_val(x)
|
|
return mlops.Sub.apply(*self._broadcasted(x, reverse)) if isinstance(x, Tensor) or x else (-self if reverse else self)
|
|
def mul(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
|
|
x = self._to_const_val(x)
|
|
if not isinstance(x, Tensor) and x == 0.0: return mlops.Zero.apply(self)
|
|
if not isinstance(x, Tensor) and x == -1.0: return -self
|
|
return mlops.Mul.apply(*self._broadcasted(x, reverse)) if isinstance(x, Tensor) or x != 1.0 else self
|
|
def div(self, x:Union[Tensor, Scalar], reverse=False, upcast=True) -> Tensor:
|
|
x = self._to_const_val(x)
|
|
if not isinstance(x, Tensor) and not reverse and x != 0 and upcast: return self.mul(1/x)
|
|
if (isinstance(x, Tensor) and dtypes.is_float(x.dtype)) or not upcast: return mlops.Div.apply(*self._broadcasted(x, reverse))
|
|
return mlops.Div.apply(*self.cast(least_upper_float(self.dtype))._broadcasted(x, reverse))
|
|
def xor(self, x:Tensor, reverse=False) -> Tensor: return mlops.Xor.apply(*self._broadcasted(x, reverse))
|
|
|
|
def pow(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
|
|
x = self._to_const_val(x)
|
|
if not isinstance(x, Tensor) and not reverse:
|
|
# simple pow identities
|
|
if x < 0: return self.reciprocal().pow(-x)
|
|
if x in [3,2,1,0]: return functools.reduce(lambda acc,_: acc * self, range(int(x)), mlops.Zero.apply(self)+1)
|
|
if x == 0.5: return self.sqrt()
|
|
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
|
|
ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp()
|
|
# correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the power)
|
|
sign = (x * math.pi).cos() if isinstance(x, Tensor) else math.cos(x * math.pi) if not reverse else (self * math.pi).cos()
|
|
# we only need to correct the sign if the base is negative
|
|
base_sign = ((self.sign() if not reverse else x.sign() if isinstance(x, Tensor) else math.copysign(1, x)) - 1) / -2
|
|
# we need 0 to be positive so we need to correct base_sign when the base is 0
|
|
base_sign = base_sign - (1.5 * (1 - (self.sign().abs() if not reverse else x.sign().abs() if isinstance(x, Tensor) else abs(int(bool(x))))))
|
|
# inject nan if the base is negative and the power is not an integer
|
|
to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else \
|
|
int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign
|
|
inject_nan = ((((-to_nan) * 2) + 1)).log().add(1) if isinstance(to_nan, Tensor) else 1 if not to_nan else float("nan")
|
|
return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
|
|
|
|
def maximum(self, x:Union[Tensor, Scalar]) -> Tensor:
|
|
return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
|
|
def minimum(self, x:Union[Tensor, Scalar]) -> Tensor: return -((-self).maximum(-x))
|
|
|
|
def where(self:Tensor, input_:Union[Tensor, Scalar], other:Union[Tensor, Scalar]):
|
|
if isinstance(input_, Tensor): input_, other = input_._broadcasted(other)
|
|
elif isinstance(other, Tensor): other, input_ = other._broadcasted(input_)
|
|
x_,y = self._broadcasted(input_, match_dtype=False)
|
|
x,z = x_._broadcasted(other, match_dtype=False)
|
|
return mlops.Where.apply(x.cast(dtypes.bool), *y._broadcasted(z))
|
|
|
|
# ***** op wrappers (wasted lines to make the typechecker happy) *****
|
|
|
|
def __neg__(self) -> Tensor: return self.neg()
|
|
|
|
def __add__(self, x) -> Tensor: return self.add(x)
|
|
def __sub__(self, x) -> Tensor: return self.sub(x)
|
|
def __mul__(self, x) -> Tensor: return self.mul(x)
|
|
def __pow__(self, x) -> Tensor: return self.pow(x)
|
|
def __truediv__(self, x) -> Tensor: return self.div(x)
|
|
def __matmul__(self, x) -> Tensor: return self.matmul(x)
|
|
def __xor__(self, x) -> Tensor: return self.xor(x)
|
|
|
|
def __radd__(self, x) -> Tensor: return self.add(x, True)
|
|
def __rsub__(self, x) -> Tensor: return self.sub(x, True)
|
|
def __rmul__(self, x) -> Tensor: return self.mul(x, True)
|
|
def __rpow__(self, x) -> Tensor: return self.pow(x, True)
|
|
def __rtruediv__(self, x) -> Tensor: return self.div(x, True)
|
|
def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True)
|
|
def __rxor__(self, x) -> Tensor: return self.xor(x, True)
|
|
|
|
def __iadd__(self, x) -> Tensor: return self.assign(self.add(x))
|
|
def __isub__(self, x) -> Tensor: return self.assign(self.sub(x))
|
|
def __imul__(self, x) -> Tensor: return self.assign(self.mul(x))
|
|
def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
|
|
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
|
|
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
|
|
def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
|
|
|
|
def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False))
|
|
def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True))
|
|
def __ge__(self, x) -> Tensor: return (self<x).logical_not()
|
|
def __le__(self, x) -> Tensor: return (self>x).logical_not()
|
|
def __eq__(self, x) -> Tensor: return mlops.Eq.apply(*self._broadcasted(x, True)) # type: ignore[override]
|
|
def __ne__(self, x) -> Tensor: return (self==x).logical_not() # type: ignore[override]
|
|
|
|
# ***** functional nn ops *****
|
|
|
|
def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
|
|
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
|
|
return x.add(bias) if bias is not None else x
|
|
|
|
def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return functools.reduce(lambda x,f: f(x), ll, self)
|
|
|
|
def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor:
|
|
y = (self - self.mean(axis, keepdim=True))
|
|
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
|
|
|
|
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor:
|
|
axis_ = argfix(axis)
|
|
shape = tuple(s if ax in axis_ else 1 for ax, s in enumerate(self.shape))
|
|
x = self - mean.reshape(shape)
|
|
if weight is not None: x = x * weight.reshape(shape)
|
|
ret = x.mul(invstd.reshape(shape) if len(invstd.shape) == len(axis_) else invstd)
|
|
return (ret + bias.reshape(shape)) if bias is not None else ret
|
|
|
|
def dropout(self, p=0.5) -> Tensor:
|
|
if not Tensor.training or p == 0: return self
|
|
return self * (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p) * (1/(1.0 - p))
|
|
|
|
def one_hot(self, num_classes:int) -> Tensor:
|
|
return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0)
|
|
|
|
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,
|
|
dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
|
|
# NOTE: it works if key, value have symbolic shape
|
|
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
|
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
|
|
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
|
|
qk = self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1])
|
|
return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).dropout(dropout_p) @ value
|
|
|
|
def binary_crossentropy(self, y:Tensor) -> Tensor:
|
|
return (-y*self.log() - (1-y)*(1-self).log()).mean()
|
|
|
|
def binary_crossentropy_logits(self, y:Tensor) -> Tensor:
|
|
return (self.maximum(0) - y * self + (1 + self.abs().neg().exp()).log()).mean()
|
|
|
|
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0) -> Tensor:
|
|
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
|
|
# NOTE: self is a logits input
|
|
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index)
|
|
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
|
|
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
|
smoothing = -1 * label_smoothing * (log_probs.mean(-1) * loss_mask).sum() / loss_mask.sum()
|
|
return (1 - label_smoothing) * (log_probs * y).sum() / loss_mask.sum() + smoothing
|
|
|
|
# ***** cast ops *****
|
|
|
|
def llvm_bf16_cast(self, dtype:DType):
|
|
# hack for devices that don't support bfloat16
|
|
assert self.dtype == dtypes.bfloat16
|
|
return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
|
|
def cast(self, dtype:DType) -> Tensor: return self if self.dtype == dtype else mlops.Cast.apply(self, dtype=dtype)
|
|
def bitcast(self, dtype:DType) -> Tensor:
|
|
assert self.dtype.itemsize == dtype.itemsize, "can't bitcast mismatched dtype itemsizes"
|
|
return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
|
|
def float(self) -> Tensor: return self.cast(dtypes.float32)
|
|
def half(self) -> Tensor: return self.cast(dtypes.float16)
|
|
|
|
# ***** convenience stuff *****
|
|
|
|
@property
|
|
def ndim(self) -> int: return len(self.shape)
|
|
def numel(self) -> sint: return prod(self.shape)
|
|
def element_size(self) -> int: return self.dtype.itemsize
|
|
def nbytes(self) -> int: return self.numel() * self.element_size()
|
|
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
|
|
|
|
# register functions to move between devices
|
|
for device in Device._devices: setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, device))
|
|
|
|
if IMAGE:
|
|
# if IMAGE>0 we install these replacement functions in Tensor (hack!)
|
|
from tinygrad.features.image import image_conv2d, image_dot
|
|
setattr(Tensor, "conv2d", image_conv2d)
|
|
setattr(Tensor, "dot", image_dot)
|
|
|
|
# TODO: eventually remove this
|
|
def custom_random(out:Buffer):
|
|
Tensor._seed += 1
|
|
if DEBUG >= 2: print(f"*** {out.device} rand seed {Tensor._seed} size {out.size:<15d} dtype {out.dtype}")
|
|
rng = np.random.default_rng(Tensor._seed)
|
|
if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False)
|
|
else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype.np, copy=False)
|
|
out.copyin(rng_np_buffer.data)
|