diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 2f2242862c..bcd354aab5 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -129,22 +129,23 @@ class Kernel: @property def membufs(self) -> List[MemBuffer]: return [x for x in self.bufs if isinstance(x, MemBuffer)] - def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()] - def float4_axis(self, i): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0] + # TODO: these need more tests or it might silently be no-op + def shape_offsets(self, i:int): return itertools.product(*[list(range(cast(int, s))) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()] + def float4_axis(self, i:int): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0] - def upcasted_axis(self, i): + def upcasted_axis(self, i:int): return list(zip(self.sts[i].shape[self.shape_len-self.upcasted:], self.sts[i].real_strides()[self.shape_len-self.upcasted:], [x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])])) # TODO: is there a better way to write this? - def acc_offsets(self, i) -> List[int]: + def acc_offsets(self, i:int) -> List[int]: if self.upcasted == 0: return [0] upcasted_i = self.upcasted_axis(i) acc_strides = [x*(1-upcasted_i[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in upcasted_i[::-1])))] return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])] - def get_upcast_dim(self, i) -> List[int]: + def get_upcast_dim(self, i:int) -> List[int]: should_upcast = self.opts.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType)) return [x for x in self.sts[i].unit_stride_axes() if should_upcast and x >= self.shape_len-self.upcasted and self.sts[i].shape[x] > 1] diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 45bb7f9a73..be43462a00 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -287,7 +287,7 @@ class Tensor: # - There's a special case where a permute is needed at the end: # - if first Tensor passed in (expand dims) is not at dim 0 # - and following Tensors does not follow consecutively to the end of fancy indexing's dims - def __getitem__(self, val): # val: Union[int, slice, Tensor, None, Ellipsis, Tuple[Union[int, slice, Tensor, None, Ellipsis], ...]] + def __getitem__(self, val) -> Tensor: # val: Union[int, slice, Tensor, None, Ellipsis, Tuple[Union[int, slice, Tensor, None, Ellipsis], ...]] def normalize_int(e, i, dim_sz): if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1 raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}") @@ -362,7 +362,7 @@ class Tensor: padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)]) return self.pad(padding, value=value).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)])) - def gather(self: Tensor, idx: Tensor, dim: int): + def gather(self: Tensor, idx: Tensor, dim: int) -> Tensor: assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim" assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape" if dim < 0: dim += self.ndim @@ -371,7 +371,7 @@ class Tensor: permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]] return ((idx == Tensor.arange(self.shape[dim], dtype=dtypes.int32, requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim) - def cat(self, *args, dim=0): + def cat(self, *args, dim=0) -> Tensor: dim = (dim + len(self.shape)) if dim < 0 else dim assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args) catargs = [self, *args] @@ -383,13 +383,13 @@ class Tensor: return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)]) @staticmethod - def stack(tensors, dim=0): + def stack(tensors, dim=0) -> Tensor: first = tensors[0].unsqueeze(dim) unsqueezed_tensors = [tensor.unsqueeze(dim) for tensor in tensors[1:]] # checks for shapes and number of dimensions delegated to cat return first.cat(*unsqueezed_tensors, dim=dim) - def repeat(self, repeats): + def repeat(self, repeats) -> Tensor: base_shape = (1,) * (len(repeats) - self.ndim) + self.shape new_shape = [x for b in base_shape for x in [1, b]] expand_shape = [x for rs in zip(repeats, base_shape) for x in rs] @@ -402,19 +402,19 @@ class Tensor: slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)] return [self[tuple(sl)] for sl in slice_params] - def squeeze(self, dim=None): + def squeeze(self, dim=None) -> Tensor: if dim is None: return self if 1 not in self.shape else self.reshape(*[size for size in self.shape if size != 1]) if dim <= 0 and self.ndim == 0: return self # This is to match PyTorch behavior if not -self.ndim <= dim < self.ndim: raise IndexError(f"Dimension out of range (expected to be in range of [{-self.ndim if self.ndim > 0 else self.ndim-1}, {self.ndim-1 if self.ndim > 0 else self.ndim}], but got {dim})") if dim < 0: dim += self.ndim return self if self.shape[dim] != 1 else self.reshape(*[size for idx, size in enumerate(self.shape) if idx != dim]) - def unsqueeze(self, dim): + def unsqueeze(self, dim) -> Tensor: if dim < 0: dim = len(self.shape) + dim + 1 return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:]) # (padding_left, padding_right, padding_top, padding_bottom) - def pad2d(self, padding:Union[List[int], Tuple[int, ...]], value:float=0): + def pad2d(self, padding:Union[List[int], Tuple[int, ...]], value:float=0) -> Tensor: slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1] return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value) @@ -473,9 +473,9 @@ class Tensor: # ***** processing ops ***** - def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor: + def _pool(self, k_:Tuple[sint, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor: assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}" - assert all_int(self.shape), f"does not support symbolic shape {self.shape}" + assert all_int(self.shape) and all_int(k_), f"does not support symbolic {self.shape=}, {k_=}" s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_)) assert len(k_) == len(s_) and len(k_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}" slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):]