support both asymmetric and negative padding

This commit is contained in:
George Hotz
2022-06-26 17:59:25 -07:00
parent 49c954b389
commit dffde3de5a
6 changed files with 37 additions and 12 deletions

View File

@@ -266,6 +266,27 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x,w,stride=stride).relu(),
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), atol=1e-4)
def test_negative_padding_conv2d(self):
n,k = 10, 3
helper_test_op([(1,1,n,n), (1,1,k,k)],
lambda x,w: torch.nn.functional.conv2d(x[:, :, 1:-1, 1:-1],w).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=-1).relu(), atol=1e-4)
helper_test_op([(1,1,n,n), (1,1,k,k)],
lambda x,w: torch.nn.functional.conv2d(x[:, :, 1:, 1:],w).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=(-1,0,-1,0)).relu(), atol=1e-4)
def test_asymmetric_padding_conv2d(self):
for p in [(0,1,0,1), (2,1,2,1), (2,0,2,1)]:
with self.subTest(padding := p):
for n in [3,4]:
for k in [2]:
helper_test_op([(1,1,n,n), (1,1,k,k)],
lambda x,w: torch.nn.functional.conv2d(torch.nn.functional.pad(x, p),w).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4)
helper_test_op([(1,1,n,n), (1,1,k,k)],
lambda x,w: torch.nn.functional.conv2d(torch.nn.functional.pad(x, p),w).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4)
def test_padded_conv2d(self):
bs = 4
cin = 3

View File

@@ -5,18 +5,19 @@ def prod(x): return math.prod(x)
def reduce_shape(shape, axis):
return [1 if i in axis else shape[i] for i in range(len(shape))]
ConvArgs = namedtuple('ConvArgs', ['H', 'W', 'groups', 'rcout', 'cin', 'oy', 'ox', 'iy', 'ix', 'ys', 'xs', 'bs', 'cout', 'py', 'px', 'dy', 'dx', 'out_shape'])
ConvArgs = namedtuple('ConvArgs', ['H', 'W', 'groups', 'rcout', 'cin', 'oy', 'ox', 'iy', 'ix', 'ys', 'xs', 'bs', 'cout', 'py', 'py_', 'px', 'px_', 'dy', 'dx', 'out_shape'])
def get_conv_args(x_shape, w_shape, stride=1, groups=1, padding=0, dilation=1):
# TODO: https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html#tensor-layout
cout,cin,H,W = w_shape
ys,xs = (stride, stride) if isinstance(stride, int) else stride
py,px = (padding, padding) if isinstance(padding, int) else padding
if not isinstance(padding, int) and len(padding) == 4: px,px_,py,py_ = padding
else: py,px = (padding, padding) if isinstance(padding, int) else padding; py_, px_ = py, px
dy,dx = (dilation, dilation) if isinstance(dilation, int) else dilation
bs,cin_,iy,ix = x_shape
# TODO: should be easy to support asymmetric padding by changing output size
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html describes these sizes well
oy = (iy + 2*py - dy * (H-1) - 1)//ys + 1
ox = (ix + 2*px - dx * (W-1) - 1)//xs + 1
oy = (iy + py + py_ - dy * (H-1) - 1)//ys + 1
ox = (ix + px + px_ - dx * (W-1) - 1)//xs + 1
if cin*groups != cin_: raise Exception(f"Input Tensor shape {x_shape} does not match the shape of the weights {w_shape}. ({cin*groups} vs. {cin_})")
assert cout % groups == 0
return ConvArgs(H, W, groups, cout//groups, cin, oy, ox, iy, ix, ys, xs, bs, cout, py, px, dy, dx, (bs, cout, oy, ox))
return ConvArgs(H, W, groups, cout//groups, cin, oy, ox, iy, ix, ys, xs, bs, cout, py, py_, px, px_, dy, dx, (bs, cout, oy, ox))

View File

@@ -61,7 +61,8 @@ class CPUBuffer(np.ndarray):
def processing_op(x,op,w,C):
assert op == ProcessingOps.CONV, f"{op} isn't supported"
if C.px > 0 or C.py > 0: x = np.pad(x, [(0,0), (0,0), (C.py, C.py), (C.px, C.px)])
if C.px != 0 or C.py != 0 or C.px_ != 0 or C.py_ != 0:
x = x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
gx = x.reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3])
tx = np.lib.stride_tricks.as_strided(gx,
shape=(C.bs, C.groups, C.cin, C.oy, C.ox, C.H, C.W),

View File

@@ -126,7 +126,7 @@ class GPUBuffer:
if C is not None:
ints = ''.join(f"int {x} = {getattr(C, x)};" for x in ["H", "W", "ys", "xs", "dx", "dy", "px", "py", "groups", "rcout", "cin"])
params = [(f"int {x}", getattr(C, x)) for x in ["oy", "ox", "iy", "ix"]]
if C.px == 0 and C.py == 0: options.append("-DALLVALID")
if C.px == 0 and C.py == 0 and C.px_ == 0 and C.py_ == 0: options.append("-DALLVALID")
if C.oy == 1 and C.ox == 1: options.append("-DONEBYONE")
global_size = [C.bs*C.cout, C.oy, C.ox]
assert bufs[0][0] == "input" and bufs[1][0] == "weight"

View File

@@ -1,7 +1,7 @@
import torch
import numpy as np
from tinygrad.llops.ops_cpu import CPUBuffer
from tinygrad.ops import ProcessingOps
from tinygrad.ops import MovementOps, ProcessingOps
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class TorchBuffer(torch.Tensor):
@@ -23,4 +23,6 @@ class TorchBuffer(torch.Tensor):
def processing_op(x,op,w,C):
assert op == ProcessingOps.CONV, f"{op} isn't supported"
return torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px))
if C.px != C.px_ or C.py != C.py_: padding, x = 0, x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
else: padding = (C.py, C.px)
return torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=padding)

