conv2d and conv_transpose2d to mixin (#15838)

* conv2d and conv_transpose2d to mixin

* cleanup
This commit is contained in:
chenyu
2026-04-20 18:10:06 -04:00
committed by GitHub
parent b8d3bf8970
commit cabc347066
3 changed files with 95 additions and 64 deletions

View File

@@ -71,6 +71,35 @@ class TestTensorUOpStack(unittest.TestCase):
def test_stack_3tensors(self): _check(self, _t(2, 3), lambda x: x.stack(x, x, dim=0))
def test_stack_new_last(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=-1))
class TestTensorUOpConv2d(unittest.TestCase):
def test_conv2d_basic(self):
w = _t(1, 1, 2, 2).float()
_check(self, _t(1, 1, 3, 3).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop))
def test_conv2d_padded(self):
w = _t(1, 1, 2, 2).float()
_check(self, _t(1, 1, 3, 3).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop, padding=1))
def test_conv2d_negative_padding(self):
w = _t(1, 1, 3, 3).float()
_check(self, _t(1, 1, 5, 5).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop, padding=(-1,-1,-1,-1)))
def test_conv2d_multichannel_bias(self):
w, b = _t(4, 2, 3, 3).float(), _t(4).float()
_check(self, _t(2, 2, 5, 5).float(), lambda x: x.conv2d(*(y if isinstance(x, Tensor) else y.uop for y in (w, b))))
def test_conv2d_stride_dilation(self):
w = _t(2, 2, 2, 2).float()
_check(self, _t(1, 2, 6, 6).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop, stride=2, dilation=2))
def test_conv2d_groups(self):
w = _t(4, 1, 2, 2).float()
_check(self, _t(1, 4, 4, 4).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop, groups=4))
def test_conv2d_3d(self):
w = _t(1, 1, 2, 2, 2).float()
_check(self, _t(1, 1, 3, 3, 3).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop))
def test_conv_transpose2d_basic(self):
w = _t(1, 1, 2, 2).float()
_check(self, _t(1, 1, 3, 3).float(), lambda x: x.conv_transpose2d(w if isinstance(x, Tensor) else w.uop))
def test_conv_transpose2d_stride(self):
w = _t(1, 1, 2, 2).float()
_check(self, _t(1, 1, 3, 3).float(), lambda x: x.conv_transpose2d(w if isinstance(x, Tensor) else w.uop, stride=2))
class TestTensorUOpEinsum(unittest.TestCase):
def test_einsum_dot(self): _check(self, _t(2, 3), lambda x: type(x).einsum("ij,ij->", x, x))
def test_einsum_transpose(self): _check(self, _t(2, 3), lambda x: type(x).einsum("ij->ji", x))

View File

