Files
tinygrad/tinygrad/tensor.py
David Hou fc11808a79 initialize Tensor grad same type as self (#3613)
* initialize Tensor grad same type as self

* also test different default float

* check dtype + try/finally

* don't test_gradient_dtype if f16 is not supported

* fix bad merge

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2024-03-22 20:33:18 -04:00

1053 lines
63 KiB
Python

# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, math, itertools, functools
from contextlib import ContextDecorator
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, Dict, DefaultDict, cast, get_args
from collections import defaultdict
import numpy as np
from tinygrad.dtype import DType, dtypes, ImageDType, Scalar, least_upper_float, least_upper_dtype, cast_scalar
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv
from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
from tinygrad.lazy import LazyBuffer
from tinygrad.features.multi import MultiLazyBuffer
from tinygrad.ops import LoadOps
from tinygrad.device import Buffer, Device
from tinygrad.shape.symbolic import sint
from tinygrad.realize import run_schedule, create_schedule
# **** start with two base classes, Tensor and Function ****
class Function:
def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor):
self.device = device
self.needs_input_grad = [t.requires_grad for t in tensors]
self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
if self.requires_grad: self.parents = tensors
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")
@classmethod
def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
ctx = fxn(x[0].device, *x)
ret = Tensor.__new__(Tensor)
ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
return ret
import tinygrad.mlops as mlops
def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Tuple[LazyBuffer, ...]=()):
if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src)
return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None)
def _fromcpu(x: np.ndarray) -> LazyBuffer:
return LazyBuffer.loadop(LoadOps.EMPTY, x.shape, dtypes.from_np(x.dtype), "EXT",
_buf=Buffer("EXT", 0, dtypes.from_np(x.dtype), (memoryview(bytearray()), None)) if x.size == 0 else \
Buffer("EXT", prod(x.shape), dtypes.from_np(x.dtype), (flat_mv(np.require(x, requirements='C').data), x)))
def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str, Tuple[str, ...]]) -> List[List[Tensor]]:
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device) for m in mat], dim=dim)
for k in range(len(mat[0]))] for dim in range(dims)]
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
# multiply mat_1 @ mat_2 @ t with foldable constants, where mat_i acts on vector t along dimension i; roughly kron(mat, mat) @ t
# due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic
t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims
# precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...)
matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device)
# multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t
ret = sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
return ret
class Tensor:
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
__deletable__ = ('_ctx',)
training: ClassVar[bool] = False
class train(ContextDecorator):
def __init__(self, mode:bool = True): self.mode = mode
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
no_grad: ClassVar[bool] = False
class inference_mode(ContextDecorator):
def __init__(self, mode:bool = True): self.mode = mode
def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
def __init__(self, data:Union[None, Scalar, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer],
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
# tensors have gradients, buffers do not
self.grad: Optional[Tensor] = None
# NOTE: this can be in three states. False and None: no gradient, True: gradient
# None (the default) will be updated to True if it's put in an optimizer
self.requires_grad: Optional[bool] = requires_grad
# internal variables used for autograd graph construction
self._ctx: Optional[Function] = None
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
elif isinstance(data, get_args(Scalar)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
elif isinstance(data, bytes): data = _fromcpu(np.frombuffer(data, np.uint8))
elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
elif isinstance(data, list):
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool
elif d and all_int(d): dtype = dtype or dtypes.default_int
else: dtype = dtype or dtypes.default_float
if dtype == dtypes.bfloat16: data = Tensor(_fromcpu(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
else: data = _fromcpu(np.array(data, dtype.np))
elif isinstance(data, np.ndarray):
if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
else: data = _fromcpu(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
# data is a LazyBuffer, but it might be on the wrong device
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
if isinstance(device, tuple):
# TODO: what if it's a MultiLazyBuffer on other devices?
self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = MultiLazyBuffer.from_sharded(data, device, None) if isinstance(data, LazyBuffer) else data
else:
self.lazydata = data if data.device == device else data.copy_to_device(device)
def __repr__(self): return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
# Python has a non moving GC, so this should be okay
def __hash__(self): return id(self)
def __bool__(self): raise TypeError("__bool__ on Tensor is not defined")
@property
def device(self) -> Union[str, Tuple[str, ...]]: return self.lazydata.device
@property
def shape(self) -> Tuple[sint, ...]: return self.lazydata.shape
@property
def dtype(self) -> DType: return self.lazydata.dtype
# ***** data handlers ****
@staticmethod
def corealize(lst:Iterable[Tensor]):
run_schedule(create_schedule(flatten([x.lazydata.lbs if isinstance(x.lazydata, MultiLazyBuffer) else [x.lazydata] for x in lst])))
def realize(self) -> Tensor:
Tensor.corealize([self])
return self
def replace(self, x:Tensor) -> Tensor:
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
assert not x.requires_grad and getattr(self, '_ctx', None) is None
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
self.lazydata = x.lazydata
return self
def assign(self, x) -> Tensor:
# TODO: this is a hack for writing to DISK. remove with working assign
if isinstance(self.device, str) and self.device.startswith("DISK"):
if x.__class__ is not Tensor: x = Tensor(x, device="EXT", dtype=self.dtype)
self.contiguous().realize().lazydata.base.realized.copyin(x.numpy().data)
return self
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
if self.lazydata is x.lazydata: return self # a self assign is a NOOP
# NOTE: we allow cross device assign
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
assert not x.requires_grad # self requires_grad is okay?
if not self.lazydata.is_realized(): return self.replace(x)
self.lazydata = self.lazydata.assign(x.lazydata)
return self
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
def _data(self) -> memoryview:
if 0 in self.shape: return memoryview(bytearray(0))
t = self if isinstance(self.device, str) else self.to(self.device[0]) # deal with multitensor
return cast(Buffer, t.cast(t.dtype.scalar()).contiguous().realize().lazydata.base.realized).as_buffer()
def data(self) -> memoryview:
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
return self._data().cast(self.dtype.fmt, self.shape if len(self.shape) else (1,))
def item(self) -> Scalar:
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
assert self.numel() == 1, "must have one element for item"
return self._data().cast(self.dtype.fmt)[0]
def numpy(self) -> np.ndarray:
if self.dtype == dtypes.bfloat16: return self.float().numpy()
assert self.dtype.np is not None, f"no np dtype for {self.dtype}"
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
return np.frombuffer(self._data(), dtype=self.dtype.np).reshape(self.shape)
def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor:
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
if device is None or device == self.device: return self
if not isinstance(device, str): return self.shard(device)
ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
if self.grad is not None: ret.grad = self.grad.to(device)
if hasattr(self, '_ctx'): ret._ctx = self._ctx
return ret
def to_(self, device:Optional[Union[str, Tuple[str, ...]]]):
real = self.to(device)
# TODO: is this assign?
if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
self.lazydata = real.lazydata
def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None) -> Tensor:
assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
canonical_devices = tuple(Device.canonicalize(x) for x in devices)
if axis is not None and axis < 0: axis += len(self.shape)
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices, requires_grad=self.requires_grad)
def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None):
self.lazydata = self.shard(devices, axis).lazydata
return self
# ***** creation llop entrypoint *****
@staticmethod
def _loadop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
if isinstance(device, tuple):
return Tensor(MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \
for d in device], None), device, dtype, **kwargs)
return Tensor(LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), device, dtype, **kwargs)
@staticmethod
def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, argfix(*shape), **kwargs)
_seed: int = int(time.time())
_rng_counter: Optional[Tensor] = None
@staticmethod
def manual_seed(seed=0): Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False)
@staticmethod
def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, **kwargs):
if Tensor._rng_counter is None: Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
if not THREEFRY.value:
if dtype == dtypes.bfloat16:
return Tensor.rand(*shape, **kwargs, device=device, dtype=dtypes.float).cast(dtypes.bfloat16)
return Tensor._loadop(LoadOps.CUSTOM, argfix(*shape), arg=custom_random, device=device, dtype=dtype, **kwargs)
# threefry
if (num := prod((shape:=argfix(*shape)))) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
counts = (Tensor.arange(num, device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device)).realize().pad(((0,num%2),))
Tensor._rng_counter.assign(Tensor._rng_counter + num).realize()
rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
ks = [0x0, Tensor._seed ^ 0x0 ^ 0x1BD11BDA, Tensor._seed]
x = [(c := counts.chunk(2))[0] + ks[-1], c[1] + ks[0]]
for i in range(5):
for r in rotations[i % 2]: x[0], x[1] = (x0 := x[0] + x[1]), x0 ^ ((x[1] * (2 ** r)) + (x[1].div(2 ** (32 - r), upcast=False)))
x = [(x[0] + ks[i % 3]), (x[1] + ks[(i + 1) % 3] + i + 1)]
out = x[0].cat(x[1])[:num].div(2 ** 8, upcast=False).cast(dtypes.float32).div(2 ** 24)
out = out.reshape(shape).cast(dtypes.default_float if dtype is None else dtype)
out.requires_grad = kwargs.get("requires_grad")
return out.contiguous()
# ***** creation helper functions *****
@staticmethod
def full(shape:Tuple[sint, ...], fill_value:Scalar, **kwargs):
return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
@staticmethod
def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0.0, **kwargs)
@staticmethod
def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1.0, **kwargs)
@staticmethod
def arange(start, stop=None, step=1, **kwargs):
if stop is None: stop, start = start, 0
assert all(isinstance(s, (int, float)) for s in (start, stop, step)), "symbolic arange not supported"
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
@staticmethod
def eye(dim:int, **kwargs):
return Tensor.ones((dim,1),**kwargs).pad((None,(0,dim))).flatten().shrink(((0,dim*dim),)).reshape(dim, dim)
def full_like(self, fill_value:Scalar, **kwargs):
return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
def zeros_like(self, **kwargs): return self.full_like(0, **kwargs)
def ones_like(self, **kwargs): return self.full_like(1, **kwargs)
# ***** rng hlops *****
@staticmethod
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
src = Tensor.rand((2, *argfix(*shape)), **kwargs)
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)
@staticmethod
def randint(*shape, low=0, high=10, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=low, high=high, dtype=dtypes.int32, **kwargs)
@staticmethod
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean
@staticmethod
def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
dtype = kwargs.pop("dtype", dtypes.default_float)
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
@staticmethod
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
@staticmethod
def glorot_uniform(*shape, **kwargs) -> Tensor:
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(argfix(*shape)[0]+prod(argfix(*shape)[1:])))**0.5)
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
@staticmethod
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
@staticmethod
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor:
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
weight = self.unsqueeze(0) if self.ndim == 1 else self
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1, device=self.device)
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
# ***** toposort and backward pass *****
def deepwalk(self):
def _deepwalk(node, visited):
visited.add(node)
if getattr(node, "_ctx", None):
for i in node._ctx.parents:
if i not in visited: yield from _deepwalk(i, visited)
yield node
return list(_deepwalk(self, set()))
def backward(self) -> Tensor:
assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})"
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
# this is "implicit gradient creation"
self.grad = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
for t0 in reversed(self.deepwalk()):
if t0.grad is None: raise RuntimeError("tensor has no grad")
grads = t0._ctx.backward(t0.grad.lazydata)
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
for t, g in zip(t0._ctx.parents, grads):
if g is not None and t.requires_grad:
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
t.grad = g if t.grad is None else (t.grad + g)
del t0._ctx
return self
# ***** movement mlops *****
def reshape(self, shape, *args) -> Tensor:
new_shape = argfix(shape, *args)
new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])
return mlops.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
def expand(self, shape, *args) -> Tensor:
new_shape = tuple([x if x != -1 and x is not None else s for s,x in zip(self.shape, argfix(shape, *args))])
return mlops.Expand.apply(self, shape=new_shape) if new_shape != self.shape else self
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
if all(x is None or x == (0,s) for x,s in zip(arg, self.shape)): return self
return mlops.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape)))
def pad(self, arg:Tuple[Optional[Tuple[sint, sint]], ...], value:float=0.0) -> Tensor:
if all(x is None or x == (0,0) for x in arg): return self
ret = mlops.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg)))
return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value)
# ***** movement hlops *****
# Supported Indexing Implementations:
# 1. Int indexing (no copy)
# - for all dims where there's int, shrink -> reshape
# - negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
# - X = Tensor.rand(4,5,9); X[2,-2] shrinks the Tensor to X.shrink(((2, 3), (3, 4), (0, 9))) -> X.shape=(1,1,9)
# - Then we reshape (collapse) the int dim away such that for X: (1,1,9) -> (9,)
# 2. Slice indexing (no copy)
# - for all dims where slice is start:end:stride, shrink -> Optional[flip] -> pad -> reshape -> shrink
# - first shrink the Tensor to X.shrink(((start, end),))
# - then we apply stride through Optional[flip] -> pad -> reshape -> shrink
# - flip where dim value is negative
# - pad 0's on dims such that reshaping [dim_size_padded] -> [dim_size_padded // stride, stride] is possible
# - shrink [dim_size_padded // stride, stride] -> [dim_size_padded // stride, 1]
# - reshape [dim_size_padded // stride, 1] -> [dim_size_padded // stride] and now you have your stride
# 3. None indexing (no copy)
# - reshape (inject) a dim at the dim where there's None
# 4. Tensor indexing (copy)
# - use Tensor.arange == tensor_index to create a mask
# - apply mask to self by mask * self for dims where index is a tensor
# - (mask * self).sum(dim) to reduce to correct shape
# Tiny Things:
# 1. Supported indices: Union[int, slice, Tensor, None, List, Tuple, Ellipsis]
# - for any list, List[Union[List, Tuple, int]], must have homogeneous shape
# - for any tuple, Tuple[Union[List, Tuple, int]], must have homogeneous shape
# 2. Bool indexing is not supported
# 3. Out of bounds Tensor indexing results in 0
# - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are OOB
def __getitem__(self, indices) -> Tensor:
# 1. indices normalization and validation
# treat internal tuples and lists as Tensors and standardize indices to list type
if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, self.device, requires_grad=False)]
elif isinstance(indices, (tuple, list)):
indices = [Tensor(list(i), self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices]
else: indices = [indices]
# turn scalar Tensors into const val for int indexing if possible
indices = [self._to_const_val(i) if isinstance(i, Tensor) else i for i in indices]
# move Tensor indices to the same device as self
indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices]
# filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis]
fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
indices[fill_idx:fill_idx+1] = [slice(None)] * (len(self.shape) - num_indices)
# use Dict[type, List[dimension]] to track elements in indices
type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
# record None for dimension injection later and filter None and record rest of indices
type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
indices_filtered = [v for v in indices if v is not None]
for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
for index_type in type_dim:
if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported")
if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}")
# 2. basic indexing, uses only movement ops (no copy)
# currently indices_filtered: Tuple[Union[slice, int, Tensor], ...]
# turn indices in indices_filtered to Tuple[shrink_arg, strides]
for dim in type_dim[int]:
if (index := indices_filtered[dim]) >= (size := self.shape[dim]) or index < -size:
raise IndexError(f"{index=} is out of bounds on {dim=} with {size=}")
indices_filtered[dim] = ((index, index+1), 1) if index >= 0 else ((size+index, size+index+1), 1)
for dim in type_dim[slice]:
if (index := indices_filtered[dim]).step == 0: raise ValueError(f"{index=} on {dim=} cannot have 0 as step")
s, e, st = index.indices(self.shape[dim])
indices_filtered[dim] = ((0, 0) if (st * (e - s)) < 0 else (s, e) if st > 0 else (e+1, s+1), st)
# record tensors and skip all Tensor dims for basic indexing
tensor_index: List[Tensor] = []
for dim in type_dim[Tensor]:
tensor_index.append(index := indices_filtered[dim])
if not dtypes.is_int(index.dtype): raise IndexError(f"{index.dtype=} on {dim=} is not supported, only int tensor indexing is supported")
indices_filtered[dim] = ((0, self.shape[dim]), 1)
new_slice, strides = ((),()) if not indices_filtered else zip(*indices_filtered)
ret = self.shrink(new_slice).flip(tuple(i for i, s in enumerate(strides) if s < 0))
if any(abs(s) != 1 for s in strides):
strides = tuple(abs(s) for s in strides)
ret = ret.pad(tuple((0, round_up(sh, s) - sh) for s, sh in zip(strides, ret.shape)))
ret = ret.reshape(tuple(flatten((sh // s, s) for s, sh in zip(strides, ret.shape))))
ret = ret.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in ret.shape[::2]))).reshape(ret.shape[::2])
# inject 1 for dim where it's None and collapse dim for int
new_shape = list(ret.shape)
for dim in type_dim[None]: new_shape.insert(dim, 1)
for dim in (dims_collapsed := tuple(dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int]))): new_shape.pop(dim)
ret = ret.reshape(new_shape)
assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}"
# 3. advanced indexing (copy)
if type_dim[Tensor]:
# calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim
def calc_dim(tensor_dim:int) -> int:
return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d) + sum(1 for d in type_dim[None] if tensor_dim >= d)
# track tensor_dim and tensor_index using a dict
# calc_dim to get dim and use that to normalize the negative tensor indices
idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(type_dim[Tensor],tensor_index)}
# compute sum_dim, arange, and idx
max_idx_dim, first_dim, last_dim = max(i.ndim for i in idx.values()), min(idx.keys()), max(idx.keys())
sum_dim = tuple(d if n==0 else d+max_idx_dim-n for n,d in enumerate(idx.keys()))
arange = [Tensor.arange(ret.shape[d], requires_grad=False, device=self.device).reshape(ret.shape[d], *[1]*(ret.ndim+max_idx_dim-n-sd-1)) \
for n,(sd,d) in enumerate(zip(sum_dim, idx.keys()))]
reshaped_idx = [i.reshape(i.shape + (1,)*(ret.ndim - first_dim - (n or 1))) for n,i in enumerate(idx.values())]
ret = ret.reshape(ret.shape[:first_dim+1] + (1,)*max_idx_dim + ret.shape[first_dim+1:])
# iteratively eq -> mul -> sum fancy index
try:
for a,i,sd in zip(arange, reshaped_idx, sum_dim): ret = (a==i).mul(ret).sum(sd)
except AssertionError as exc: raise IndexError("cannot broadcast indices") from exc
# special permute case
if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
ret_dims = list(range(ret.ndim))
ret = ret.permute(ret_dims[first_dim:first_dim+max_idx_dim] + ret_dims[:first_dim] + ret_dims[first_dim+max_idx_dim:])
return ret
def __setitem__(self,indices,v): return self.__getitem__(indices).assign(v)
# NOTE: using slice is discouraged and things should migrate to pad and shrink
def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
arg_ = tuple(a if a is not None else (0, s) for s,a in zip(self.shape, arg))
padding = tuple((max(0, -l), max(0, r-s)) for s,(l,r) in zip(self.shape, arg_))
return self.pad(padding, value=value).shrink(tuple((l + pl, r + pl) for (l,r),(pl,_) in zip(arg_, padding)))
def gather(self:Tensor, idx:Tensor, dim:int) -> Tensor:
assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim"
assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape"
dim = self._resolve_dim(dim)
idx = idx.to(self.device).transpose(ax1=dim, ax2=0).unsqueeze(-1)
permarg = list(range(self.ndim))
permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]]
return ((idx == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(
tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim)
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
dim = self._resolve_dim(dim)
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
catargs = [self, *args]
cat_dims = [s.shape[dim] for s in catargs]
cat_dim_cumsum = [0, *itertools.accumulate(cat_dims)]
slc:List[List[Optional[Tuple[sint, sint]]]] = [[None for _ in self.shape] for _ in catargs]
for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d)
return functools.reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
@staticmethod
def stack(tensors:Sequence[Tensor], dim:int=0) -> Tensor:
unsqueezed_tensors = [tensor.unsqueeze(dim) for tensor in tensors]
# checks for shapes and number of dimensions delegated to cat
return unsqueezed_tensors[0].cat(*unsqueezed_tensors[1:], dim=dim)
def repeat(self, repeats:Sequence[int]) -> Tensor:
base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
new_shape = [x for b in base_shape for x in [1, b]]
expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
final_shape = [r*s for r,s in zip(repeats, base_shape)]
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
def _resolve_dim(self, dim:int, *, outer:bool=False) -> int:
if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer):
raise IndexError(f"{dim=} out of range {[-max(1, self.ndim+outer), max(1, self.ndim+outer)-1]}")
return dim + self.ndim+outer if dim < 0 else dim
def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
dim = self._resolve_dim(dim)
if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
def chunk(self, num:int, dim:int=0) -> List[Tensor]:
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
assert num > 0, f"expect num to be greater than 0, got: {num}"
dim = self._resolve_dim(dim)
return list(self.split(math.ceil(self.shape[dim]/num) if self.shape[dim] else [0]*num, dim=dim))
def squeeze(self, dim:Optional[int]=None) -> Tensor:
if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
dim = self._resolve_dim(dim)
return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
def unsqueeze(self, dim:int) -> Tensor:
dim = self._resolve_dim(dim, outer=True)
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
# (padding_left, padding_right, padding_top, padding_bottom)
def pad2d(self, padding:Sequence[int], value:float=0) -> Tensor:
slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1]
return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
@property
def T(self) -> Tensor: return self.transpose()
def transpose(self, ax1=1, ax2=0) -> Tensor:
order = list(range(self.ndim))
order[ax1], order[ax2] = order[ax2], order[ax1]
return self.permute(order)
def flatten(self, start_dim=0, end_dim=-1):
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
def unflatten(self, dim:int, sizes:Tuple[int,...]):
dim = self._resolve_dim(dim)
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
# ***** reduce ops *****
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False) -> Tensor:
axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis))
axis_ = tuple(x if x >= 0 else x+len(self.shape) for x in axis_)
shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_)
ret = fxn.apply(self, axis=axis_)
return ret if keepdim else ret.reshape(shape=shape)
def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None):
if acc_dtype is None: acc_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \
least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \
least_upper_dtype(self.dtype, dtypes.float)
# cast back to float16 or bfloat16 to match torch / jax behavior, but we use float for acc
output_dtype = self.dtype if self.dtype in (dtypes.float16, dtypes.bfloat16) else acc_dtype
return self.cast(acc_dtype)._reduce(mlops.Sum, axis, keepdim).cast(output_dtype)
def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim)
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
def mean(self, axis=None, keepdim=False):
assert all_int(self.shape), "does not support symbolic shape"
out = self.sum(axis=axis, keepdim=keepdim)
return out.div(prod(self.shape) / prod(out.shape)) if 0 not in out.shape else out
def var(self, axis=None, keepdim=False, correction=1):
assert all_int(self.shape), "does not support symbolic shape"
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
return square_sum.div(max(0, prod(self.shape)/prod(square_sum.shape)-correction))
def std(self, axis=None, keepdim=False, correction=1): return self.var(axis, keepdim, correction).sqrt()
def _softmax(self, axis):
if len(self.shape) == 0:
assert axis in [-1, 0], f"{axis=} out of range of [-1, 0]"
axis = None
m = self - self.max(axis=axis, keepdim=True)
e = m.exp()
return m, e, e.sum(axis=axis, keepdim=True)
def softmax(self, axis=-1):
_, e, ss = self._softmax(axis)
return e.div(ss)
def log_softmax(self, axis=-1):
m, _, ss = self._softmax(axis)
return m - ss.log()
def argmax(self, axis=None, keepdim=False):
# NOTE: return the first index if there are multiple occurrences of the maximum values
if axis is None:
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape)
return (prod(self.shape) - idx.max() - 1).cast(dtypes.int32)
axis = self._resolve_dim(axis)
m = self == self.max(axis=axis, keepdim=True)
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.int32)
def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)
@staticmethod
def einsum(formula:str, *raw_xs) -> Tensor:
xs:Tuple[Tensor] = argfix(*raw_xs)
formula = formula.replace(" ", "")
inputs_str, output = formula.split("->") if "->" in formula else (formula, sorted(formula))
inputs = [x for x in cast(str,inputs_str).split(',')]
assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
# map the value of each letter in the formula
letter_val = sorted(merge_dicts([{letter:dim for letter, dim in zip(letters, tensor.shape)} for letters, tensor in zip(inputs, xs)]).items())
xs_:List[Tensor] = []
lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs]
for x,(order,letters) in zip(xs, [list(zip(*l)) for l in lhs]):
# permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val]))
rhs_order, rhs_letters = tuple(zip(*sorted(enumerate(output), key=lambda e:e[1]))) or ([], [])
# sum over all axes that's not in the output, then permute to the output order
return functools.reduce(lambda a,b:a*b, xs_) \
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in rhs_letters]).permute(rhs_order)
# ***** processing ops *****
def _pool(self, k_:Tuple[sint, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
assert all_int(self.shape) and all_int(k_), f"does not support symbolic {self.shape=}, {k_=}"
s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
noop_, i_ = [None] * len(self.shape[:-len(k_)]), self.shape[-len(k_):]
if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)]
# repeats such that we don't need padding
xup = self.repeat([1]*len(noop_) + [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)])
# slice by dilation
xup = xup.slice(noop_ + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)]).reshape(noop_ + flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
# handle stride
xup = xup.slice(noop_ + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_))).reshape(noop_ + flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
xup = xup.slice(noop_ + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_))).reshape(noop_ + flatten((k,o) for k,o in zip(k_, o_)))
# permute to move reduce to the end
return xup.permute(*range(len(noop_)), *[len(noop_)+i*2+1 for i in range(len(i_))], *[len(noop_)+i*2 for i in range(len(i_))])
# TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
xup = self.slice(noop_ + [(0,o*s) for o,s in zip(o_, s_)])
xup = xup.reshape(noop_ + flatten(((o,s) for o,s in zip(o_, s_))))
xup = xup.slice(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
return xup.permute(*range(len(noop_)), *[len(noop_)+i*2 for i in range(len(i_))], *[len(noop_)+i*2+1 for i in range(len(i_))])
# NOTE: these work for more than 2D
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(
make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(
make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
x, w = self, weight.unflatten(0, (groups, -1)).permute(0,2,1,*trailing).flip(trailing)
stride = make_pair(stride, len(HW))
if any(s>1 for s in stride):
x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(
zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW)))))))
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor:
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
padding_ = [padding]*2*len(HW) if isinstance(padding, int) else (padding if len(padding) == 2*len(HW) else [p for p in padding for _ in range(2)][::-1]) # noqa: E501
# conv2d is a pooling op (with padding)
x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not WINO:
# normal conv
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, acc_dtype=acc_dtype).reshape(bs, cout, *oyx) # noqa: E501
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]]
winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order doubles compile time
# todo: stride == dilation
# use padding to round up to 4x4 output tiles
# (bs, cin_, tyx, HWI)
d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501
# move HW to the front: # (HWI, bs, cin_, tyx)
d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW)))
tyx = d.shape[-len(HWI):] # dim of tiling
g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
# compute 6x6 winograd tiles: GgGt, BtdB
# (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
# (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx)
# matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), acc_dtype=acc_dtype), len(HW))
# interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]])
# merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final
ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(tuple((0, s) for s in [bs, cout, *oyx]))
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
def dot(self, w:Tensor, acc_dtype:Optional[DType]=None) -> Tensor:
n1, n2 = len(self.shape), len(w.shape)
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
assert (L:=self.shape[-1]) == (R:=w.shape[-min(n2, 2)]), f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})"
x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype))
def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DType]=None) -> Tensor:
return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
pl_sz = self.shape[axis] - int(not _first_zero and self.shape[axis] != 0)
return self.transpose(axis,-1).pad2d((pl_sz,0))._pool((self.shape[axis] or 1,)).sum(-1).transpose(axis,-1)
def cumsum(self, axis:int=0) -> Tensor:
# TODO: someday the optimizer will find this on it's own
# for now this is a two stage cumsum
SPLIT = 256
if self.shape[axis] <= SPLIT*2: return self._cumsum(axis)
ret = self.transpose(axis,-1).pad2d((round_up(self.shape[axis], SPLIT)-self.shape[axis], 0))
ret = ret.unflatten(-1, (-1, SPLIT))._cumsum(-1)
base_add = ret[..., -1]._cumsum(-1, _first_zero=True)[..., :-1]
base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -self.shape[axis]:].transpose(axis,-1)
return fix(ret) + fix(base_add)
@staticmethod
def _tri(r:sint, c:sint, k:int=0, **kwargs) -> Tensor:
assert all_int((r,c)), "does not support symbolic"
if r == 0: return Tensor.zeros((r, c), **kwargs)
return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0)
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self)
# ***** mlops (unary) *****
def logical_not(self): return mlops.Eq.apply(*self._broadcasted(False))
def neg(self): return mlops.Neg.apply(self) if self.dtype != dtypes.bool else self.logical_not()
def contiguous(self): return mlops.Contiguous.apply(self)
def contiguous_backward(self): return mlops.ContiguousBackward.apply(self)
def log(self): return mlops.Log.apply(self.cast(least_upper_float(self.dtype)))
def log2(self): return self.log()/math.log(2)
def exp(self): return mlops.Exp.apply(self.cast(least_upper_float(self.dtype)))
def exp2(self): return mlops.Exp.apply(self*math.log(2))
def relu(self): return mlops.Relu.apply(self)
def sigmoid(self): return mlops.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
def sin(self): return mlops.Sin.apply(self.cast(least_upper_float(self.dtype)))
def sqrt(self): return mlops.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
def rsqrt(self): return self.reciprocal().sqrt()
def cos(self): return ((math.pi/2)-self).sin()
def tan(self): return self.sin() / self.cos()
# ***** math functions (unary) *****
def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).cast(self.dtype)
def ceil(self: Tensor) -> Tensor: return (self > (b := self.trunc())).where(b+1, b)
def floor(self: Tensor) -> Tensor: return (self < (b := self.trunc())).where(b-1, b)
def round(self: Tensor) -> Tensor:
return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
def square(self): return self*self
def clip(self, min_, max_): return self.maximum(min_).minimum(max_)
def abs(self): return self.relu() + (-self).relu()
def sign(self): return ((self.float()) / (self.float().abs() + 1e-12)).cast(self.dtype)
def reciprocal(self): return mlops.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
# ***** activation functions (unary) *****
def elu(self, alpha=1.0): return self.relu() - alpha*(1-self.exp()).relu()
def celu(self, alpha=1.0): return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
def swish(self): return self * self.sigmoid()
def silu(self): return self.swish() # The SiLU function is also known as the swish function.
def relu6(self): return self.relu() - (self-6).relu()
def hardswish(self): return self * (self+3).relu6() * (1/6)
def tanh(self): return 2.0 * ((2.0 * self).sigmoid()) - 1.0
def sinh(self): return (self.exp() - self.neg().exp()) / 2
def cosh(self): return (self.exp() + self.neg().exp()) / 2
def atanh(self): return ((1 + self)/(1 - self)).log() / 2
def asinh(self): return (self + (self.square() + 1).sqrt()).log()
def acosh(self): return (self + (self.square() - 1).sqrt()).log()
def hardtanh(self, min_val=-1, max_val=1): return self.clip(min_val, max_val)
def gelu(self): return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
def quick_gelu(self): return self * (self * 1.702).sigmoid()
def leakyrelu(self, neg_slope=0.01): return self.relu() - (-neg_slope*self).relu()
def mish(self): return self * self.softplus().tanh()
def softplus(self, beta=1): return (1/beta) * (1 + (self*beta).exp()).log()
def softsign(self): return self / (1 + self.abs())
# ***** broadcasted elementwise mlops *****
def _broadcasted(self, y:Union[Tensor, Scalar], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
x: Tensor = self
if not isinstance(y, Tensor):
# make y a Tensor
assert isinstance(y, (float, int, bool)), f"{type(y)=}, {y=}"
if isinstance(self.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
else: y_dtype = dtypes.from_py(y)
y = Tensor(cast_scalar(y, y_dtype), self.device, y_dtype, requires_grad=False)
if match_dtype:
output_dtype = least_upper_dtype(x.dtype, y.dtype)
x, y = x.cast(output_dtype), y.cast(output_dtype)
if reverse: x, y = y, x
# left pad shape with 1s
if len(y.shape) < len(x.shape): y = y.reshape((1,) * (len(x.shape) - len(y.shape)) + y.shape)
elif len(x.shape) < len(y.shape): x = x.reshape((1,) * (len(y.shape) - len(x.shape)) + x.shape)
broadcasted_shape = tuple(0 if xi==0 or yi==0 else max(xi, yi) for xi, yi in zip(x.shape, y.shape))
return x.expand(broadcasted_shape), y.expand(broadcasted_shape)
def _to_const_val(self, x:Union[Tensor, Scalar]) -> Union[Tensor, Scalar]:
# TODO: update with multi
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_contiguous_const() \
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
def add(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
x = self._to_const_val(x)
return mlops.Add.apply(*self._broadcasted(x, reverse)) if isinstance(x, Tensor) or x else self
def sub(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
x = self._to_const_val(x)
return mlops.Sub.apply(*self._broadcasted(x, reverse)) if isinstance(x, Tensor) or x else (-self if reverse else self)
def mul(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
x = self._to_const_val(x)
if not isinstance(x, Tensor) and x == 0.0: return mlops.Zero.apply(self)
if not isinstance(x, Tensor) and x == -1.0: return -self
return mlops.Mul.apply(*self._broadcasted(x, reverse)) if isinstance(x, Tensor) or x != 1.0 else self
def div(self, x:Union[Tensor, Scalar], reverse=False, upcast=True) -> Tensor:
x = self._to_const_val(x)
if not isinstance(x, Tensor) and not reverse and x != 0 and upcast: return self.mul(1/x)
if (isinstance(x, Tensor) and dtypes.is_float(x.dtype)) or not upcast: return mlops.Div.apply(*self._broadcasted(x, reverse))
return mlops.Div.apply(*self.cast(least_upper_float(self.dtype))._broadcasted(x, reverse))
def xor(self, x:Tensor, reverse=False) -> Tensor: return mlops.Xor.apply(*self._broadcasted(x, reverse))
def pow(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
x = self._to_const_val(x)
if not isinstance(x, Tensor) and not reverse:
# simple pow identities
if x < 0: return self.reciprocal().pow(-x)
if x in [3,2,1,0]: return functools.reduce(lambda acc,_: acc * self, range(int(x)), mlops.Zero.apply(self)+1)
if x == 0.5: return self.sqrt()
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp()
# correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the power)
sign = (x * math.pi).cos() if isinstance(x, Tensor) else math.cos(x * math.pi) if not reverse else (self * math.pi).cos()
# we only need to correct the sign if the base is negative
base_sign = ((self.sign() if not reverse else x.sign() if isinstance(x, Tensor) else math.copysign(1, x)) - 1) / -2
# we need 0 to be positive so we need to correct base_sign when the base is 0
base_sign = base_sign - (1.5 * (1 - (self.sign().abs() if not reverse else x.sign().abs() if isinstance(x, Tensor) else abs(int(bool(x))))))
# inject nan if the base is negative and the power is not an integer
to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else \
int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign
inject_nan = ((((-to_nan) * 2) + 1)).log().add(1) if isinstance(to_nan, Tensor) else 1 if not to_nan else float("nan")
return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
def maximum(self, x:Union[Tensor, Scalar]) -> Tensor:
return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
def minimum(self, x:Union[Tensor, Scalar]) -> Tensor: return -((-self).maximum(-x))
def where(self:Tensor, input_:Union[Tensor, Scalar], other:Union[Tensor, Scalar]):
if isinstance(input_, Tensor): input_, other = input_._broadcasted(other)
elif isinstance(other, Tensor): other, input_ = other._broadcasted(input_)
x_,y = self._broadcasted(input_, match_dtype=False)
x,z = x_._broadcasted(other, match_dtype=False)
return mlops.Where.apply(x.cast(dtypes.bool), *y._broadcasted(z))
# ***** op wrappers (wasted lines to make the typechecker happy) *****
def __neg__(self) -> Tensor: return self.neg()
def __add__(self, x) -> Tensor: return self.add(x)
def __sub__(self, x) -> Tensor: return self.sub(x)
def __mul__(self, x) -> Tensor: return self.mul(x)
def __pow__(self, x) -> Tensor: return self.pow(x)
def __truediv__(self, x) -> Tensor: return self.div(x)
def __matmul__(self, x) -> Tensor: return self.matmul(x)
def __xor__(self, x) -> Tensor: return self.xor(x)
def __radd__(self, x) -> Tensor: return self.add(x, True)
def __rsub__(self, x) -> Tensor: return self.sub(x, True)
def __rmul__(self, x) -> Tensor: return self.mul(x, True)
def __rpow__(self, x) -> Tensor: return self.pow(x, True)
def __rtruediv__(self, x) -> Tensor: return self.div(x, True)
def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True)
def __rxor__(self, x) -> Tensor: return self.xor(x, True)
def __iadd__(self, x) -> Tensor: return self.assign(self.add(x))
def __isub__(self, x) -> Tensor: return self.assign(self.sub(x))
def __imul__(self, x) -> Tensor: return self.assign(self.mul(x))
def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False))
def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True))
def __ge__(self, x) -> Tensor: return (self<x).logical_not()
def __le__(self, x) -> Tensor: return (self>x).logical_not()
def __eq__(self, x) -> Tensor: return mlops.Eq.apply(*self._broadcasted(x, True)) # type: ignore[override]
def __ne__(self, x) -> Tensor: return (self==x).logical_not() # type: ignore[override]
# ***** functional nn ops *****
def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
return x.add(bias) if bias is not None else x
def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return functools.reduce(lambda x,f: f(x), ll, self)
def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor:
y = (self - self.mean(axis, keepdim=True))
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor:
axis_ = argfix(axis)
shape = tuple(s if ax in axis_ else 1 for ax, s in enumerate(self.shape))
x = self - mean.reshape(shape)
if weight is not None: x = x * weight.reshape(shape)
ret = x.mul(invstd.reshape(shape) if len(invstd.shape) == len(axis_) else invstd)
return (ret + bias.reshape(shape)) if bias is not None else ret
def dropout(self, p=0.5) -> Tensor:
if not Tensor.training or p == 0: return self
return self * (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p) * (1/(1.0 - p))
def one_hot(self, num_classes:int) -> Tensor:
return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0)
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,
dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
# NOTE: it works if key, value have symbolic shape
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
qk = self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1])
return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).dropout(dropout_p) @ value
def binary_crossentropy(self, y:Tensor) -> Tensor:
return (-y*self.log() - (1-y)*(1-self).log()).mean()
def binary_crossentropy_logits(self, y:Tensor) -> Tensor:
return (self.maximum(0) - y * self + (1 + self.abs().neg().exp()).log()).mean()
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0) -> Tensor:
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
# NOTE: self is a logits input
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index)
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
smoothing = -1 * label_smoothing * (log_probs.mean(-1) * loss_mask).sum() / loss_mask.sum()
return (1 - label_smoothing) * (log_probs * y).sum() / loss_mask.sum() + smoothing
# ***** cast ops *****
def llvm_bf16_cast(self, dtype:DType):
# hack for devices that don't support bfloat16
assert self.dtype == dtypes.bfloat16
return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
def cast(self, dtype:DType) -> Tensor: return self if self.dtype == dtype else mlops.Cast.apply(self, dtype=dtype)
def bitcast(self, dtype:DType) -> Tensor:
assert self.dtype.itemsize == dtype.itemsize, "can't bitcast mismatched dtype itemsizes"
return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
def float(self) -> Tensor: return self.cast(dtypes.float32)
def half(self) -> Tensor: return self.cast(dtypes.float16)
# ***** convenience stuff *****
@property
def ndim(self) -> int: return len(self.shape)
def numel(self) -> sint: return prod(self.shape)
def element_size(self) -> int: return self.dtype.itemsize
def nbytes(self) -> int: return self.numel() * self.element_size()
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
# register functions to move between devices
for device in Device._devices: setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, device))
if IMAGE:
# if IMAGE>0 we install these replacement functions in Tensor (hack!)
from tinygrad.features.image import image_conv2d, image_dot
setattr(Tensor, "conv2d", image_conv2d)
setattr(Tensor, "dot", image_dot)
# TODO: eventually remove this
def custom_random(out:Buffer):
Tensor._seed += 1
if DEBUG >= 2: print(f"*** {out.device} rand seed {Tensor._seed} size {out.size:<15d} dtype {out.dtype}")
rng = np.random.default_rng(Tensor._seed)
if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False)
else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype.np, copy=False)
out.copyin(rng_np_buffer.data)