mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
@@ -1,10 +1,9 @@
|
||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import time, math, itertools
|
||||
import time, math, itertools, functools
|
||||
from contextlib import ContextDecorator
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, Dict, DefaultDict, cast, get_args
|
||||
from collections import defaultdict
|
||||
from functools import partialmethod, reduce
|
||||
import numpy as np
|
||||
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, Scalar, least_upper_float, least_upper_dtype, cast_scalar
|
||||
@@ -52,14 +51,16 @@ def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str,
|
||||
for k in range(len(mat[0]))] for dim in range(dims)]
|
||||
|
||||
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
|
||||
def _apply_winograd_matrix(mat, t:Tensor, dims:int):
|
||||
def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
|
||||
# multiply mat_1 @ mat_2 @ t with foldable constants, where mat_i acts on vector t along dimension i; roughly kron(mat, mat) @ t
|
||||
# due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic
|
||||
t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims
|
||||
# precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...)
|
||||
matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device)
|
||||
# multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t
|
||||
return sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
|
||||
ret = sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
|
||||
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
|
||||
return ret
|
||||
|
||||
class Tensor:
|
||||
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
|
||||
@@ -110,8 +111,7 @@ class Tensor:
|
||||
else:
|
||||
self.lazydata = data if data.device == device else data.copy_to_device(device)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"
|
||||
def __repr__(self): return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"
|
||||
|
||||
# Python has a non moving GC, so this should be okay
|
||||
def __hash__(self): return id(self)
|
||||
@@ -146,8 +146,7 @@ class Tensor:
|
||||
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
|
||||
# NOTE: we allow cross device assign
|
||||
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
||||
if isinstance(self.lazydata, MultiLazyBuffer):
|
||||
assert self.lazydata.axis == x.lazydata.axis
|
||||
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
|
||||
assert not x.requires_grad # self requires_grad is okay?
|
||||
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
|
||||
if self.dtype == x.dtype and not getenv("DISALLOW_ASSIGN"):
|
||||
@@ -495,7 +494,7 @@ class Tensor:
|
||||
cat_dim_cumsum = [0, *itertools.accumulate(cat_dims)]
|
||||
slc:List[List[Optional[Tuple[sint, sint]]]] = [[None for _ in self.shape] for _ in catargs]
|
||||
for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d)
|
||||
return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
|
||||
return functools.reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
|
||||
|
||||
@staticmethod
|
||||
def stack(tensors:Sequence[Tensor], dim:int=0) -> Tensor:
|
||||
@@ -630,7 +629,8 @@ class Tensor:
|
||||
|
||||
rhs_order, rhs_letters = tuple(zip(*sorted(enumerate(output), key=lambda e:e[1]))) or ([], [])
|
||||
# sum over all axes that's not in the output, then permute to the output order
|
||||
return reduce(lambda a,b:a*b, xs_).sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in rhs_letters]).permute(rhs_order)
|
||||
return functools.reduce(lambda a,b:a*b, xs_) \
|
||||
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in rhs_letters]).permute(rhs_order)
|
||||
|
||||
# ***** processing ops *****
|
||||
|
||||
@@ -865,7 +865,7 @@ class Tensor:
|
||||
if not isinstance(x, Tensor) and not reverse:
|
||||
# simple pow identities
|
||||
if x < 0: return self.reciprocal().pow(-x)
|
||||
if x in [3,2,1,0]: return reduce(lambda acc,_: acc * self, range(int(x)), mlops.Zero.apply(self)+1)
|
||||
if x in [3,2,1,0]: return functools.reduce(lambda acc,_: acc * self, range(int(x)), mlops.Zero.apply(self)+1)
|
||||
if x == 0.5: return self.sqrt()
|
||||
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
|
||||
ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp()
|
||||
@@ -933,7 +933,7 @@ class Tensor:
|
||||
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
|
||||
return x.add(bias) if bias is not None else x
|
||||
|
||||
def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return reduce(lambda x,f: f(x), ll, self)
|
||||
def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return functools.reduce(lambda x,f: f(x), ll, self)
|
||||
|
||||
def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor:
|
||||
y = (self - self.mean(axis, keepdim=True))
|
||||
@@ -1001,7 +1001,7 @@ class Tensor:
|
||||
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
|
||||
|
||||
# register functions to move between devices
|
||||
for device in Device._devices: setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device))
|
||||
for device in Device._devices: setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, device))
|
||||
|
||||
if IMAGE:
|
||||
# if IMAGE>0 we install these replacement functions in Tensor (hack!)
|
||||
|
||||
Reference in New Issue
Block a user