Files
tinygrad/tinygrad/tensor.py
Patrick Tsai 971d7f5d7c O(n) arange attempt (#3530)
* It works?

* Clamp correctly

* Refactor

* Make code better

* Undo some stuff

* First step to trying to make floats work

* Floats work in Python op but not metal because int div is different

Python integerdivision was implemented as // which rounds towards
negative infinity, but C integer division rounds towards 0 so there
is an off-by-1 division error

* arange does cumsum with ints and then multiplies by step

This is so loop optimization can remain int only

* Undo a lot of symbolic changes

* Final check

* Cleanup

* There can be multiple phis

* Fix multiple phi op removal

* const sets dtype correctly

* Fix bugs

* Fix a couple bugs and add loop vars to resolve

* missed one

* Don't trim too many ops

* Fix symbolic test

* Use ones instead of full

* Delete test

* Lint passes

* max node error

* Small updates to loop logic

* Remove unnecessary changes

* We are getting somewhere

* Simple case

* Fix

* rm, prn

* Better

* If NumNode doesn't work then continue

* clamp is needed for arange(256)

* Move everything into the optim fn

* Replace correctly

* Order optimizations better

* Delete

* mypy

* Test for simplification

* Rename

* Fix test

* update test description

* Undo more

* Cleanup

* No replaced_ops map

* Fix lint

* AssertionError

* back again

* Reinstate assertion

* Return true and make diff not as big

* Bigger range for test

* Change cumsum impl

* fix bug

* make big cumsum work

* lint

* Undo cumsum 2-stage removal

* No while helper

* optional min/max clamping

* floats work

* rm giant arange test

* fix python cast None

* Check phi parents

* one phi allowed per where

* Fix one phi per where

* Rework iteration

* Delete assertions

* convert to int

* Try mul -1 instead of neg for hip..?

* Remove one phi per where requirements

* one accum only

* Lint

* should simplify a loop at a time

* Don't get rid of loop explcitly

* Need to iterate backwards

* lint

* unary neg

* Make optim work for onnx and sum_pad_collapse

* Better message

* filter alu ops correctly

* Fix the limiter

* lint and simplify

* Add it back

* off by one error

* test wheres and phis

* test max ops and non-if stuff

* <=

* cast_scalar

* Oops

* Change test

* Pass loop uops instead of a modified map

* Cut param transfer between linearizer and uops

* Fix issues

* Fix lint

* fix efficientnet python 3.8 invalid syntax

* distinct vars in seen_vars

* accurate var names

---------

Co-authored-by: Patrick Tsai <patosai@users.noreply.github.com>
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-11 16:09:20 -07:00

1020 lines
61 KiB
Python

# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, math, itertools, functools
from contextlib import ContextDecorator
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, Dict, DefaultDict, cast, get_args
from collections import defaultdict
import numpy as np
from tinygrad.dtype import DType, dtypes, ImageDType, Scalar, least_upper_float, least_upper_dtype, cast_scalar
from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, WINO, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv
from tinygrad.lazy import LazyBuffer
from tinygrad.features.multi import MultiLazyBuffer
from tinygrad.ops import LoadOps
from tinygrad.device import Device, Buffer
from tinygrad.shape.symbolic import sint
from tinygrad.realize import run_schedule, create_schedule
# **** start with two base classes, Tensor and Function ****
class Function:
def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor):
self.device = device
self.needs_input_grad = [t.requires_grad for t in tensors]
self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
if self.requires_grad: self.parents = tensors
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")
@classmethod
def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
ctx = fxn(x[0].device, *x)
ret = Tensor.__new__(Tensor)
ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
return ret
import tinygrad.mlops as mlops
def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Optional[LazyBuffer]=None):
if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src)
return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None)
def _fromcpu(x: np.ndarray) -> LazyBuffer:
return LazyBuffer.loadop(LoadOps.EMPTY, x.shape, dtypes.from_np(x.dtype), "EXT",
_buf=Buffer("EXT", 0, dtypes.from_np(x.dtype), (memoryview(bytearray()), None)) if x.size == 0 else \
Buffer("EXT", prod(x.shape), dtypes.from_np(x.dtype), (flat_mv(np.require(x, requirements='C').data), x)))
def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str, Tuple[str, ...]]) -> List[List[Tensor]]:
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device) for m in mat], dim=dim)
for k in range(len(mat[0]))] for dim in range(dims)]
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
# multiply mat_1 @ mat_2 @ t with foldable constants, where mat_i acts on vector t along dimension i; roughly kron(mat, mat) @ t
# due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic
t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims
# precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...)
matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device)
# multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t
ret = sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
return ret
class Tensor:
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
__deletable__ = ('_ctx',)
training: ClassVar[bool] = False
class train(ContextDecorator):
def __init__(self, mode:bool = True): self.mode = mode
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
no_grad: ClassVar[bool] = False
class inference_mode(ContextDecorator):
def __init__(self, mode:bool = True): self.mode = mode
def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
def __init__(self, data:Union[None, Scalar, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer],
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
# tensors have gradients, buffers do not
self.grad: Optional[Tensor] = None
# NOTE: this can be in three states. False and None: no gradient, True: gradient
# None (the default) will be updated to True if it's put in an optimizer
self.requires_grad: Optional[bool] = requires_grad
# internal variables used for autograd graph construction
self._ctx: Optional[Function] = None
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
elif isinstance(data, get_args(Scalar)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
elif isinstance(data, bytes): data = _fromcpu(np.frombuffer(data, np.uint8))
elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
elif isinstance(data, list):
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool
elif d and all_int(d): dtype = dtype or dtypes.default_int
else: dtype = dtype or dtypes.default_float
# NOTE: cast at the end for the dtypes that do not have a numpy dtype
data = _fromcpu(np.array(data, dtype.np)).cast(dtype)
elif isinstance(data, np.ndarray):
if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
else: data = _fromcpu(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
# data is a LazyBuffer, but it might be on the wrong device
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
if isinstance(device, tuple):
# TODO: what if it's a MultiLazyBuffer on other devices?
self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = MultiLazyBuffer.from_sharded(data, device, None) if isinstance(data, LazyBuffer) else data
else:
self.lazydata = data if data.device == device else data.copy_to_device(device)
def __repr__(self): return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
# Python has a non moving GC, so this should be okay
def __hash__(self): return id(self)
def __bool__(self): raise TypeError("__bool__ on Tensor is not defined")
@property
def device(self) -> Union[str, Tuple[str, ...]]: return self.lazydata.device
@property
def shape(self) -> Tuple[sint, ...]: return self.lazydata.shape
@property
def dtype(self) -> DType: return self.lazydata.dtype
# ***** data handlers ****
@staticmethod
def corealize(lst:Iterable[Tensor]):
run_schedule(create_schedule(flatten([x.lazydata.lbs if isinstance(x.lazydata, MultiLazyBuffer) else [x.lazydata] for x in lst])))
def realize(self) -> Tensor:
Tensor.corealize([self])
return self
def assign(self, x) -> Tensor:
# TODO: this is a hack for writing to DISK. remove with working assign
if isinstance(self.device, str) and self.device.startswith("DISK"):
if x.__class__ is not Tensor: x = Tensor(x, device="EXT", dtype=self.dtype)
self.contiguous().realize().lazydata.base.realized.copyin(x.numpy().data)
return self
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
# NOTE: we allow cross device assign
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
assert not x.requires_grad # self requires_grad is okay?
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
if self.dtype == x.dtype and not getenv("DISALLOW_ASSIGN"):
if isinstance(self.lazydata, MultiLazyBuffer):
for d,s in zip(x.lazydata.lbs, self.lazydata.lbs): d.output_buffer = s.base.realized
else:
if self.lazydata.base.realized is not None: x.lazydata.output_buffer = self.lazydata.base.realized
self.lazydata = x.lazydata
return self
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
def _data(self) -> memoryview:
if 0 in self.shape: return memoryview(bytearray(0))
t = self if isinstance(self.device, str) else self.to(self.device[0]) # deal with multitensor
return cast(Buffer, t.cast(t.dtype.scalar()).contiguous().realize().lazydata.base.realized).as_buffer()
def data(self) -> memoryview:
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
return self._data().cast(self.dtype.fmt, self.shape if len(self.shape) else (1,))
def item(self) -> Scalar:
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
assert self.numel() == 1, "must have one element for item"
return self._data().cast(self.dtype.fmt)[0]
def numpy(self) -> np.ndarray:
assert self.dtype.np is not None, f"no np dtype for {self.dtype}"
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
return np.frombuffer(self._data(), dtype=self.dtype.np).reshape(self.shape)
def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor:
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
if device is None or device == self.device: return self
if not isinstance(device, str): return self.shard(device)
ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
if self.grad is not None: ret.grad = self.grad.to(device)
if hasattr(self, '_ctx'): ret._ctx = self._ctx
return ret
def to_(self, device:Optional[Union[str, Tuple[str, ...]]]):
real = self.to(device)
# TODO: is this assign?
if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
self.lazydata = real.lazydata
def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None) -> Tensor:
assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
canonical_devices = tuple(Device.canonicalize(x) for x in devices)
if axis is not None and axis < 0: axis += len(self.shape)
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices, requires_grad=self.requires_grad)
def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None):
self.lazydata = self.shard(devices, axis).lazydata
return self
# ***** creation llop entrypoint *****
@staticmethod
def _loadop(op, shape, device:Optional[Union[Tuple[str], str]]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
if isinstance(device, tuple):
return Tensor(MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \
for d in device], None), device, dtype, **kwargs)
return Tensor(LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), device, dtype, **kwargs)
@staticmethod
def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, argfix(*shape), **kwargs)
_seed: int = int(time.time())
@staticmethod
def manual_seed(seed=0): Tensor._seed = seed
@staticmethod
def rand(*shape, **kwargs): return Tensor._loadop(LoadOps.CUSTOM, argfix(*shape), arg=custom_random, **kwargs)
# ***** creation helper functions *****
@staticmethod
def full(shape:Tuple[sint, ...], fill_value:Scalar, **kwargs):
return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
@staticmethod
def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0.0, **kwargs)
@staticmethod
def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1.0, **kwargs)
@staticmethod
def arange(start, stop=None, step=1, **kwargs):
if stop is None: stop, start = start, 0
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
@staticmethod
def eye(dim:int, **kwargs):
return Tensor.ones((dim,1),**kwargs).pad((None,(0,dim))).flatten().shrink(((0,dim*dim),)).reshape(dim, dim)
def full_like(self, fill_value:Scalar, **kwargs):
return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
def zeros_like(self, **kwargs): return self.full_like(0, **kwargs)
def ones_like(self, **kwargs): return self.full_like(1, **kwargs)
# ***** rng hlops *****
@staticmethod
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
src = Tensor.rand((2, *argfix(*shape)), **kwargs)
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)
@staticmethod
def randint(*shape, low=0, high=10, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=low, high=high, dtype=dtypes.int32, **kwargs)
@staticmethod
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean
@staticmethod
def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
dtype = kwargs.pop("dtype", dtypes.default_float)
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
@staticmethod
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
@staticmethod
def glorot_uniform(*shape, **kwargs) -> Tensor:
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(argfix(*shape)[0]+prod(argfix(*shape)[1:])))**0.5)
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
@staticmethod
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
@staticmethod
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor:
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
weight = self.unsqueeze(0) if self.ndim == 1 else self
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1, device=self.device)
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.default_int)
# ***** toposort and backward pass *****
def deepwalk(self):
def _deepwalk(node, visited):
visited.add(node)
if getattr(node, "_ctx", None):
for i in node._ctx.parents:
if i not in visited: yield from _deepwalk(i, visited)
yield node
return list(_deepwalk(self, set()))
def backward(self) -> Tensor:
assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})"
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
# this is "implicit gradient creation"
self.grad = Tensor(1.0, device=self.device, requires_grad=False)
for t0 in reversed(self.deepwalk()):
if t0.grad is None: raise RuntimeError("tensor has no grad")
grads = t0._ctx.backward(t0.grad.lazydata)
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
for t, g in zip(t0._ctx.parents, grads):
if g is not None and t.requires_grad:
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
t.grad = g if t.grad is None else (t.grad + g)
del t0._ctx
return self
# ***** movement mlops *****
def reshape(self, shape, *args) -> Tensor:
new_shape = argfix(shape, *args)
new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])
return mlops.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
def expand(self, shape, *args) -> Tensor:
new_shape = tuple([x if x != -1 and x is not None else s for s,x in zip(self.shape, argfix(shape, *args))])
return mlops.Expand.apply(self, shape=new_shape) if new_shape != self.shape else self
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
if all(x is None or x == (0,s) for x,s in zip(arg, self.shape)): return self
return mlops.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape)))
def pad(self, arg:Tuple[Optional[Tuple[sint, sint]], ...], value:float=0.0) -> Tensor:
if all(x is None or x == (0,0) for x in arg): return self
ret = mlops.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg)))
return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value)
# ***** movement hlops *****
# Supported Indexing Implementations:
# 1. Int indexing (no copy)
# - for all dims where there's int, shrink -> reshape
# - negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
# - X = Tensor.rand(4,5,9); X[2,-2] shrinks the Tensor to X.shrink(((2, 3), (3, 4), (0, 9))) -> X.shape=(1,1,9)
# - Then we reshape (collapse) the int dim away such that for X: (1,1,9) -> (9,)
# 2. Slice indexing (no copy)
# - for all dims where slice is start:end:stride, shrink -> Optional[flip] -> pad -> reshape -> shrink
# - first shrink the Tensor to X.shrink(((start, end),))
# - then we apply stride through Optional[flip] -> pad -> reshape -> shrink
# - flip where dim value is negative
# - pad 0's on dims such that reshaping [dim_size_padded] -> [dim_size_padded // stride, stride] is possible
# - shrink [dim_size_padded // stride, stride] -> [dim_size_padded // stride, 1]
# - reshape [dim_size_padded // stride, 1] -> [dim_size_padded // stride] and now you have your stride
# 3. None indexing (no copy)
# - reshape (inject) a dim at the dim where there's None
# 4. Tensor indexing (copy)
# - use Tensor.arange == tensor_index to create a mask
# - apply mask to self by mask * self for dims where index is a tensor
# - (mask * self).sum(dim) to reduce to correct shape
# Tiny Things:
# 1. Supported indices: Union[int, slice, Tensor, None, List, Tuple, Ellipsis]
# - for any list, List[Union[List, Tuple, int]], must have homogeneous shape
# - for any tuple, Tuple[Union[List, Tuple, int]], must have homogeneous shape
# 2. Bool indexing is not supported
# 3. Out of bounds Tensor indexing results in 0
# - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are OOB
def __getitem__(self, indices) -> Tensor:
# 1. indices normalization and validation
# treat internal tuples and lists as Tensors and standardize indices to list type
if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, self.device, requires_grad=False)]
elif isinstance(indices, (tuple, list)):
indices = [Tensor(list(i), self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices]
else: indices = [indices]
# turn scalar Tensors into const val for int indexing if possible
indices = [self._to_const_val(i) if isinstance(i, Tensor) else i for i in indices]
# move Tensor indices to the same device as self
indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices]
# filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis]
fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
indices[fill_idx:fill_idx+1] = [slice(None)] * (len(self.shape) - num_indices)
# use Dict[type, List[dimension]] to track elements in indices
type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
# record None for dimension injection later and filter None and record rest of indices
type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
indices_filtered = [v for v in indices if v is not None]
for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
for index_type in type_dim:
if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported")
if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}")
# 2. basic indexing, uses only movement ops (no copy)
# currently indices_filtered: Tuple[Union[slice, int, Tensor], ...]
# turn indices in indices_filtered to Tuple[shrink_arg, strides]
for dim in type_dim[int]:
if (index := indices_filtered[dim]) >= (size := self.shape[dim]) or index < -size:
raise IndexError(f"{index=} is out of bounds on {dim=} with {size=}")
indices_filtered[dim] = ((index, index+1), 1) if index >= 0 else ((size+index, size+index+1), 1)
for dim in type_dim[slice]:
if (index := indices_filtered[dim]).step == 0: raise ValueError(f"{index=} on {dim=} cannot have 0 as step")
s, e, st = index.indices(self.shape[dim])
indices_filtered[dim] = ((0, 0) if (st * (e - s)) < 0 else (s, e) if st > 0 else (e+1, s+1), st)
# record tensors and skip all Tensor dims for basic indexing
tensor_index: List[Tensor] = []
for dim in type_dim[Tensor]:
tensor_index.append(index := indices_filtered[dim])
if not dtypes.is_int(index.dtype): raise IndexError(f"{index.dtype=} on {dim=} is not supported, only int tensor indexing is supported")
indices_filtered[dim] = ((0, self.shape[dim]), 1)
new_slice, strides = ((),()) if not indices_filtered else zip(*indices_filtered)
ret = self.shrink(new_slice).flip(tuple(i for i, s in enumerate(strides) if s < 0))
if any(abs(s) != 1 for s in strides):
strides = tuple(abs(s) for s in strides)
ret = ret.pad(tuple((0, round_up(sh, s) - sh) for s, sh in zip(strides, ret.shape)))
ret = ret.reshape(tuple(flatten((sh // s, s) for s, sh in zip(strides, ret.shape))))
ret = ret.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in ret.shape[::2]))).reshape(ret.shape[::2])
# inject 1 for dim where it's None and collapse dim for int
new_shape = list(ret.shape)
for dim in type_dim[None]: new_shape.insert(dim, 1)
for dim in (dims_collapsed := tuple(dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int]))): new_shape.pop(dim)
ret = ret.reshape(new_shape)
assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}"
# 3. advanced indexing (copy)
if type_dim[Tensor]:
# calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim
def calc_dim(tensor_dim:int) -> int:
return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d) + sum(1 for d in type_dim[None] if tensor_dim >= d)
# track tensor_dim and tensor_index using a dict
# calc_dim to get dim and use that to normalize the negative tensor indices
idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(type_dim[Tensor],tensor_index)}
# compute sum_dim, arange, and idx
max_idx_dim, first_dim, last_dim = max(i.ndim for i in idx.values()), min(idx.keys()), max(idx.keys())
sum_dim = tuple(d if n==0 else d+max_idx_dim-n for n,d in enumerate(idx.keys()))
arange = [Tensor.arange(ret.shape[d], requires_grad=False, device=self.device).reshape(ret.shape[d], *[1]*(ret.ndim+max_idx_dim-n-sd-1)) \
for n,(sd,d) in enumerate(zip(sum_dim, idx.keys()))]
reshaped_idx = [i.reshape(i.shape + (1,)*(ret.ndim - first_dim - (n or 1))) for n,i in enumerate(idx.values())]
ret = ret.reshape(ret.shape[:first_dim+1] + (1,)*max_idx_dim + ret.shape[first_dim+1:])
# iteratively eq -> mul -> sum fancy index
try:
for a,i,sd in zip(arange, reshaped_idx, sum_dim): ret = (a==i).mul(ret).sum(sd)
except AssertionError as exc: raise IndexError("cannot broadcast indices") from exc
# special permute case
if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
ret_dims = list(range(ret.ndim))
ret = ret.permute(ret_dims[first_dim:first_dim+max_idx_dim] + ret_dims[:first_dim] + ret_dims[first_dim+max_idx_dim:])
return ret
def __setitem__(self,indices,v): return self.__getitem__(indices).assign(v)
# NOTE: using slice is discouraged and things should migrate to pad and shrink
def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
arg_ = tuple(a if a is not None else (0, s) for s,a in zip(self.shape, arg))
padding = tuple((max(0, -l), max(0, r-s)) for s,(l,r) in zip(self.shape, arg_))
return self.pad(padding, value=value).shrink(tuple((l + pl, r + pl) for (l,r),(pl,_) in zip(arg_, padding)))
def gather(self:Tensor, idx:Tensor, dim:int) -> Tensor:
assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim"
assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape"
dim = self._resolve_dim(dim)
idx = idx.to(self.device).transpose(ax1=dim, ax2=0).unsqueeze(-1)
permarg = list(range(self.ndim))
permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]]
return ((idx == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(
tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim)
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
dim = self._resolve_dim(dim)
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
catargs = [self, *args]
cat_dims = [s.shape[dim] for s in catargs]
cat_dim_cumsum = [0, *itertools.accumulate(cat_dims)]
slc:List[List[Optional[Tuple[sint, sint]]]] = [[None for _ in self.shape] for _ in catargs]
for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d)
return functools.reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
@staticmethod
def stack(tensors:Sequence[Tensor], dim:int=0) -> Tensor:
unsqueezed_tensors = [tensor.unsqueeze(dim) for tensor in tensors]
# checks for shapes and number of dimensions delegated to cat
return unsqueezed_tensors[0].cat(*unsqueezed_tensors[1:], dim=dim)
def repeat(self, repeats:Sequence[int]) -> Tensor:
base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
new_shape = [x for b in base_shape for x in [1, b]]
expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
final_shape = [r*s for r,s in zip(repeats, base_shape)]
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
def _resolve_dim(self, dim:int, *, outer:bool=False) -> int:
if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer):
raise IndexError(f"{dim=} out of range {[-max(1, self.ndim+outer), max(1, self.ndim+outer)-1]}")
return dim + self.ndim+outer if dim < 0 else dim
def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
dim = self._resolve_dim(dim)
if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
def chunk(self, num:int, dim:int=0) -> List[Tensor]:
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
assert num > 0, f"expect num to be greater than 0, got: {num}"
dim = self._resolve_dim(dim)
return list(self.split(math.ceil(self.shape[dim]/num) if self.shape[dim] else [0]*num, dim=dim))
def squeeze(self, dim:Optional[int]=None) -> Tensor:
if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
dim = self._resolve_dim(dim)
return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
def unsqueeze(self, dim:int) -> Tensor:
dim = self._resolve_dim(dim, outer=True)
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
# (padding_left, padding_right, padding_top, padding_bottom)
def pad2d(self, padding:Sequence[int], value:float=0) -> Tensor:
slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1]
return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
@property
def T(self) -> Tensor: return self.transpose()
def transpose(self, ax1=1, ax2=0) -> Tensor:
order = list(range(self.ndim))
order[ax1], order[ax2] = order[ax2], order[ax1]
return self.permute(order)
def flatten(self, start_dim=0, end_dim=-1):
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
def unflatten(self, dim:int, sizes:Tuple[int,...]):
dim = self._resolve_dim(dim)
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
# ***** reduce ops *****
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False) -> Tensor:
axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis))
axis_ = tuple(x if x >= 0 else x+len(self.shape) for x in axis_)
shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_)
ret = fxn.apply(self, axis=axis_)
return ret if keepdim else ret.reshape(shape=shape)
def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None):
if acc_dtype is None: acc_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \
least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \
least_upper_dtype(self.dtype, dtypes.float)
# cast back to float16 or bfloat16 to match torch / jax behavior, but we use float for acc
output_dtype = self.dtype if self.dtype in (dtypes.float16, dtypes.bfloat16) else acc_dtype
return self.cast(acc_dtype)._reduce(mlops.Sum, axis, keepdim).cast(output_dtype)
def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim)
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
def mean(self, axis=None, keepdim=False):
assert all_int(self.shape), "does not support symbolic shape"
out = self.sum(axis=axis, keepdim=keepdim)
return out.div(prod(self.shape) / prod(out.shape)) if 0 not in out.shape else out
def var(self, axis=None, keepdim=False, correction=1):
assert all_int(self.shape), "does not support symbolic shape"
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
return square_sum.div(max(0, prod(self.shape)/prod(square_sum.shape)-correction))
def std(self, axis=None, keepdim=False, correction=1): return self.var(axis, keepdim, correction).sqrt()
def _softmax(self, axis):
if len(self.shape) == 0:
assert axis in [-1, 0], f"{axis=} out of range of [-1, 0]"
axis = None
m = self - self.max(axis=axis, keepdim=True)
e = m.exp()
return m, e, e.sum(axis=axis, keepdim=True)
def softmax(self, axis=-1):
_, e, ss = self._softmax(axis)
return e.div(ss)
def log_softmax(self, axis=-1):
m, _, ss = self._softmax(axis)
return m - ss.log()
def argmax(self, axis=None, keepdim=False):
if axis is None:
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape)
return (prod(self.shape) - idx.max() - 1).cast(dtypes.default_int)
axis = axis + len(self.shape) if axis < 0 else axis
m = self == self.max(axis=axis, keepdim=True)
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.default_int)
def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)
@staticmethod
def einsum(formula:str, *raw_xs) -> Tensor:
xs:Tuple[Tensor] = argfix(*raw_xs)
formula = formula.replace(" ", "")
inputs_str, output = formula.split("->") if "->" in formula else (formula, sorted(formula))
inputs = [x for x in cast(str,inputs_str).split(',')]
assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
# map the value of each letter in the formula
letter_val = sorted(merge_dicts([{letter:dim for letter, dim in zip(letters, tensor.shape)} for letters, tensor in zip(inputs, xs)]).items())
xs_:List[Tensor] = []
lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs]
for x,(order,letters) in zip(xs, [list(zip(*l)) for l in lhs]):
# permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val]))
rhs_order, rhs_letters = tuple(zip(*sorted(enumerate(output), key=lambda e:e[1]))) or ([], [])
# sum over all axes that's not in the output, then permute to the output order
return functools.reduce(lambda a,b:a*b, xs_) \
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in rhs_letters]).permute(rhs_order)
# ***** processing ops *****
def _pool(self, k_:Tuple[sint, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
assert all_int(self.shape) and all_int(k_), f"does not support symbolic {self.shape=}, {k_=}"
s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
noop_, i_ = [None] * len(self.shape[:-len(k_)]), self.shape[-len(k_):]
if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)]
# repeats such that we don't need padding
xup = self.repeat([1]*len(noop_) + [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)])
# slice by dilation
xup = xup.slice(noop_ + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)]).reshape(noop_ + flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
# handle stride
xup = xup.slice(noop_ + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_))).reshape(noop_ + flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
xup = xup.slice(noop_ + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_))).reshape(noop_ + flatten((k,o) for k,o in zip(k_, o_)))
# permute to move reduce to the end
return xup.permute(*range(len(noop_)), *[len(noop_)+i*2+1 for i in range(len(i_))], *[len(noop_)+i*2 for i in range(len(i_))])
# TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
xup = self.slice(noop_ + [(0,o*s) for o,s in zip(o_, s_)])
xup = xup.reshape(noop_ + flatten(((o,s) for o,s in zip(o_, s_))))
xup = xup.slice(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
return xup.permute(*range(len(noop_)), *[len(noop_)+i*2 for i in range(len(i_))], *[len(noop_)+i*2+1 for i in range(len(i_))])
# NOTE: these work for more than 2D
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(
make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(
make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
x, w = self, weight.unflatten(0, (groups, -1)).permute(0,2,1,*trailing).flip(trailing)
stride = make_pair(stride, len(HW))
if any(s>1 for s in stride):
x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(
zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW)))))))
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor:
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
padding_ = [padding]*2*len(HW) if isinstance(padding, int) else (padding if len(padding) == 2*len(HW) else [p for p in padding for _ in range(2)][::-1]) # noqa: E501
# conv2d is a pooling op (with padding)
x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not WINO:
# normal conv
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, acc_dtype=acc_dtype).reshape(bs, cout, *oyx) # noqa: E501
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]]
winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order doubles compile time
# todo: stride == dilation
# use padding to round up to 4x4 output tiles
# (bs, cin_, tyx, HWI)
d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501
# move HW to the front: # (HWI, bs, cin_, tyx)
d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW)))
tyx = d.shape[-len(HWI):] # dim of tiling
g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
# compute 6x6 winograd tiles: GgGt, BtdB
# (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
# (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx)
# matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), acc_dtype=acc_dtype), len(HW))
# interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]])
# merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final
ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(tuple((0, s) for s in [bs, cout, *oyx]))
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
def dot(self, w:Tensor, acc_dtype:Optional[DType]=None) -> Tensor:
n1, n2 = len(self.shape), len(w.shape)
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
assert (L:=self.shape[-1]) == (R:=w.shape[-min(n2, 2)]), f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})"
x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype))
def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DType]=None) -> Tensor:
return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
pl_sz = self.shape[axis] - int(not _first_zero and self.shape[axis] != 0)
return self.transpose(axis,-1).pad2d((pl_sz,0))._pool((self.shape[axis] or 1,)).sum(-1).transpose(axis,-1)
def cumsum(self, axis:int=0) -> Tensor:
# TODO: someday the optimizer will find this on it's own
# for now this is a two stage cumsum
SPLIT = 256
if self.shape[axis] <= SPLIT*2: return self._cumsum(axis)
ret = self.transpose(axis,-1).pad2d((round_up(self.shape[axis], SPLIT)-self.shape[axis], 0))
ret = ret.unflatten(-1, (-1, SPLIT))._cumsum(-1)
base_add = ret[..., -1]._cumsum(-1, _first_zero=True)[..., :-1]
base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -self.shape[axis]:].transpose(axis,-1)
return fix(ret) + fix(base_add)
@staticmethod
def _tri(r:sint, c:sint, k:int=0, **kwargs) -> Tensor:
assert all_int((r,c)), "does not support symbolic"
if r == 0: return Tensor.zeros((r, c), **kwargs)
return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0)
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self)
# ***** mlops (unary) *****
def logical_not(self): return mlops.Eq.apply(*self._broadcasted(False))
def neg(self): return mlops.Neg.apply(self) if self.dtype != dtypes.bool else self.logical_not()
def contiguous(self): return mlops.Contiguous.apply(self)
def contiguous_backward(self): return mlops.ContiguousBackward.apply(self)
def log(self): return mlops.Log.apply(self.cast(least_upper_float(self.dtype)))
def log2(self): return self.log()/math.log(2)
def exp(self): return mlops.Exp.apply(self.cast(least_upper_float(self.dtype)))
def exp2(self): return mlops.Exp.apply(self*math.log(2))
def relu(self): return mlops.Relu.apply(self)
def sigmoid(self): return mlops.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
def sin(self): return mlops.Sin.apply(self.cast(least_upper_float(self.dtype)))
def sqrt(self): return mlops.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
def rsqrt(self): return self.reciprocal().sqrt()
def cos(self): return ((math.pi/2)-self).sin()
def tan(self): return self.sin() / self.cos()
# ***** math functions (unary) *****
def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).cast(self.dtype)
def ceil(self: Tensor) -> Tensor: return (self > (b := self.trunc())).where(b+1, b)
def floor(self: Tensor) -> Tensor: return (self < (b := self.trunc())).where(b-1, b)
def round(self: Tensor) -> Tensor:
return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
def square(self): return self*self
def clip(self, min_, max_): return self.maximum(min_).minimum(max_)
def abs(self): return self.relu() + (-self).relu()
def sign(self): return ((self.float()) / (self.float().abs() + 1e-12)).cast(self.dtype)
def reciprocal(self): return mlops.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
# ***** activation functions (unary) *****
def elu(self, alpha=1.0): return self.relu() - alpha*(1-self.exp()).relu()
def celu(self, alpha=1.0): return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
def swish(self): return self * self.sigmoid()
def silu(self): return self.swish() # The SiLU function is also known as the swish function.
def relu6(self): return self.relu() - (self-6).relu()
def hardswish(self): return self * (self+3).relu6() * (1/6)
def tanh(self): return 2.0 * ((2.0 * self).sigmoid()) - 1.0
def sinh(self): return (self.exp() - self.neg().exp()) / 2
def cosh(self): return (self.exp() + self.neg().exp()) / 2
def atanh(self): return ((1 + self)/(1 - self)).log() / 2
def asinh(self): return (self + (self.square() + 1).sqrt()).log()
def acosh(self): return (self + (self.square() - 1).sqrt()).log()
def hardtanh(self, min_val=-1, max_val=1): return self.clip(min_val, max_val)
def gelu(self): return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
def quick_gelu(self): return self * (self * 1.702).sigmoid()
def leakyrelu(self, neg_slope=0.01): return self.relu() - (-neg_slope*self).relu()
def mish(self): return self * self.softplus().tanh()
def softplus(self, beta=1): return (1/beta) * (1 + (self*beta).exp()).log()
def softsign(self): return self / (1 + self.abs())
# ***** broadcasted elementwise mlops *****
def _broadcasted(self, y:Union[Tensor, Scalar], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
x: Tensor = self
if not isinstance(y, Tensor):
# make y a Tensor
assert isinstance(y, (float, int, bool)), f"{type(y)=}, {y=}"
if isinstance(self.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
else: y_dtype = dtypes.from_py(y)
y = Tensor(cast_scalar(y, y_dtype), self.device, y_dtype, requires_grad=False)
if match_dtype:
output_dtype = least_upper_dtype(x.dtype, y.dtype)
x, y = x.cast(output_dtype), y.cast(output_dtype)
if reverse: x, y = y, x
# left pad shape with 1s
if len(y.shape) < len(x.shape): y = y.reshape((1,) * (len(x.shape) - len(y.shape)) + y.shape)
elif len(x.shape) < len(y.shape): x = x.reshape((1,) * (len(y.shape) - len(x.shape)) + x.shape)
broadcasted_shape = tuple(0 if xi==0 or yi==0 else max(xi, yi) for xi, yi in zip(x.shape, y.shape))
return x.expand(broadcasted_shape), y.expand(broadcasted_shape)
def _to_const_val(self, x:Union[Tensor, Scalar]) -> Union[Tensor, Scalar]:
# TODO: update with multi
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_contiguous_const() \
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
def add(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
x = self._to_const_val(x)
return mlops.Add.apply(*self._broadcasted(x, reverse)) if isinstance(x, Tensor) or x else self
def sub(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
x = self._to_const_val(x)
return mlops.Sub.apply(*self._broadcasted(x, reverse)) if isinstance(x, Tensor) or x else (-self if reverse else self)
def mul(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
x = self._to_const_val(x)
if not isinstance(x, Tensor) and x == 0.0: return mlops.Zero.apply(self)
if not isinstance(x, Tensor) and x == -1.0: return -self
return mlops.Mul.apply(*self._broadcasted(x, reverse)) if isinstance(x, Tensor) or x != 1.0 else self
def div(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
x = self._to_const_val(x)
if not isinstance(x, Tensor) and not reverse and x != 0: return self.mul(1/x)
if isinstance(x, Tensor) and dtypes.is_float(x.dtype): return mlops.Div.apply(*self._broadcasted(x, reverse))
return mlops.Div.apply(*self.cast(least_upper_float(self.dtype))._broadcasted(x, reverse))
def xor(self, x:Tensor, reverse=False) -> Tensor: return mlops.Xor.apply(*self._broadcasted(x, reverse))
def pow(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
x = self._to_const_val(x)
if not isinstance(x, Tensor) and not reverse:
# simple pow identities
if x < 0: return self.reciprocal().pow(-x)
if x in [3,2,1,0]: return functools.reduce(lambda acc,_: acc * self, range(int(x)), mlops.Zero.apply(self)+1)
if x == 0.5: return self.sqrt()
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp()
# correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the power)
sign = (x * math.pi).cos() if isinstance(x, Tensor) else math.cos(x * math.pi) if not reverse else (self * math.pi).cos()
# we only need to correct the sign if the base is negative
base_sign = ((self.sign() if not reverse else x.sign() if isinstance(x, Tensor) else math.copysign(1, x)) - 1) / -2
# we need 0 to be positive so we need to correct base_sign when the base is 0
base_sign = base_sign - (1.5 * (1 - (self.sign().abs() if not reverse else x.sign().abs() if isinstance(x, Tensor) else abs(int(bool(x))))))
# inject nan if the base is negative and the power is not an integer
to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else \
int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign
inject_nan = ((((-to_nan) * 2) + 1)).log().add(1) if isinstance(to_nan, Tensor) else 1 if not to_nan else float("nan")
return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
def maximum(self, x:Union[Tensor, Scalar]) -> Tensor:
return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
def minimum(self, x:Union[Tensor, Scalar]) -> Tensor: return -((-self).maximum(-x))
def where(self:Tensor, input_:Union[Tensor, Scalar], other:Union[Tensor, Scalar]):
if isinstance(input_, Tensor): input_, other = input_._broadcasted(other)
elif isinstance(other, Tensor): other, input_ = other._broadcasted(input_)
x_,y = self._broadcasted(input_, match_dtype=False)
x,z = x_._broadcasted(other, match_dtype=False)
return mlops.Where.apply(x.cast(dtypes.bool), *y._broadcasted(z))
# ***** op wrappers (wasted lines to make the typechecker happy) *****
def __neg__(self) -> Tensor: return self.neg()
def __add__(self, x) -> Tensor: return self.add(x)
def __sub__(self, x) -> Tensor: return self.sub(x)
def __mul__(self, x) -> Tensor: return self.mul(x)
def __pow__(self, x) -> Tensor: return self.pow(x)
def __truediv__(self, x) -> Tensor: return self.div(x)
def __matmul__(self, x) -> Tensor: return self.matmul(x)
def __xor__(self, x) -> Tensor: return self.xor(x)
def __radd__(self, x) -> Tensor: return self.add(x, True)
def __rsub__(self, x) -> Tensor: return self.sub(x, True)
def __rmul__(self, x) -> Tensor: return self.mul(x, True)
def __rpow__(self, x) -> Tensor: return self.pow(x, True)
def __rtruediv__(self, x) -> Tensor: return self.div(x, True)
def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True)
def __rxor__(self, x) -> Tensor: return self.xor(x, True)
def __iadd__(self, x) -> Tensor: return self.assign(self.add(x))
def __isub__(self, x) -> Tensor: return self.assign(self.sub(x))
def __imul__(self, x) -> Tensor: return self.assign(self.mul(x))
def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False))
def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True))
def __ge__(self, x) -> Tensor: return (self<x).logical_not()
def __le__(self, x) -> Tensor: return (self>x).logical_not()
def __eq__(self, x) -> Tensor: return mlops.Eq.apply(*self._broadcasted(x, True)) # type: ignore[override]
def __ne__(self, x) -> Tensor: return (self==x).logical_not() # type: ignore[override]
# ***** functional nn ops *****
def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
return x.add(bias) if bias is not None else x
def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return functools.reduce(lambda x,f: f(x), ll, self)
def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor:
y = (self - self.mean(axis, keepdim=True))
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor:
axis_ = argfix(axis)
shape = tuple(s if ax in axis_ else 1 for ax, s in enumerate(self.shape))
x = self - mean.reshape(shape)
if weight is not None: x = x * weight.reshape(shape)
ret = x.mul(invstd.reshape(shape) if len(invstd.shape) == len(axis_) else invstd)
return (ret + bias.reshape(shape)) if bias is not None else ret
def dropout(self, p=0.5) -> Tensor:
if not Tensor.training or p == 0: return self
return self * (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p) * (1/(1.0 - p))
def one_hot(self, num_classes:int) -> Tensor:
return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0)
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,
dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
# NOTE: it works if key, value have symbolic shape
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
qk = self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1])
return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).dropout(dropout_p) @ value
def binary_crossentropy(self, y:Tensor) -> Tensor:
return (-y*self.log() - (1-y)*(1-self).log()).mean()
def binary_crossentropy_logits(self, y:Tensor) -> Tensor:
return (self.maximum(0) - y * self + (1 + self.abs().neg().exp()).log()).mean()
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0) -> Tensor:
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
# NOTE: self is a logits input
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index)
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
smoothing = -1 * label_smoothing * (log_probs.mean(-1) * loss_mask).sum() / loss_mask.sum()
return (1 - label_smoothing) * (log_probs * y).sum() / loss_mask.sum() + smoothing
# ***** cast ops *****
def cast(self, dtype:DType) -> Tensor:
if self.dtype == dtype: return self
# hack for devices that don't support bfloat16
if self.dtype == dtypes.bfloat16: return self.bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
return mlops.Cast.apply(self, dtype=dtype)
def bitcast(self, dtype:DType) -> Tensor:
assert self.dtype.itemsize == dtype.itemsize, "can't bitcast mismatched dtype itemsizes"
return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
def float(self) -> Tensor: return self.cast(dtypes.float32)
def half(self) -> Tensor: return self.cast(dtypes.float16)
# ***** convenience stuff *****
@property
def ndim(self) -> int: return len(self.shape)
def numel(self) -> sint: return prod(self.shape)
def element_size(self) -> int: return self.dtype.itemsize
def nbytes(self) -> int: return self.numel() * self.element_size()
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
# register functions to move between devices
for device in Device._devices: setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, device))
if IMAGE:
# if IMAGE>0 we install these replacement functions in Tensor (hack!)
from tinygrad.features.image import image_conv2d, image_dot
setattr(Tensor, "conv2d", image_conv2d)
setattr(Tensor, "dot", image_dot)
# TODO: remove the custom op and replace with threefry
def custom_random(out:Buffer):
Tensor._seed += 1
if DEBUG >= 2: print(f"*** {out.device} rand seed {Tensor._seed} size {out.size:<15d} dtype {out.dtype}")
rng = np.random.default_rng(Tensor._seed)
if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False)
else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype.np, copy=False)
out.copyin(rng_np_buffer.data)