From 8eeb77a905aff0108e365dfdc41d36c75b5790a2 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 20 Apr 2026 14:03:35 -0400 Subject: [PATCH] flat_to_grouped and resolve_pool_pads to helpers (#15834) --- tinygrad/helpers.py | 6 ++++++ tinygrad/tensor.py | 29 +++++++++++------------------ 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 157066359f..2f4d4cb978 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -45,6 +45,12 @@ def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in def fully_flatten(l): if not (hasattr(l, "__len__") and hasattr(l, "__getitem__")) or isinstance(l, str): return [l] return [l[()]] if hasattr(l, "shape") and l.shape == () else [x for li in l for x in fully_flatten(li)] +# `(padding_left, padding_right, padding_top, padding_bottom, ...)` -> `(..., (padding_top, padding_bottom), (padding_left, padding_right))` +def flat_to_grouped(padding:Sequence[T]) -> tuple[tuple[T, T], ...]: return tuple(zip(padding[-2::-2], padding[::-2])) +def resolve_pool_pads(padding:int|Sequence[int], dims:int) -> Sequence[int]: + if not isinstance(padding, int) and not (len(padding) == 2*dims or len(padding) == dims): + raise ValueError(f"Padding must be an int or a sequence of length {dims} or {2*dims}, but got {padding=} with {dims=}.") + return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1]) def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm) def _is_balanced(s:str) -> bool: return (d := 0, all((d := d + (c == '(') - (c == ')')) >= 0 for c in s))[1] and d == 0 def strip_parens(fst:str) -> str: return fst[1:-1] if fst[:1]=='(' and fst[-1:]==')' and _is_balanced(fst[1:-1]) else fst diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7e27048e85..5c09e56fe6 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -6,8 +6,9 @@ from typing import Any, Callable, ClassVar, Sequence, cast, get_args, Literal, P if TYPE_CHECKING: import numpy from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_float, least_upper_dtype, to_dtype, truncate from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid, InvalidType -from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, getenv, all_same, fully_flatten, ceildiv, fetch -from tinygrad.helpers import IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile, suppress_finalizing, disable_gc +from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, getenv, all_same, fully_flatten, ceildiv, fetch, flat_to_grouped +from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile +from tinygrad.helpers import suppress_finalizing, disable_gc from tinygrad.gradient import compute_gradient from tinygrad.mixin import OpMixin, ReductionStr from tinygrad.uop.ops import smax, UOp, Ops, sint, all_metadata, _index_to_concrete_int, sint_to_uop, Variable @@ -86,9 +87,6 @@ def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, . # select from values for each True element in mask else select from target return mask.where(values, target) -# `(padding_left, padding_right, padding_top, padding_bottom, ...)` -> `(..., (padding_top, padding_bottom), (padding_left, padding_right))` -def _flat_to_grouped(padding:Sequence[sint]) -> tuple[tuple[sint, sint], ...]: return tuple(zip(padding[-2::-2], padding[::-2])) - class Tensor(OpMixin): """ A `Tensor` is a multi-dimensional matrix containing elements of a single data type. @@ -1104,7 +1102,7 @@ class Tensor(OpMixin): # normalize to grouped format if all(isinstance(p, (int,UOp)) for p in padding): if len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads") - pX = _flat_to_grouped(tuple(cast(Sequence[sint], padding)) + (0,0)*(self.ndim - len(padding)//2)) + pX = ((0,0),)*(self.ndim - len(padding)//2) + flat_to_grouped(cast(Sequence[sint], padding)) else: pX = tuple((0,0) if p is None else p for p in cast(Sequence[tuple[sint, sint]|None], padding)) if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}") # dispatch @@ -1519,14 +1517,9 @@ class Tensor(OpMixin): # ***** processing ops ***** - def _resolve_pool_pads(self, padding:int|Sequence[int], dims:int) -> Sequence[int]: - if not isinstance(padding, int) and not (len(padding) == 2*dims or len(padding) == dims): - raise ValueError(f"Padding must be an int or a sequence of length {dims} or {2*dims}, but got {padding=} for {self.shape=} with {dims=}.") - return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1]) - def _apply_ceil_mode(self, pads:Sequence[int], k_:tuple[sint, ...], s_:int|tuple[int, ...], d_:int|tuple[int, ...]) -> list[int]: (d_,s_), i_ = (make_tuple(x, len(k_)) for x in (d_,s_)), self.shape[-len(k_):] - pads, grouped_pads = list(pads), _flat_to_grouped(pads) + pads, grouped_pads = list(pads), flat_to_grouped(pads) # https://arxiv.org/pdf/1603.07285 section 5.1, relationship 15. o_ = [ceildiv(i+pB+pA - (d*(k-1)+1), s) + 1 for i,d,k,s,(pB,pA) in zip(i_,d_,k_,s_,grouped_pads)] for dim,(o,i,s,k,d,(pB,pA)) in enumerate(zip(o_,i_,s_,k_,d_,grouped_pads)): @@ -1576,7 +1569,7 @@ class Tensor(OpMixin): """ axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0)) def pool(x:Tensor, padding_:Sequence[int]) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation) - reg_pads = self._resolve_pool_pads(padding, len(k_)) + reg_pads = resolve_pool_pads(padding, len(k_)) ceil_pads = self._apply_ceil_mode(reg_pads, k_, stride if stride is not None else k_, dilation) if not count_include_pad: pads = ceil_pads if ceil_mode else reg_pads @@ -1618,7 +1611,7 @@ class Tensor(OpMixin): ``` """ axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0)) - pads = self._resolve_pool_pads(padding, len(k_)) + pads = resolve_pool_pads(padding, len(k_)) if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation) pooled = self.pad(pads, value=self.dtype.min)._pool(k_, stride if stride is not None else k_, dilation) if not return_indices: return pooled.max(axis) @@ -1652,7 +1645,7 @@ class Tensor(OpMixin): bs,c,*spatial_shape = self.shape if output_size is None: k_,d_,s_ = (make_tuple(x, len(spatial_shape)) for x in (kernel_size, dilation, stride if stride is not None else kernel_size)) - p_ = _flat_to_grouped(self._resolve_pool_pads(padding, len(spatial_shape))) + p_ = flat_to_grouped(resolve_pool_pads(padding, len(spatial_shape))) # https://arxiv.org/pdf/1603.07285 inverse of relationship 15 in section 5.1. output_size = tuple((i-1)*s - (pB+pA) + (d*(k-1)+1) for i,k,d,s,(pA,pB) in zip(spatial_shape,k_,d_,s_,p_)) else: output_size = output_size[-len(spatial_shape):] @@ -1688,7 +1681,7 @@ class Tensor(OpMixin): """ if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, dtype) (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:] - padding_ = self._resolve_pool_pads(padding, len(HW)) + padding_ = resolve_pool_pads(padding, len(HW)) assert groups*cin == cin_ and len(self.shape) == len(weight.shape),\ f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" @@ -1765,7 +1758,7 @@ class Tensor(OpMixin): """ x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1)) HW = weight.shape[2:] - padding = _flat_to_grouped(self._resolve_pool_pads(padding, len(HW))) + padding = flat_to_grouped(resolve_pool_pads(padding, len(HW))) stride, dilation, output_padding = [make_tuple(x, len(HW)) for x in (stride, dilation, output_padding)] if any(s>1 for s in stride): # handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1)) @@ -2495,7 +2488,7 @@ class Tensor(OpMixin): (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W) - padding_neg, padding_pos = [min(0, p) for p in self._resolve_pool_pads(padding, 2)], [max(0, p) for p in self._resolve_pool_pads(padding, 2)] + padding_neg, padding_pos = [min(0, p) for p in resolve_pool_pads(padding, 2)], [max(0, p) for p in resolve_pool_pads(padding, 2)] x = x.pad(padding_neg) iy, ix = x.shape[2:]