View File

@@ -203,8 +203,8 @@ class Conv2D(Function):
wt = ctx.movement_op(MovementOps.FLIP, wt, (3, 4))
wt = ctx.movement_op(MovementOps.PERMUTE, wt, (0, 2, 1, 3, 4))
wt = ctx.movement_op(MovementOps.RESHAPE, wt, (C.groups*C.cin, C.rcout, C.H, C.W))
Cdx = get_conv_args(xt.shape, wt.shape, dilation=(C.dy, C.dx), padding=((C.H-1)*C.dy-C.py,(C.W-1)*C.dx-C.px), groups=C.groups)
# TODO: this shape can be wrong. support asymmetric padding to remove the slice
Cdx = get_conv_args(xt.shape, wt.shape, dilation=(C.dy, C.dx), padding=((C.W-1)*C.dx-C.px, (C.W-1)*C.dx-C.px_, (C.H-1)*C.dy-C.py, (C.H-1)*C.dy-C.py_), groups=C.groups)
# TODO: this shape can be wrong strided. support asymmetric padding to remove the slice
dx = ctx._conv(xt, wt, Cdx)
dx = ctx.movement_op(MovementOps.SLICE, dx, [(0,s) for s in x.shape])
@@ -215,7 +215,7 @@ class Conv2D(Function):
xdw = ctx.movement_op(MovementOps.RESHAPE, xdw, (C.cin, C.groups*C.bs, C.iy, C.ix))
grad_output_dw = ctx.movement_op(MovementOps.PERMUTE, grad_output, (1,0,2,3))
grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output_dw, (C.cout, C.bs, C.oy, C.ox))
Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.py, C.px), stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.groups)
Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.px, C.px_, C.py, C.py_), stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.groups)
grad_weight = ctx._conv(xdw, grad_output_dw, Cdw)
grad_weight = ctx.movement_op(MovementOps.PERMUTE, grad_weight, (1,0,2,3))
# TODO: remove this slice using asymmetric padding