mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 07:18:40 -05:00
types (#2346)
This commit is contained in:
@@ -129,22 +129,23 @@ class Kernel:
|
||||
@property
|
||||
def membufs(self) -> List[MemBuffer]: return [x for x in self.bufs if isinstance(x, MemBuffer)]
|
||||
|
||||
def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()]
|
||||
def float4_axis(self, i): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0]
|
||||
# TODO: these need more tests or it might silently be no-op
|
||||
def shape_offsets(self, i:int): return itertools.product(*[list(range(cast(int, s))) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()]
|
||||
def float4_axis(self, i:int): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0]
|
||||
|
||||
def upcasted_axis(self, i):
|
||||
def upcasted_axis(self, i:int):
|
||||
return list(zip(self.sts[i].shape[self.shape_len-self.upcasted:],
|
||||
self.sts[i].real_strides()[self.shape_len-self.upcasted:],
|
||||
[x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])]))
|
||||
|
||||
# TODO: is there a better way to write this?
|
||||
def acc_offsets(self, i) -> List[int]:
|
||||
def acc_offsets(self, i:int) -> List[int]:
|
||||
if self.upcasted == 0: return [0]
|
||||
upcasted_i = self.upcasted_axis(i)
|
||||
acc_strides = [x*(1-upcasted_i[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in upcasted_i[::-1])))]
|
||||
return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])]
|
||||
|
||||
def get_upcast_dim(self, i) -> List[int]:
|
||||
def get_upcast_dim(self, i:int) -> List[int]:
|
||||
should_upcast = self.opts.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType))
|
||||
return [x for x in self.sts[i].unit_stride_axes() if should_upcast and x >= self.shape_len-self.upcasted and self.sts[i].shape[x] > 1]
|
||||
|
||||
|
||||
@@ -287,7 +287,7 @@ class Tensor:
|
||||
# - There's a special case where a permute is needed at the end:
|
||||
# - if first Tensor passed in (expand dims) is not at dim 0
|
||||
# - and following Tensors does not follow consecutively to the end of fancy indexing's dims
|
||||
def __getitem__(self, val): # val: Union[int, slice, Tensor, None, Ellipsis, Tuple[Union[int, slice, Tensor, None, Ellipsis], ...]]
|
||||
def __getitem__(self, val) -> Tensor: # val: Union[int, slice, Tensor, None, Ellipsis, Tuple[Union[int, slice, Tensor, None, Ellipsis], ...]]
|
||||
def normalize_int(e, i, dim_sz):
|
||||
if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1
|
||||
raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}")
|
||||
@@ -362,7 +362,7 @@ class Tensor:
|
||||
padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)])
|
||||
return self.pad(padding, value=value).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)]))
|
||||
|
||||
def gather(self: Tensor, idx: Tensor, dim: int):
|
||||
def gather(self: Tensor, idx: Tensor, dim: int) -> Tensor:
|
||||
assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim"
|
||||
assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape"
|
||||
if dim < 0: dim += self.ndim
|
||||
@@ -371,7 +371,7 @@ class Tensor:
|
||||
permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]]
|
||||
return ((idx == Tensor.arange(self.shape[dim], dtype=dtypes.int32, requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim)
|
||||
|
||||
def cat(self, *args, dim=0):
|
||||
def cat(self, *args, dim=0) -> Tensor:
|
||||
dim = (dim + len(self.shape)) if dim < 0 else dim
|
||||
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
|
||||
catargs = [self, *args]
|
||||
@@ -383,13 +383,13 @@ class Tensor:
|
||||
return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
|
||||
|
||||
@staticmethod
|
||||
def stack(tensors, dim=0):
|
||||
def stack(tensors, dim=0) -> Tensor:
|
||||
first = tensors[0].unsqueeze(dim)
|
||||
unsqueezed_tensors = [tensor.unsqueeze(dim) for tensor in tensors[1:]]
|
||||
# checks for shapes and number of dimensions delegated to cat
|
||||
return first.cat(*unsqueezed_tensors, dim=dim)
|
||||
|
||||
def repeat(self, repeats):
|
||||
def repeat(self, repeats) -> Tensor:
|
||||
base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
|
||||
new_shape = [x for b in base_shape for x in [1, b]]
|
||||
expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
|
||||
@@ -402,19 +402,19 @@ class Tensor:
|
||||
slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)]
|
||||
return [self[tuple(sl)] for sl in slice_params]
|
||||
|
||||
def squeeze(self, dim=None):
|
||||
def squeeze(self, dim=None) -> Tensor:
|
||||
if dim is None: return self if 1 not in self.shape else self.reshape(*[size for size in self.shape if size != 1])
|
||||
if dim <= 0 and self.ndim == 0: return self # This is to match PyTorch behavior
|
||||
if not -self.ndim <= dim < self.ndim: raise IndexError(f"Dimension out of range (expected to be in range of [{-self.ndim if self.ndim > 0 else self.ndim-1}, {self.ndim-1 if self.ndim > 0 else self.ndim}], but got {dim})")
|
||||
if dim < 0: dim += self.ndim
|
||||
return self if self.shape[dim] != 1 else self.reshape(*[size for idx, size in enumerate(self.shape) if idx != dim])
|
||||
|
||||
def unsqueeze(self, dim):
|
||||
def unsqueeze(self, dim) -> Tensor:
|
||||
if dim < 0: dim = len(self.shape) + dim + 1
|
||||
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
||||
|
||||
# (padding_left, padding_right, padding_top, padding_bottom)
|
||||
def pad2d(self, padding:Union[List[int], Tuple[int, ...]], value:float=0):
|
||||
def pad2d(self, padding:Union[List[int], Tuple[int, ...]], value:float=0) -> Tensor:
|
||||
slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1]
|
||||
return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
|
||||
|
||||
@@ -473,9 +473,9 @@ class Tensor:
|
||||
|
||||
# ***** processing ops *****
|
||||
|
||||
def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
|
||||
def _pool(self, k_:Tuple[sint, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
|
||||
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
assert all_int(self.shape) and all_int(k_), f"does not support symbolic {self.shape=}, {k_=}"
|
||||
s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
|
||||
assert len(k_) == len(s_) and len(k_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
|
||||
slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):]
|
||||
|
||||
Reference in New Issue
Block a user