mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
conv2d and conv_transpose2d to mixin (#15838)
* conv2d and conv_transpose2d to mixin * cleanup
This commit is contained in:
@@ -71,6 +71,35 @@ class TestTensorUOpStack(unittest.TestCase):
|
||||
def test_stack_3tensors(self): _check(self, _t(2, 3), lambda x: x.stack(x, x, dim=0))
|
||||
def test_stack_new_last(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=-1))
|
||||
|
||||
class TestTensorUOpConv2d(unittest.TestCase):
|
||||
def test_conv2d_basic(self):
|
||||
w = _t(1, 1, 2, 2).float()
|
||||
_check(self, _t(1, 1, 3, 3).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop))
|
||||
def test_conv2d_padded(self):
|
||||
w = _t(1, 1, 2, 2).float()
|
||||
_check(self, _t(1, 1, 3, 3).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop, padding=1))
|
||||
def test_conv2d_negative_padding(self):
|
||||
w = _t(1, 1, 3, 3).float()
|
||||
_check(self, _t(1, 1, 5, 5).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop, padding=(-1,-1,-1,-1)))
|
||||
def test_conv2d_multichannel_bias(self):
|
||||
w, b = _t(4, 2, 3, 3).float(), _t(4).float()
|
||||
_check(self, _t(2, 2, 5, 5).float(), lambda x: x.conv2d(*(y if isinstance(x, Tensor) else y.uop for y in (w, b))))
|
||||
def test_conv2d_stride_dilation(self):
|
||||
w = _t(2, 2, 2, 2).float()
|
||||
_check(self, _t(1, 2, 6, 6).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop, stride=2, dilation=2))
|
||||
def test_conv2d_groups(self):
|
||||
w = _t(4, 1, 2, 2).float()
|
||||
_check(self, _t(1, 4, 4, 4).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop, groups=4))
|
||||
def test_conv2d_3d(self):
|
||||
w = _t(1, 1, 2, 2, 2).float()
|
||||
_check(self, _t(1, 1, 3, 3, 3).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop))
|
||||
def test_conv_transpose2d_basic(self):
|
||||
w = _t(1, 1, 2, 2).float()
|
||||
_check(self, _t(1, 1, 3, 3).float(), lambda x: x.conv_transpose2d(w if isinstance(x, Tensor) else w.uop))
|
||||
def test_conv_transpose2d_stride(self):
|
||||
w = _t(1, 1, 2, 2).float()
|
||||
_check(self, _t(1, 1, 3, 3).float(), lambda x: x.conv_transpose2d(w if isinstance(x, Tensor) else w.uop, stride=2))
|
||||
|
||||
class TestTensorUOpEinsum(unittest.TestCase):
|
||||
def test_einsum_dot(self): _check(self, _t(2, 3), lambda x: type(x).einsum("ij,ij->", x, x))
|
||||
def test_einsum_transpose(self): _check(self, _t(2, 3), lambda x: type(x).einsum("ij->ji", x))
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.mixin.reduce import ReduceMixin
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element
|
||||
from tinygrad.dtype import DTypeLike, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
|
||||
from tinygrad.helpers import argfix, flatten, prod, round_up
|
||||
from tinygrad.helpers import argfix, flatten, flat_to_grouped, make_tuple, prod, resolve_pool_pads, round_up
|
||||
|
||||
ReductionStr = Literal["mean", "sum", "none"]
|
||||
|
||||
@@ -417,6 +417,61 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
|
||||
return x.add(bias) if bias is not None else x
|
||||
|
||||
def conv2d(self, weight:Self, bias:Self|None=None, groups=1, stride=1, dilation=1, padding:int|Sequence[int]=0,
|
||||
dtype:DTypeLike|None=None) -> Self:
|
||||
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
|
||||
padding_ = resolve_pool_pads(padding, len(HW))
|
||||
assert groups*cin == cin_ and len(self.shape) == len(weight.shape),\
|
||||
f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"
|
||||
# conv2d is a pooling op (with padding, possibly negative — _pad_constant handles the shrink)
|
||||
x = self._pad_constant(((0,0),)*(self.ndim-len(HW)) + flat_to_grouped(padding_), 0.0)._pool(HW, stride, dilation)
|
||||
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
|
||||
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW)\
|
||||
.permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))])
|
||||
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
|
||||
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW))\
|
||||
.sum([-1-i for i in range(1+len(oyx))], keepdim=True, dtype=dtype).reshape(bs, cout, *oyx)
|
||||
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
|
||||
|
||||
def conv_transpose2d(self, weight:Self, bias:Self|None=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Self:
|
||||
"""
|
||||
Applies a transposed convolution over a tensor with a given `weight` and optional `bias`.
|
||||
|
||||
This function supports three different types of `padding`
|
||||
|
||||
1. `int` (single value):
|
||||
Applies the same padding value uniformly to all spatial dimensions.
|
||||
|
||||
2. `tuple[int, ...]` (length = number of spatial dimensions):
|
||||
Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
|
||||
|
||||
3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
|
||||
Specifies explicit padding for each side of each spatial dimension in the form
|
||||
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
|
||||
|
||||
NOTE: unlike PyTorch, this implementation is not limited to only 2d transposed convolutions and instead works for any number of dimensions.
|
||||
|
||||
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(9).reshape(1, 1, 3, 3)
|
||||
w = Tensor.ones(1, 1, 2, 2)
|
||||
print(t.conv_transpose2d(w).numpy())
|
||||
```
|
||||
"""
|
||||
x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1))
|
||||
HW = weight.shape[2:]
|
||||
padding = flat_to_grouped(resolve_pool_pads(padding, len(HW)))
|
||||
stride, dilation, output_padding = [make_tuple(x, len(HW)) for x in (stride, dilation, output_padding)]
|
||||
if any(s>1 for s in stride):
|
||||
# handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1))
|
||||
x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
|
||||
x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
|
||||
x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
|
||||
x = x.shrink_to(None, None, *[k-(s-1) for k,s in zip(x.shape[2:], stride)])
|
||||
padding = flatten((((k-1)*d-pB,(k-1)*d-pA+op) for k,d,(pB,pA),op in reversed(list(zip(HW, dilation, padding, output_padding)))))
|
||||
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
||||
|
||||
def layernorm(self, axis:int|tuple[int,...]=-1, eps:float=1e-5) -> Self:
|
||||
"""
|
||||
Applies Layer Normalization over a mini-batch of inputs.
|
||||
|
||||
@@ -1652,51 +1652,6 @@ class Tensor(OpMixin):
|
||||
ret = (indices.reshape(bs,c,1,-1)._one_hot_along_dim(prod(output_size), 2).where(self.reshape(bs,c,1,-1), 0)).sum(3)
|
||||
return ret.reshape(bs,c,*output_size)
|
||||
|
||||
def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
|
||||
dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Applies a convolution over a tensor with a given `weight` and optional `bias`.
|
||||
|
||||
This function supports three different types of `padding`
|
||||
|
||||
1. `int` (single value):
|
||||
Applies the same padding value uniformly to all spatial dimensions.
|
||||
|
||||
2. `tuple[int, ...]` (length = number of spatial dimensions):
|
||||
Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
|
||||
|
||||
3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
|
||||
Specifies explicit padding for each side of each spatial dimension in the form
|
||||
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
|
||||
|
||||
NOTE: unlike PyTorch, this implementation is not limited to only 2d convolutions and instead works for any number of dimensions.
|
||||
|
||||
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(9).reshape(1, 1, 3, 3)
|
||||
w = Tensor.ones(1, 1, 2, 2)
|
||||
print(t.conv2d(w).numpy())
|
||||
```
|
||||
"""
|
||||
if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, dtype)
|
||||
HW = weight.shape[2:]
|
||||
if WINO and all(x == 3 for x in HW) and stride == 1 and dilation == 1: return self._conv2d_winograd(weight, bias, groups, padding, dtype)
|
||||
(bs,cin_), (cout,cin) = self.shape[:2], weight.shape[:2]
|
||||
padding_ = resolve_pool_pads(padding, len(HW))
|
||||
assert groups*cin == cin_ and len(self.shape) == len(weight.shape),\
|
||||
f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"
|
||||
|
||||
# conv2d is a pooling op (with padding)
|
||||
x = self.pad(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
|
||||
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
|
||||
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW)\
|
||||
.permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))])
|
||||
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
|
||||
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW))\
|
||||
.sum([-1-i for i in range(1+len(oyx))], keepdim=True, dtype=dtype).reshape(bs, cout, *oyx)
|
||||
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
|
||||
|
||||
# TODO: winograd can be a rewrite rule like split_reduceop
|
||||
def _conv2d_winograd(self, weight:Tensor, bias:Tensor|None, groups:int, padding:int|Sequence[int], dtype:DTypeLike|None) -> Tensor:
|
||||
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
|
||||
@@ -1736,9 +1691,10 @@ class Tensor(OpMixin):
|
||||
|
||||
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
|
||||
|
||||
def conv_transpose2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
|
||||
def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|Sequence[int]=0,
|
||||
dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Applies a transposed convolution over a tensor with a given `weight` and optional `bias`.
|
||||
Applies a convolution over a tensor with a given `weight` and optional `bias`.
|
||||
|
||||
This function supports three different types of `padding`
|
||||
|
||||
@@ -1752,32 +1708,23 @@ class Tensor(OpMixin):
|
||||
Specifies explicit padding for each side of each spatial dimension in the form
|
||||
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
|
||||
|
||||
NOTE: unlike PyTorch, this implementation is not limited to only 2d transposed convolutions and instead works for any number of dimensions.
|
||||
NOTE: unlike PyTorch, this implementation is not limited to only 2d convolutions and instead works for any number of dimensions.
|
||||
|
||||
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
||||
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(9).reshape(1, 1, 3, 3)
|
||||
w = Tensor.ones(1, 1, 2, 2)
|
||||
print(t.conv_transpose2d(w).numpy())
|
||||
print(t.conv2d(w).numpy())
|
||||
```
|
||||
"""
|
||||
x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1))
|
||||
HW = weight.shape[2:]
|
||||
padding = flat_to_grouped(resolve_pool_pads(padding, len(HW)))
|
||||
stride, dilation, output_padding = [make_tuple(x, len(HW)) for x in (stride, dilation, output_padding)]
|
||||
if any(s>1 for s in stride):
|
||||
# handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1))
|
||||
x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
|
||||
x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
|
||||
x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
|
||||
x = x.shrink_to(None, None, *[k-(s-1) for k,s in zip(x.shape[2:], stride)])
|
||||
padding = flatten((((k-1)*d-pB,(k-1)*d-pA+op) for k,d,(pB,pA),op in reversed(list(zip(HW, dilation, padding, output_padding)))))
|
||||
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
||||
if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, dtype)
|
||||
if WINO and all(x == 3 for x in weight.shape[2:]) and stride == dilation == 1: return self._conv2d_winograd(weight, bias, groups, padding, dtype)
|
||||
return super().conv2d(weight, bias, groups, stride, dilation, padding, dtype)
|
||||
|
||||
def dot(self, w:Tensor, dtype:DTypeLike|None=None) -> Tensor:
|
||||
if IMAGE: return self.image_dot(w, dtype)
|
||||
return super().dot(w, dtype=dtype)
|
||||
return super().dot(w, dtype)
|
||||
|
||||
def cummax(self, axis:int=0) -> tuple[Tensor, Tensor]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user