From cabc347066a0c2f99245bd513b851cb53dc0a0d2 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 20 Apr 2026 18:10:06 -0400 Subject: [PATCH] conv2d and conv_transpose2d to mixin (#15838) * conv2d and conv_transpose2d to mixin * cleanup --- test/null/test_tensor_uop_mixin.py | 29 ++++++++++++ tinygrad/mixin/__init__.py | 57 ++++++++++++++++++++++- tinygrad/tensor.py | 73 ++++-------------------------- 3 files changed, 95 insertions(+), 64 deletions(-) diff --git a/test/null/test_tensor_uop_mixin.py b/test/null/test_tensor_uop_mixin.py index 8c98809bf3..fbfac34231 100644 --- a/test/null/test_tensor_uop_mixin.py +++ b/test/null/test_tensor_uop_mixin.py @@ -71,6 +71,35 @@ class TestTensorUOpStack(unittest.TestCase): def test_stack_3tensors(self): _check(self, _t(2, 3), lambda x: x.stack(x, x, dim=0)) def test_stack_new_last(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=-1)) +class TestTensorUOpConv2d(unittest.TestCase): + def test_conv2d_basic(self): + w = _t(1, 1, 2, 2).float() + _check(self, _t(1, 1, 3, 3).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop)) + def test_conv2d_padded(self): + w = _t(1, 1, 2, 2).float() + _check(self, _t(1, 1, 3, 3).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop, padding=1)) + def test_conv2d_negative_padding(self): + w = _t(1, 1, 3, 3).float() + _check(self, _t(1, 1, 5, 5).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop, padding=(-1,-1,-1,-1))) + def test_conv2d_multichannel_bias(self): + w, b = _t(4, 2, 3, 3).float(), _t(4).float() + _check(self, _t(2, 2, 5, 5).float(), lambda x: x.conv2d(*(y if isinstance(x, Tensor) else y.uop for y in (w, b)))) + def test_conv2d_stride_dilation(self): + w = _t(2, 2, 2, 2).float() + _check(self, _t(1, 2, 6, 6).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop, stride=2, dilation=2)) + def test_conv2d_groups(self): + w = _t(4, 1, 2, 2).float() + _check(self, _t(1, 4, 4, 4).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop, groups=4)) + def test_conv2d_3d(self): + w = _t(1, 1, 2, 2, 2).float() + _check(self, _t(1, 1, 3, 3, 3).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop)) + def test_conv_transpose2d_basic(self): + w = _t(1, 1, 2, 2).float() + _check(self, _t(1, 1, 3, 3).float(), lambda x: x.conv_transpose2d(w if isinstance(x, Tensor) else w.uop)) + def test_conv_transpose2d_stride(self): + w = _t(1, 1, 2, 2).float() + _check(self, _t(1, 1, 3, 3).float(), lambda x: x.conv_transpose2d(w if isinstance(x, Tensor) else w.uop, stride=2)) + class TestTensorUOpEinsum(unittest.TestCase): def test_einsum_dot(self): _check(self, _t(2, 3), lambda x: type(x).einsum("ij,ij->", x, x)) def test_einsum_transpose(self): _check(self, _t(2, 3), lambda x: type(x).einsum("ij->ji", x)) diff --git a/tinygrad/mixin/__init__.py b/tinygrad/mixin/__init__.py index ab2efc9e78..c766e3c135 100644 --- a/tinygrad/mixin/__init__.py +++ b/tinygrad/mixin/__init__.py @@ -6,7 +6,7 @@ from tinygrad.mixin.reduce import ReduceMixin from tinygrad.uop import Ops from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element from tinygrad.dtype import DTypeLike, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype -from tinygrad.helpers import argfix, flatten, prod, round_up +from tinygrad.helpers import argfix, flatten, flat_to_grouped, make_tuple, prod, resolve_pool_pads, round_up ReductionStr = Literal["mean", "sum", "none"] @@ -417,6 +417,61 @@ class OpMixin(ElementwiseMixin, ReduceMixin): x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight) return x.add(bias) if bias is not None else x + def conv2d(self, weight:Self, bias:Self|None=None, groups=1, stride=1, dilation=1, padding:int|Sequence[int]=0, + dtype:DTypeLike|None=None) -> Self: + (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:] + padding_ = resolve_pool_pads(padding, len(HW)) + assert groups*cin == cin_ and len(self.shape) == len(weight.shape),\ + f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" + # conv2d is a pooling op (with padding, possibly negative — _pad_constant handles the shrink) + x = self._pad_constant(((0,0),)*(self.ndim-len(HW)) + flat_to_grouped(padding_), 0.0)._pool(HW, stride, dilation) + rcout, oyx = cout//groups, x.shape[2:-len(HW)] + x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW)\ + .permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) + # conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW) + ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW))\ + .sum([-1-i for i in range(1+len(oyx))], keepdim=True, dtype=dtype).reshape(bs, cout, *oyx) + return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW))) + + def conv_transpose2d(self, weight:Self, bias:Self|None=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Self: + """ + Applies a transposed convolution over a tensor with a given `weight` and optional `bias`. + + This function supports three different types of `padding` + + 1. `int` (single value): + Applies the same padding value uniformly to all spatial dimensions. + + 2. `tuple[int, ...]` (length = number of spatial dimensions): + Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`. + + 3. `tuple[int, ...]` (length = 2 * number of spatial dimensions): + Specifies explicit padding for each side of each spatial dimension in the form + `(padding_left, padding_right, padding_top, padding_bottom, ...)`. + + NOTE: unlike PyTorch, this implementation is not limited to only 2d transposed convolutions and instead works for any number of dimensions. + + See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.arange(9).reshape(1, 1, 3, 3) + w = Tensor.ones(1, 1, 2, 2) + print(t.conv_transpose2d(w).numpy()) + ``` + """ + x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1)) + HW = weight.shape[2:] + padding = flat_to_grouped(resolve_pool_pads(padding, len(HW))) + stride, dilation, output_padding = [make_tuple(x, len(HW)) for x in (stride, dilation, output_padding)] + if any(s>1 for s in stride): + # handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1)) + x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:])) + x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride))) + x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)]) + x = x.shrink_to(None, None, *[k-(s-1) for k,s in zip(x.shape[2:], stride)]) + padding = flatten((((k-1)*d-pB,(k-1)*d-pA+op) for k,d,(pB,pA),op in reversed(list(zip(HW, dilation, padding, output_padding))))) + return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding) + def layernorm(self, axis:int|tuple[int,...]=-1, eps:float=1e-5) -> Self: """ Applies Layer Normalization over a mini-batch of inputs. diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 0078756c6c..5b98a68572 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1652,51 +1652,6 @@ class Tensor(OpMixin): ret = (indices.reshape(bs,c,1,-1)._one_hot_along_dim(prod(output_size), 2).where(self.reshape(bs,c,1,-1), 0)).sum(3) return ret.reshape(bs,c,*output_size) - def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0, - dtype:DTypeLike|None=None) -> Tensor: - """ - Applies a convolution over a tensor with a given `weight` and optional `bias`. - - This function supports three different types of `padding` - - 1. `int` (single value): - Applies the same padding value uniformly to all spatial dimensions. - - 2. `tuple[int, ...]` (length = number of spatial dimensions): - Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`. - - 3. `tuple[int, ...]` (length = 2 * number of spatial dimensions): - Specifies explicit padding for each side of each spatial dimension in the form - `(padding_left, padding_right, padding_top, padding_bottom, ...)`. - - NOTE: unlike PyTorch, this implementation is not limited to only 2d convolutions and instead works for any number of dimensions. - - See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.arange(9).reshape(1, 1, 3, 3) - w = Tensor.ones(1, 1, 2, 2) - print(t.conv2d(w).numpy()) - ``` - """ - if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, dtype) - HW = weight.shape[2:] - if WINO and all(x == 3 for x in HW) and stride == 1 and dilation == 1: return self._conv2d_winograd(weight, bias, groups, padding, dtype) - (bs,cin_), (cout,cin) = self.shape[:2], weight.shape[:2] - padding_ = resolve_pool_pads(padding, len(HW)) - assert groups*cin == cin_ and len(self.shape) == len(weight.shape),\ - f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" - - # conv2d is a pooling op (with padding) - x = self.pad(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W) - rcout, oyx = cout//groups, x.shape[2:-len(HW)] - x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW)\ - .permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) - # conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW) - ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW))\ - .sum([-1-i for i in range(1+len(oyx))], keepdim=True, dtype=dtype).reshape(bs, cout, *oyx) - return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW))) - # TODO: winograd can be a rewrite rule like split_reduceop def _conv2d_winograd(self, weight:Tensor, bias:Tensor|None, groups:int, padding:int|Sequence[int], dtype:DTypeLike|None) -> Tensor: (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:] @@ -1736,9 +1691,10 @@ class Tensor(OpMixin): return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward() - def conv_transpose2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor: + def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|Sequence[int]=0, + dtype:DTypeLike|None=None) -> Tensor: """ - Applies a transposed convolution over a tensor with a given `weight` and optional `bias`. + Applies a convolution over a tensor with a given `weight` and optional `bias`. This function supports three different types of `padding` @@ -1752,32 +1708,23 @@ class Tensor(OpMixin): Specifies explicit padding for each side of each spatial dimension in the form `(padding_left, padding_right, padding_top, padding_bottom, ...)`. - NOTE: unlike PyTorch, this implementation is not limited to only 2d transposed convolutions and instead works for any number of dimensions. + NOTE: unlike PyTorch, this implementation is not limited to only 2d convolutions and instead works for any number of dimensions. - See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html + See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(9).reshape(1, 1, 3, 3) w = Tensor.ones(1, 1, 2, 2) - print(t.conv_transpose2d(w).numpy()) + print(t.conv2d(w).numpy()) ``` """ - x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1)) - HW = weight.shape[2:] - padding = flat_to_grouped(resolve_pool_pads(padding, len(HW))) - stride, dilation, output_padding = [make_tuple(x, len(HW)) for x in (stride, dilation, output_padding)] - if any(s>1 for s in stride): - # handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1)) - x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:])) - x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride))) - x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)]) - x = x.shrink_to(None, None, *[k-(s-1) for k,s in zip(x.shape[2:], stride)]) - padding = flatten((((k-1)*d-pB,(k-1)*d-pA+op) for k,d,(pB,pA),op in reversed(list(zip(HW, dilation, padding, output_padding))))) - return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding) + if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, dtype) + if WINO and all(x == 3 for x in weight.shape[2:]) and stride == dilation == 1: return self._conv2d_winograd(weight, bias, groups, padding, dtype) + return super().conv2d(weight, bias, groups, stride, dilation, padding, dtype) def dot(self, w:Tensor, dtype:DTypeLike|None=None) -> Tensor: if IMAGE: return self.image_dot(w, dtype) - return super().dot(w, dtype=dtype) + return super().dot(w, dtype) def cummax(self, axis:int=0) -> tuple[Tensor, Tensor]: """