add convtranspose (#809)

* add convtranspose

* onnx convtranspose
This commit is contained in:
George Hotz
2023-05-26 12:35:03 -07:00
committed by GitHub
parent 04284414db
commit 26014a0fa1
4 changed files with 57 additions and 7 deletions

View File

@@ -37,7 +37,7 @@ def BatchNormalization(X, scale, B, input_mean, input_var, epsilon=1e-05, moment
else:
invstd = (input_var + epsilon)**-0.5
return X.batchnorm(scale, B, input_mean, invstd)
def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_type=1):
assert stash_type == 1, "only float32 is supported"
axis = tuple(i for i in range(axis if axis >= 0 else len(x.shape) + axis, len(x.shape)))
@@ -46,7 +46,7 @@ def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_typ
def GroupNormalization(x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05):
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
# onnx: [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
# numpy.pad: ((x1_begin, x1_end), (x2_begin, x2_end), ...)
def _format_padding(onnx_pads, ndims=None, axes=None):
@@ -90,6 +90,9 @@ def MaxPool(X, kernel_shape, auto_pad="NOTSET", ceil_mode=0, dilations=1, pads=N
def Conv(X, W, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, strides=1):
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=(pads[1], pads[3], pads[0], pads[2]) if pads is not None else 0)
def ConvTranspose(X, W, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, strides=1):
return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=(pads[1], pads[3], pads[0], pads[2]) if pads is not None else 0)
# TODO: copied from tensor.py
def Dropout(data, ratio=0.5, training_mode=False, seed=None):
# TODO: mask should be a boolean tensor
@@ -116,7 +119,7 @@ def Expand(input, shape):
return input.reshape(x_shape).expand(shape_ret)
def LRN(input, size, alpha=1e-4, beta=0.75, bias=1.0):
bs, c, iy, ix = input.shape
bs, c, iy, ix = input.shape
return input / input.mul(input).reshape(bs,1,c,iy*ix).pad2d((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1).reshape(bs,c,iy,ix).mul(alpha).add(bias).pow(beta)
def Identity(input): return input

View File

@@ -30,12 +30,12 @@ class TinygradBackend(Backend):
print("prepare", cls, device, net_feed_input)
run_onnx = get_run_onnx(model)
return TinygradModel(run_onnx, net_feed_input)
@classmethod
def supports_device(cls, device: str) -> bool:
return device == "CPU"
backend_test = onnx.backend.test.BackendTest(TinygradBackend, __name__)
backend_test = onnx.backend.test.BackendTest(TinygradBackend, __name__)
# add support for SoftmaxCrossEntropyLoss and NegativeLogLikelihoodLoss
backend_test.exclude('test_sce_*')

View File

@@ -131,14 +131,14 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: 2.0**x, lambda x: 2.0**x)
def test_sqrt(self):
helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, a=0)
def test_sin(self):
helper_test_op([(45,65)], lambda x: x.sin(), Tensor.sin, a=0)
def test_cos(self):
helper_test_op([(45,65)], lambda x: x.cos(), Tensor.cos, a=0)
def test_tan(self):
helper_test_op([(45,65)], lambda x: x.tan(), Tensor.tan, a=0)
def test_relu(self):
helper_test_op([(64,64)], lambda x: x.relu(), Tensor.relu)
def test_relu_exact(self):
@@ -415,6 +415,40 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
# conv transpose
def test_simple_conv_transpose2d(self):
helper_test_op([(2,4,9,9), (4,4,3,3)],
lambda x,w: torch.nn.functional.conv_transpose2d(x,w).relu(),
lambda x,w: Tensor.conv_transpose2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
def test_grouped_conv_transpose2d(self):
helper_test_op([(2,4,9,9), (4,4,3,3)],
lambda x,w: torch.nn.functional.conv_transpose2d(x,w,groups=2).relu(),
lambda x,w: Tensor.conv_transpose2d(x,w,groups=2).relu(), atol=1e-4, grad_rtol=1e-5)
def test_padded_conv_transpose2d(self):
helper_test_op([(2,4,9,9), (4,4,3,3)],
lambda x,w: torch.nn.functional.conv_transpose2d(x,w,padding=1).relu(),
lambda x,w: Tensor.conv_transpose2d(x,w,padding=1).relu(), atol=1e-4, grad_rtol=1e-5)
def test_dilated_conv_transpose2d(self):
helper_test_op([(2,4,9,9), (4,4,3,3)],
lambda x,w: torch.nn.functional.conv_transpose2d(x,w,dilation=2).relu(),
lambda x,w: Tensor.conv_transpose2d(x,w,dilation=2).relu(), atol=1e-4, grad_rtol=1e-5)
@unittest.skip("not currently supported")
def test_strided_conv_transpose2d(self):
helper_test_op([(2,4,9,9), (4,4,3,3)],
lambda x,w: torch.nn.functional.conv_transpose2d(x,w,stride=2).relu(),
lambda x,w: Tensor.conv_transpose2d(x,w,stride=2).relu(), atol=1e-4, grad_rtol=1e-5)
@unittest.skipIf(IMAGE>0, "no conv3d on images")
def test_simple_conv_transpose3d(self):
helper_test_op([(2,4,9,9,9), (4,4,3,3,3)],
lambda x,w: torch.nn.functional.conv_transpose3d(x,w).relu(),
lambda x,w: Tensor.conv_transpose2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
def test_conv2d(self):
for bs in [1,8]:
for cin in [1,3]:

View File

@@ -365,6 +365,19 @@ class Tensor:
def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
def max_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
x, w = self, weight.reshape(groups, weight.shape[0]//groups, weight.shape[1], *weight.shape[2:]).permute(0,2,1,*trailing).flip(trailing)
assert stride == 1, "stride is not supported in transposed conv"
# TODO: stride support. i believe this is correct, but you have to SHRINK the output tensor
#stride = make_pair(stride, len(HW))
#if any(s>1 for s in stride):
# x = x.reshape(bs, cin_, *flatten((k,1) for k in x.shape[2:]))
# x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride)))
# x = x.reshape(bs, cin_, *[k*s for k,s in zip(x.shape[2::2], stride)])
# TODO: the make_pair on padding is wrong in the asymmetric padding case
return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=flatten(((k-1)*d-p,(k-1)*d-p) for k,p,d in zip(HW, make_pair(padding, len(HW)), make_pair(dilation, len(HW)))))
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"