mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Readability for unreadable functions (#1610)
* cleaned * typing * typing * if format * if format * mypy * update argmax * argmax more readable * More stable def pad * lint
This commit is contained in:
@@ -5,7 +5,7 @@ from functools import partialmethod, reduce
|
||||
from itertools import accumulate
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast
|
||||
from math import ceil, pi, prod, sqrt, log, cos, copysign, isinf
|
||||
from math import ceil, pi, prod, sqrt, log, cos, copysign
|
||||
|
||||
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
|
||||
from tinygrad.lazy import Device, LazyBuffer
|
||||
@@ -167,8 +167,7 @@ class Tensor:
|
||||
def ones_like(tensor, **kwargs): return Tensor.full_like(tensor, 1, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def eye(dim:int, **kwargs):
|
||||
return Tensor([1], **kwargs).pad(((0,dim),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim)
|
||||
def eye(dim:int, **kwargs): return Tensor.full((dim,1),1,**kwargs).pad(((0,0),(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim)
|
||||
|
||||
# ***** rng hlops *****
|
||||
|
||||
@@ -243,8 +242,7 @@ class Tensor:
|
||||
def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
|
||||
def pad(self, arg: Tuple[Tuple[int, int], ...], value:float=0) -> Tensor:
|
||||
ret = mlops.Pad.apply(self, arg=arg) if any(x != (0, 0) for x in arg) else self
|
||||
if isinf(value): return ret + copysign(1,value)/mlops.Pad.apply(Tensor.full(self.shape, value), arg=arg)
|
||||
return ret if 0 == value else ret + (value - mlops.Pad.apply(Tensor.full(self.shape, value), arg=arg))
|
||||
return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=arg).where(0, value)
|
||||
|
||||
# ***** movement hlops *****
|
||||
|
||||
@@ -378,19 +376,15 @@ class Tensor:
|
||||
return first.cat(*unsqueezed_tensors, dim=dim)
|
||||
|
||||
def repeat(self, repeats):
|
||||
base_shape = self.shape
|
||||
if len(repeats) > self.ndim:
|
||||
base_shape = (1,) * (len(repeats) - self.ndim) + base_shape
|
||||
new_shape = [x for i in range(len(base_shape)) for x in [1, base_shape[i]]]
|
||||
expand_shape = [x for r,s in zip(repeats, base_shape) for x in [r,s]]
|
||||
base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
|
||||
new_shape = [x for b in base_shape for x in [1, b]]
|
||||
expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
|
||||
final_shape = [r*s for r,s in zip(repeats, base_shape)]
|
||||
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
|
||||
|
||||
# TODO: make this nicer with syntactic sugar in slice
|
||||
def chunk(self, num, dim):
|
||||
slice_params = [[slice(None) for s in self.shape] for _ in range(ceil(self.shape[dim]/ceil(self.shape[dim]/num)))]
|
||||
for i, k in enumerate(range(0, self.shape[dim], ceil(self.shape[dim]/num))):
|
||||
slice_params[i][dim] = slice(k, k + ceil(self.shape[dim]/num))
|
||||
def chunk(self, num:int, dim:int) -> List[Tensor]:
|
||||
dim, step = dim + self.ndim if dim < 0 else dim, ceil(self.shape[dim]/num)
|
||||
slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)]
|
||||
return [self[tuple(sl)] for sl in slice_params]
|
||||
|
||||
def squeeze(self, dim=None):
|
||||
@@ -450,10 +444,12 @@ class Tensor:
|
||||
return m - ss.log()
|
||||
|
||||
def argmax(self, axis=None, keepdim=False):
|
||||
if axis is None: return prod(self.shape) - ((self == self.max(axis)).flatten() * Tensor.arange(prod(self.shape)-1,-1,-1)).max() - 1
|
||||
axis = axis + self.ndim if axis < 0 else axis
|
||||
m = self == (self.max(axis=axis, keepdim=keepdim) if keepdim else self.max(axis=axis, keepdim=keepdim).unsqueeze(axis))
|
||||
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1).reshape(*[1]*axis, self.shape[axis], *[1]*(self.ndim-(axis+1)))
|
||||
if axis is None:
|
||||
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1).reshape(self.shape)
|
||||
return prod(self.shape) - idx.max() - 1
|
||||
axis = axis + len(self.shape) if axis < 0 else axis
|
||||
m = self == self.max(axis=axis, keepdim=True)
|
||||
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
|
||||
return self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1
|
||||
def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)
|
||||
|
||||
@@ -523,10 +519,7 @@ class Tensor:
|
||||
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
|
||||
return (x*w).sum(-1)
|
||||
|
||||
def cumsum(self, axis=0):
|
||||
axis = (axis + self.ndim) if axis < 0 else axis
|
||||
x = self.permute(*(i for i in range(self.ndim) if i != axis), axis)
|
||||
return x.reshape(1, 1, -1, self.shape[axis]).conv2d(Tensor.ones(1, 1, 1, self.shape[axis], dtype=self.dtype, device=self.device), padding=(self.shape[axis]-1, 0, 0, 0)).reshape(*x.shape).permute(*range(axis), self.ndim - 1, *range(axis, self.ndim-1))
|
||||
def cumsum(self, axis:int=0) -> Tensor: return self.transpose(axis,-1).pad2d((self.shape[axis]-1,0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
|
||||
|
||||
# ***** mlops (unary) *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user