@@ -6,7 +6,7 @@ from tinygrad.mixin.reduce import ReduceMixin
from tinygrad.uop import Ops
from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element
from tinygrad.dtype import DTypeLike, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.helpers import argfix, flatten, prod, round_up
from tinygrad.helpers import argfix, flatten, flat_to_grouped, make_tuple, prod, resolve_pool_pads, round_up
ReductionStr = Literal["mean", "sum", "none"]
@@ -417,6 +417,61 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
return x.add(bias) if bias is not None else x
def conv2d(self, weight:Self, bias:Self|None=None, groups=1, stride=1, dilation=1, padding:int|Sequence[int]=0,
dtype:DTypeLike|None=None) -> Self:
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
padding_ = resolve_pool_pads(padding, len(HW))
assert groups*cin == cin_ and len(self.shape) == len(weight.shape),\
f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"
# conv2d is a pooling op (with padding, possibly negative — _pad_constant handles the shrink)
x = self._pad_constant(((0,0),)*(self.ndim-len(HW)) + flat_to_grouped(padding_), 0.0)._pool(HW, stride, dilation)
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW)\
.permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))])
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW))\
.sum([-1-i for i in range(1+len(oyx))], keepdim=True, dtype=dtype).reshape(bs, cout, *oyx)
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
def conv_transpose2d(self, weight:Self, bias:Self|None=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Self:
"""
Applies a transposed convolution over a tensor with a given `weight` and optional `bias`.
This function supports three different types of `padding`
1. `int` (single value):
Applies the same padding value uniformly to all spatial dimensions.
2. `tuple[int, ...]` (length = number of spatial dimensions):
Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
Specifies explicit padding for each side of each spatial dimension in the form
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
NOTE: unlike PyTorch, this implementation is not limited to only 2d transposed convolutions and instead works for any number of dimensions.
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(9).reshape(1, 1, 3, 3)
w = Tensor.ones(1, 1, 2, 2)
print(t.conv_transpose2d(w).numpy())
```
"""
x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1))
HW = weight.shape[2:]
padding = flat_to_grouped(resolve_pool_pads(padding, len(HW)))
stride, dilation, output_padding = [make_tuple(x, len(HW)) for x in (stride, dilation, output_padding)]
if any(s>1 for s in stride):
# handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1))
x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
x = x.shrink_to(None, None, *[k-(s-1) for k,s in zip(x.shape[2:], stride)])
padding = flatten((((k-1)*d-pB,(k-1)*d-pA+op) for k,d,(pB,pA),op in reversed(list(zip(HW, dilation, padding, output_padding)))))
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
def layernorm(self, axis:int|tuple[int,...]=-1, eps:float=1e-5) -> Self:
"""
Applies Layer Normalization over a mini-batch of inputs.

View File

@@ -1652,51 +1652,6 @@ class Tensor(OpMixin):
ret = (indices.reshape(bs,c,1,-1)._one_hot_along_dim(prod(output_size), 2).where(self.reshape(bs,c,1,-1), 0)).sum(3)
return ret.reshape(bs,c,*output_size)
def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
dtype:DTypeLike|None=None) -> Tensor:
"""
Applies a convolution over a tensor with a given `weight` and optional `bias`.
This function supports three different types of `padding`
1. `int` (single value):
Applies the same padding value uniformly to all spatial dimensions.
2. `tuple[int, ...]` (length = number of spatial dimensions):
Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
Specifies explicit padding for each side of each spatial dimension in the form
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
NOTE: unlike PyTorch, this implementation is not limited to only 2d convolutions and instead works for any number of dimensions.
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(9).reshape(1, 1, 3, 3)
w = Tensor.ones(1, 1, 2, 2)
print(t.conv2d(w).numpy())
```
"""
if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, dtype)
HW = weight.shape[2:]
if WINO and all(x == 3 for x in HW) and stride == 1 and dilation == 1: return self._conv2d_winograd(weight, bias, groups, padding, dtype)
(bs,cin_), (cout,cin) = self.shape[:2], weight.shape[:2]
padding_ = resolve_pool_pads(padding, len(HW))
assert groups*cin == cin_ and len(self.shape) == len(weight.shape),\
f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"
# conv2d is a pooling op (with padding)
x = self.pad(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW)\
.permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))])
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW))\
.sum([-1-i for i in range(1+len(oyx))], keepdim=True, dtype=dtype).reshape(bs, cout, *oyx)
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
# TODO: winograd can be a rewrite rule like split_reduceop
def _conv2d_winograd(self, weight:Tensor, bias:Tensor|None, groups:int, padding:int|Sequence[int], dtype:DTypeLike|None) -> Tensor:
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
@@ -1736,9 +1691,10 @@ class Tensor(OpMixin):
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
def conv_transpose2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|Sequence[int]=0,
dtype:DTypeLike|None=None) -> Tensor:
"""
Applies a transposed convolution over a tensor with a given `weight` and optional `bias`.
Applies a convolution over a tensor with a given `weight` and optional `bias`.
This function supports three different types of `padding`
@@ -1752,32 +1708,23 @@ class Tensor(OpMixin):
Specifies explicit padding for each side of each spatial dimension in the form
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
NOTE: unlike PyTorch, this implementation is not limited to only 2d transposed convolutions and instead works for any number of dimensions.
NOTE: unlike PyTorch, this implementation is not limited to only 2d convolutions and instead works for any number of dimensions.
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(9).reshape(1, 1, 3, 3)
w = Tensor.ones(1, 1, 2, 2)
print(t.conv_transpose2d(w).numpy())
print(t.conv2d(w).numpy())
```
"""
x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1))
HW = weight.shape[2:]
padding = flat_to_grouped(resolve_pool_pads(padding, len(HW)))
stride, dilation, output_padding = [make_tuple(x, len(HW)) for x in (stride, dilation, output_padding)]
if any(s>1 for s in stride):
# handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1))
x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
x = x.shrink_to(None, None, *[k-(s-1) for k,s in zip(x.shape[2:], stride)])
padding = flatten((((k-1)*d-pB,(k-1)*d-pA+op) for k,d,(pB,pA),op in reversed(list(zip(HW, dilation, padding, output_padding)))))
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, dtype)
if WINO and all(x == 3 for x in weight.shape[2:]) and stride == dilation == 1: return self._conv2d_winograd(weight, bias, groups, padding, dtype)
return super().conv2d(weight, bias, groups, stride, dilation, padding, dtype)
def dot(self, w:Tensor, dtype:DTypeLike|None=None) -> Tensor:
if IMAGE: return self.image_dot(w, dtype)
return super().dot(w, dtype=dtype)
return super().dot(w, dtype)
def cummax(self, axis:int=0) -> tuple[Tensor, Tensor]:
"""