Readability for unreadable functions (#1610)

* cleaned

* typing

* typing

* if format

* if format

* mypy

* update argmax

* argmax more readable

* More stable def pad

* lint
This commit is contained in:
Umut Zengin
2023-08-22 17:09:08 +03:00
committed by GitHub
parent 86a32ffb1a
commit 1e93fd5449

View File

@@ -5,7 +5,7 @@ from functools import partialmethod, reduce
from itertools import accumulate
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast
from math import ceil, pi, prod, sqrt, log, cos, copysign, isinf
from math import ceil, pi, prod, sqrt, log, cos, copysign
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
from tinygrad.lazy import Device, LazyBuffer
@@ -167,8 +167,7 @@ class Tensor:
def ones_like(tensor, **kwargs): return Tensor.full_like(tensor, 1, **kwargs)
@staticmethod
def eye(dim:int, **kwargs):
return Tensor([1], **kwargs).pad(((0,dim),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim)
def eye(dim:int, **kwargs): return Tensor.full((dim,1),1,**kwargs).pad(((0,0),(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim)
# ***** rng hlops *****
@@ -243,8 +242,7 @@ class Tensor:
def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
def pad(self, arg: Tuple[Tuple[int, int], ...], value:float=0) -> Tensor:
ret = mlops.Pad.apply(self, arg=arg) if any(x != (0, 0) for x in arg) else self
if isinf(value): return ret + copysign(1,value)/mlops.Pad.apply(Tensor.full(self.shape, value), arg=arg)
return ret if 0 == value else ret + (value - mlops.Pad.apply(Tensor.full(self.shape, value), arg=arg))
return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=arg).where(0, value)
# ***** movement hlops *****
@@ -378,19 +376,15 @@ class Tensor:
return first.cat(*unsqueezed_tensors, dim=dim)
def repeat(self, repeats):
base_shape = self.shape
if len(repeats) > self.ndim:
base_shape = (1,) * (len(repeats) - self.ndim) + base_shape
new_shape = [x for i in range(len(base_shape)) for x in [1, base_shape[i]]]
expand_shape = [x for r,s in zip(repeats, base_shape) for x in [r,s]]
base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
new_shape = [x for b in base_shape for x in [1, b]]
expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
final_shape = [r*s for r,s in zip(repeats, base_shape)]
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
# TODO: make this nicer with syntactic sugar in slice
def chunk(self, num, dim):
slice_params = [[slice(None) for s in self.shape] for _ in range(ceil(self.shape[dim]/ceil(self.shape[dim]/num)))]
for i, k in enumerate(range(0, self.shape[dim], ceil(self.shape[dim]/num))):
slice_params[i][dim] = slice(k, k + ceil(self.shape[dim]/num))
def chunk(self, num:int, dim:int) -> List[Tensor]:
dim, step = dim + self.ndim if dim < 0 else dim, ceil(self.shape[dim]/num)
slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)]
return [self[tuple(sl)] for sl in slice_params]
def squeeze(self, dim=None):
@@ -450,10 +444,12 @@ class Tensor:
return m - ss.log()
def argmax(self, axis=None, keepdim=False):
if axis is None: return prod(self.shape) - ((self == self.max(axis)).flatten() * Tensor.arange(prod(self.shape)-1,-1,-1)).max() - 1
axis = axis + self.ndim if axis < 0 else axis
m = self == (self.max(axis=axis, keepdim=keepdim) if keepdim else self.max(axis=axis, keepdim=keepdim).unsqueeze(axis))
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1).reshape(*[1]*axis, self.shape[axis], *[1]*(self.ndim-(axis+1)))
if axis is None:
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1).reshape(self.shape)
return prod(self.shape) - idx.max() - 1
axis = axis + len(self.shape) if axis < 0 else axis
m = self == self.max(axis=axis, keepdim=True)
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
return self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1
def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)
@@ -523,10 +519,7 @@ class Tensor:
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
return (x*w).sum(-1)
def cumsum(self, axis=0):
axis = (axis + self.ndim) if axis < 0 else axis
x = self.permute(*(i for i in range(self.ndim) if i != axis), axis)
return x.reshape(1, 1, -1, self.shape[axis]).conv2d(Tensor.ones(1, 1, 1, self.shape[axis], dtype=self.dtype, device=self.device), padding=(self.shape[axis]-1, 0, 0, 0)).reshape(*x.shape).permute(*range(axis), self.ndim - 1, *range(axis, self.ndim-1))
def cumsum(self, axis:int=0) -> Tensor: return self.transpose(axis,-1).pad2d((self.shape[axis]-1,0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
# ***** mlops (unary) *****