tensor.py touchups (#3667)

* tensor.py touchups

* put back
This commit is contained in:
George Hotz
2024-03-09 16:12:20 -08:00
committed by GitHub
parent 69ca7f7bf9
commit 1143c62519

View File

@@ -1,10 +1,9 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, math, itertools
import time, math, itertools, functools
from contextlib import ContextDecorator
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, Dict, DefaultDict, cast, get_args
from collections import defaultdict
from functools import partialmethod, reduce
import numpy as np
from tinygrad.dtype import DType, dtypes, ImageDType, Scalar, least_upper_float, least_upper_dtype, cast_scalar
@@ -52,14 +51,16 @@ def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str,
for k in range(len(mat[0]))] for dim in range(dims)]
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
def _apply_winograd_matrix(mat, t:Tensor, dims:int):
def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
# multiply mat_1 @ mat_2 @ t with foldable constants, where mat_i acts on vector t along dimension i; roughly kron(mat, mat) @ t
# due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic
t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims
# precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...)
matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device)
# multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t
return sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
ret = sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
return ret
class Tensor:
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
@@ -110,8 +111,7 @@ class Tensor:
else:
self.lazydata = data if data.device == device else data.copy_to_device(device)
def __repr__(self):
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"
def __repr__(self): return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"
# Python has a non moving GC, so this should be okay
def __hash__(self): return id(self)
@@ -146,8 +146,7 @@ class Tensor:
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
# NOTE: we allow cross device assign
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
if isinstance(self.lazydata, MultiLazyBuffer):
assert self.lazydata.axis == x.lazydata.axis
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
assert not x.requires_grad # self requires_grad is okay?
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
if self.dtype == x.dtype and not getenv("DISALLOW_ASSIGN"):
@@ -495,7 +494,7 @@ class Tensor:
cat_dim_cumsum = [0, *itertools.accumulate(cat_dims)]
slc:List[List[Optional[Tuple[sint, sint]]]] = [[None for _ in self.shape] for _ in catargs]
for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d)
return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
return functools.reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
@staticmethod
def stack(tensors:Sequence[Tensor], dim:int=0) -> Tensor:
@@ -630,7 +629,8 @@ class Tensor:
rhs_order, rhs_letters = tuple(zip(*sorted(enumerate(output), key=lambda e:e[1]))) or ([], [])
# sum over all axes that's not in the output, then permute to the output order
return reduce(lambda a,b:a*b, xs_).sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in rhs_letters]).permute(rhs_order)
return functools.reduce(lambda a,b:a*b, xs_) \
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in rhs_letters]).permute(rhs_order)
# ***** processing ops *****
@@ -865,7 +865,7 @@ class Tensor:
if not isinstance(x, Tensor) and not reverse:
# simple pow identities
if x < 0: return self.reciprocal().pow(-x)
if x in [3,2,1,0]: return reduce(lambda acc,_: acc * self, range(int(x)), mlops.Zero.apply(self)+1)
if x in [3,2,1,0]: return functools.reduce(lambda acc,_: acc * self, range(int(x)), mlops.Zero.apply(self)+1)
if x == 0.5: return self.sqrt()
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp()
@@ -933,7 +933,7 @@ class Tensor:
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
return x.add(bias) if bias is not None else x
def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return reduce(lambda x,f: f(x), ll, self)
def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return functools.reduce(lambda x,f: f(x), ll, self)
def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor:
y = (self - self.mean(axis, keepdim=True))
@@ -1001,7 +1001,7 @@ class Tensor:
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
# register functions to move between devices
for device in Device._devices: setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device))
for device in Device._devices: setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, device))
if IMAGE:
# if IMAGE>0 we install these replacement functions in Tensor (hack!)