[fixed] onnx pool cleanup (#8474)

* pool janitor duty

* actually conv allows asymmetric pads

* a little prettier
This commit is contained in:
geohotstan
2025-01-03 05:56:10 +08:00
committed by GitHub
parent 08c9d980dc
commit de306c615b
2 changed files with 76 additions and 101 deletions

View File

@@ -2,7 +2,7 @@ import functools, io, math
from typing import cast, Literal
from tinygrad.tensor import Tensor, _broadcast_shape, ConstType, ReductionStr
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.helpers import prod, flatten
from tinygrad.helpers import prod, flatten, make_tuple
from extra.onnx import dtype_parse, to_python_const
import numpy as np
@@ -181,110 +181,79 @@ def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon
def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05):
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
# onnx: [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
# numpy.pad: ((x1_begin, x1_end), (x2_begin, x2_end), ...)
def _format_padding(onnx_pads, ndims=None, axes=None):
if ndims and len(onnx_pads)//2 != ndims: onnx_pads = onnx_pads * ndims # for OnnxBackendPyTorchConvertedModelTest the len(onnx_pads) == 2
if ndims is None: ndims = len(onnx_pads) // 2
if axes is None: axes = list(range(ndims))
num_axes = len(axes)
np_pads = [(0,0)] * ndims
for i in range(num_axes):
np_pads[axes[i]] = (onnx_pads[i], onnx_pads[i + num_axes])
return np_pads
# (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...)
def _onnx_pads_to_tiny_pads(pads): return flatten(reversed([(pB,pA) for pB, pA in zip(pads, pads[len(pads)//2:])]))
def _padded(X: Tensor, pads=None, auto_pad="NOTSET", axes=None, constant_value=0., strides=None, kernel_shape=None, dilations=None, ceil_mode=0):
if auto_pad != "NOTSET": pads = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
elif ceil_mode:
if strides is not None: strides = [strides]*len(kernel_shape) if isinstance(strides, int) else strides or [1]*len(kernel_shape)
if dilations is not None: dilations = [1]*len(kernel_shape) if dilations == 1 else dilations
out_spatial_shape = [math.ceil((sh - dil * (ker-1)-1)/st + 1) if ceil_mode else math.floor((sh - dil * (ker-1)-1)/st + 1) for sh, st, ker, dil in zip(X.shape[-len(kernel_shape):], strides, kernel_shape, dilations)]
pad_shape = [(osh-1)*st+((ks-1)*dil+1)-ish for osh, st, ks, dil, ish in zip(out_spatial_shape, strides, kernel_shape, dilations, X.shape[-len(kernel_shape):])]
pad_shape = [[sh//2, sh-sh//2] for sh in pad_shape]
# ceil_mode case follows NOTE in https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
# so if any kernels start in right padded region, we decrease right pads to omit that kernel. Only omitting 1 kernel now.
pad_shape = [[start,end-rpad] if (rpad := ks + st%(st-(((start+xs)%st)))) <= end else [start,end]
for (start,end), ks, st, xs in zip(pad_shape, kernel_shape, strides, X.shape[-len(kernel_shape):])]
pad_shape = flatten(pad_shape)
pads = pad_shape[::2] + pad_shape[1::2]
if pads is None: return X
pads = _format_padding(pads, ndims=len(X.shape), axes=axes)
return X.pad(tuple(pads), value=constant_value)
AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"]
# (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right)
def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS):
if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))]
return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))]
def _auto_pad(X: Tensor, auto_pad, strides, kernel_shape, dilations):
strides = [strides]*len(kernel_shape) if isinstance(strides, int) else strides or [1]*len(kernel_shape)
dilations = [1]*len(kernel_shape) if dilations == 1 else dilations
if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
pad_shape = [(math.ceil(sh/st)-1)*st+((ks-1)*di+1)-sh for sh, st, ks, di in zip(X.shape[-len(kernel_shape):], strides, kernel_shape, dilations)]
pad_shape = flatten([[sh//2, sh-sh//2] for sh in pad_shape])
return pad_shape[::2] + pad_shape[1::2] if auto_pad == "SAME_UPPER" else pad_shape[1::2] + pad_shape[::2]
raise NotImplementedError(f"auto_pad={auto_pad} not implemented")
# (x1_begin, x2_begin, ..., x1_end, x2_end, ...) -> (..., x2_start, x2_end, x1_start, x1_end)
def _onnx_pads_to_pad2d_pads(pads): return flatten(reversed(list((pB, pE) for pB, pE in zip(pads, pads[len(pads)//2:]))))
def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None, mode:str="constant", value=0):
value, axes = constant_value or value or 0, axes or list(range(x.ndim))
def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None,
mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0):
value = constant_value or value
axes = axes or list(range(x.ndim))
real_pads = [0] * (x.ndim*2)
for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)]
return x.pad(padding=_onnx_pads_to_pad2d_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value)
return x.pad(padding=_onnx_pads_to_tiny_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value)
def AveragePool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, count_include_pad=0, dilations=1, pads=None, strides=1):
pixel_axes = tuple(range(2, X.ndim))
ret = _padded(X, pads, auto_pad, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations, ceil_mode=ceil_mode)
ret = ret.avg_pool2d(kernel_shape, stride=strides, dilation=dilations)
if count_include_pad: return ret
div = _padded(Tensor.ones(X.shape), pads, auto_pad, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations, ceil_mode=ceil_mode).avg_pool2d(kernel_shape, stride=strides, dilation=dilations)
return ret / div
# NOTE: just to make sure that the pool pads are symmetric
def assert_symmetric_pads(pads):
# pads: (padding_left, padding_right, padding_top, padding_bottom, ...)
if not all(pB == pA for pB, pA in zip(pads[::2], pads[1::2])): raise ValueError("Pads must be symmetric")
# use symmetric pads
return pads[::-2]
def MaxPool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, dilations=1, pads=None, storage_order=0, strides=1):
pixel_axes = tuple(range(2, X.ndim))
ret = _padded(X, pads, auto_pad, constant_value=-math.inf, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations, ceil_mode=ceil_mode)
ret = ret.max_pool2d(kernel_shape, stride=strides, dilation=dilations).cast(X.dtype)
ret_len, X_len = ret.numel(), X.numel()
indices = ((ret.flatten().unsqueeze(1).expand(ret_len, X_len) == X.flatten().unsqueeze(0).expand(ret_len, X_len)) * \
Tensor.arange(X_len, dtype=dtypes.int64).unsqueeze(0).expand(ret_len, X_len)).sum(1).reshape(ret.shape)
if storage_order: indices = indices.transpose(-2, -1)
return ret, indices
def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS, allow_asymmetric=False):
def _apply_assertion(pads): return pads if allow_asymmetric else assert_symmetric_pads(pads)
i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_))
if auto_pad == "NOTSET": return _apply_assertion(_onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2))
o_ = [((i - (1 if auto_pad in ("SAME_UPPER", "SAME_LOWER") else k)) // s + 1) for i,k,s in zip(i_, k_, s_)]
return _apply_assertion(_onnx_pads_to_tiny_pads(_auto_pad([(o-1)*s+k-i for o,i,k,s in zip(o_, i_, k_, s_)], auto_pad)))
def MaxUnpool(xT: Tensor, xI: Tensor, outshape: Tensor|None=None, kernel_shape=None, pads=None, strides=None):
def AveragePool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, count_include_pad:int=0,
dilations:list[int]|int=1, pads:list[int]|int=0, strides:list[int]|int=1):
pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad)
return X.avg_pool2d(kernel_shape, strides, dilations, pads, ceil_mode=ceil_mode, count_include_pad=count_include_pad)
def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0,
storage_order:int=0, strides:list[int]|int=1):
pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad)
ret = X.max_pool2d(kernel_shape, strides, dilations, pads, ceil_mode=ceil_mode)
# tests expect indices with int64 dtype
# TODO: if there are repeated values, this is wrong
indices = ((ret.reshape(-1, 1) == X.reshape(1, -1)) * Tensor.arange(X.numel(), dtype=dtypes.int64).unsqueeze(0)).sum(1).reshape(ret.shape)
return ret.cast(X.dtype), indices.transpose(-2, -1) if storage_order else indices
def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1):
pads = _resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad, allow_asymmetric=True)
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=tuple(pads))
# src: https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvTranspose
# another src: https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_conv_transpose.py
def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
kernel_shape:list[int]|None=None, pads:list[int]|None=None, output_shape:list[int]|None=None, output_padding:list[int]|int=0,
strides:list[int]|int=1):
input_shape, kernel_shape = X.shape[2:], (kernel_shape or W.shape[2:])
strides, dilations, output_padding = (make_tuple(x, len(input_shape)) for x in (strides, dilations, output_padding))
if output_shape is not None: # we pad according to output_shape
pads = _auto_pad([s*(i-1) + op + ((k-1)*d+1) - os for s,i,op,k,d,os in
zip(strides, input_shape, output_padding, kernel_shape, dilations, output_shape)], auto_pad)
if pads is None: # we generate pads
output_shape = output_shape or [X.shape[i+2] * strides[i] for i in range(len(strides))]
pads = [strides[i]*(input_shape[i]-1) + output_padding[i] + ((kernel_shape[i]-1)*dilations[i]+1)-output_shape[i] for i in range(len(input_shape))]
pads = _auto_pad(pads, auto_pad) if auto_pad != "NOTSET" else [0] * len(input_shape) * 2
pads = assert_symmetric_pads(_onnx_pads_to_tiny_pads(pads))
return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads, output_padding=output_padding)
def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shape:list[int]=None, pads:list[int]|int=0, strides:list[int]|int=1):
pads, strides = (make_tuple(x, len(xI.shape)) for x in (pads, strides))
out_sh = [(ks//2)*2 + st * inps for inps, st, ks in zip(xI.shape, strides, kernel_shape)]
outlength = prod(out_sh)
xI = xI.flatten().unsqueeze(1).expand(None, outlength)
arange = Tensor.arange(outlength, requires_grad=False).reshape(1, outlength).expand(xI.shape)
xT = xT.flatten().unsqueeze(1).expand(None, outlength)
ret = ((xI == arange) * xT).sum(0).reshape([1, 1] + out_sh)
if outshape is not None and outshape != ret.shape:
diff = [outshape[2] - ret.shape[2], outshape[3] - ret.shape[3]]
pad_args = [diff[0]//2, diff[1]//2, diff[0]-diff[0]//2, diff[1]-diff[1]//2]
ret = ret.pad((pad_args[1], pad_args[3], pad_args[0], pad_args[2]))
return ret
def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, strides=1):
if auto_pad != "NOTSET":
padding = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
else:
# reorder padding
padding = [p for ps in zip(pads[:len(pads)//2][::-1], pads[len(pads)//2:][::-1]) for p in ps] if pads is not None else 0
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=padding)
def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, output_shape=None, output_padding=0, strides=1):
if kernel_shape is None: kernel_shape = W.shape[2:]
if isinstance(strides, int): strides = [strides]*(W.ndim-2)
if isinstance(dilations, int): dilations = [dilations]*(W.ndim-2)
if isinstance(output_padding, int): output_padding = [output_padding]*(W.ndim-2)
out_sh = [st*(xs-1) + (ks-1)*di+1 if n < 2 else st*(xs-1) + (ks-1)*di+1 - pads[n-2] - pads[n-1] for n, (st, xs, ks, di) in enumerate(zip(strides, X.shape[2:], kernel_shape, dilations))] if output_shape is not None or auto_pad != "NOTSET" else []
if pads is None:
if output_shape is None: output_shape = [xs*st for xs, st in zip(X.shape[2:], strides)]
if auto_pad == "NOTSET": pads = [0,0] * (X.ndim - 2)
else:
total_padding = [st*(ish-1) + pad + ((ks-1)*dil+1)-osh for st, ish, pad, ks, dil, osh in zip(strides, X.shape[2:], output_padding, kernel_shape, dilations, output_shape)]
pad_shape = flatten([[sh//2, sh-sh//2] for sh in total_padding])
pads = pad_shape[::2] + pad_shape[1::2] if auto_pad == "SAME_UPPER" else pad_shape[1::2] + pad_shape[::2]
else:
if output_shape is None: output_shape = [st*(xs-1) + (ks-1)*di+1 if n < 2 else st*(xs-1) + (ks-1)*di+1 - pads[n-2] - pads[n-1] for n, (st, xs, ks, di) in enumerate(zip(strides, X.shape[2:], kernel_shape, dilations))]
if out_sh: output_padding = [os - rs for os, rs in zip(output_shape, out_sh)]
return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads if pads is not None else 0, output_padding=output_padding)
ret = (xI.reshape(-1, 1)._one_hot_along_dim(prod(out_sh)) * xT.reshape(-1, 1)).sum(0).reshape(1, 1, *out_sh)
if outshape is not None and outshape != ret.shape: pads = _auto_pad([outshape[-2] - ret.shape[-2], outshape[-1] - ret.shape[-1]], "SAME_UPPER")
return ret.pad(_onnx_pads_to_tiny_pads(pads))
def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"):
return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize)

View File

@@ -40,11 +40,9 @@ class TinygradBackend(Backend):
backend_test = onnx.backend.test.BackendTest(TinygradBackend, __name__)
# TODO: there isn't an AttributeProto for `epsilon` in the NodeProto for 'test_adam_multiple_cpu'
# [x.name for x in n.attribute] -> ['alpha', 'beta', 'norm_coefficient']
# but in their documentation https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-176, it states there being an epsilon of 1e-2
# test passes with epsilon = 1e-2
# BUG: buggy onnx tests
backend_test.exclude('test_adam_multiple_cpu')
backend_test.exclude('test_averagepool_3d_dilations_large_count_include_pad_is_1_ceil_mode_is_True_cpu')
# about different dtypes
if not is_dtype_supported(dtypes.float64):
@@ -177,6 +175,14 @@ backend_test.exclude('test_scatternd_min_cpu') # min not yet supported
backend_test.exclude('test_scatter_elements_with_reduction_max_cpu') # max not yet supported
backend_test.exclude('test_scatternd_max_cpu') # max not yet supported
# asymmetric pads
backend_test.exclude('test_averagepool_2d_same_lower_cpu')
backend_test.exclude('test_averagepool_2d_same_upper_cpu')
backend_test.exclude('test_convtranspose_autopad_same_cpu')
backend_test.exclude('test_convtranspose_output_shape_cpu')
backend_test.exclude('test_maxpool_2d_same_lower_cpu')
backend_test.exclude('test_maxpool_2d_same_upper_cpu')
if Device.DEFAULT in ['GPU', 'METAL']:
backend_test.exclude('test_resize_upsample_sizes_nearest_axes_2_3_cpu')
backend_test.exclude('test_resize_upsample_sizes_nearest_axes_3_2_cpu')