fix onnx mobilenetv2-7-quantized.onnx (#8574)

* is 67% considered fixed?

* move test up

* share function

* add qgemm too

* make sure qgemm comes out as int

* actually that note is not right

* remove qgemm (I did it wrong) and add it later lol.
This commit is contained in:
geohotstan
2025-01-14 01:25:06 +08:00
committed by GitHub
parent d19c1c7f03
commit 4abe631b56
3 changed files with 18 additions and 13 deletions

View File

@@ -11,10 +11,10 @@ from tinygrad.helpers import fetch, getenv
# ~43% - https://github.com/onnx/models/raw/refs/heads/main/Computer_Vision/alexnet_Opset16_torch_hub/alexnet_Opset16.onnx
# ~72% - https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx
# ~71% - https://github.com/axinc-ai/onnx-quantization/raw/refs/heads/main/models/mobilenetv2_1.0.opt.onnx
# ~67% - https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7-quantized.onnx
# broken:
# https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx
# https://huggingface.co/qualcomm/MobileNet-v2-Quantized/resolve/main/MobileNet-v2-Quantized.onnx
# https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7-quantized.onnx
# ~35% - https://github.com/axinc-ai/onnx-quantization/raw/refs/heads/main/models/mobilenev2_quantized.onnx
def imagenet_dataloader(cnt=0):

View File

@@ -424,23 +424,27 @@ def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:i
x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size)
return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype)
def _quantize_linear(y:Tensor, y_scale:Tensor, y_zero_point:Tensor):
assert y_scale.dtype is dtypes.float32 and y_zero_point.dtype in {dtypes.uint8, dtypes.int8}, "used only for qlinear ops"
y = (y / y_scale + y_zero_point).round()
return y.clamp(dtypes.min(y_zero_point.dtype), dtypes.max(y_zero_point.dtype)).cast(y_zero_point.dtype)
def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor,
y_zero_point: Tensor|int, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:int|list[int]=1, group:int=1,
kernel_shape:list[int]|None=None, pads:int|list[int]=0, strides:int|list[int]=1):
x = x.int() - x_zero_point
w = w.int() - w_zero_point
y = Conv(x, w, B, auto_pad, dilations, group, kernel_shape, pads, strides)
y = ((y * (x_scale * w_scale / y_scale)) + y_zero_point).round()
return y.cast(y_zero_point.dtype)
y_scale = y_scale / (x_scale * w_scale)
return _quantize_linear(y, y_scale, y_zero_point)
def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor,
y_zero_point:Tensor|int) -> Tensor:
a = a.int() - a_zero_point
b = b.int() - b_zero_point
y = Tensor.matmul(a, b, acc_dtype=dtypes.int32)
y = ((y * (a_scale * b_scale / y_scale)) + y_zero_point).round()
# cast to int first because result expects overflow/underflow wrap around
return y.int().cast(y_zero_point.dtype)
y_scale = y_scale / (a_scale * b_scale)
return _quantize_linear(y, y_scale, y_zero_point)
def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None,
auto_pad: AUTO_PAD_OPTIONS = "NOTSET", dilations: int | list[int] = 1, group: int = 1, kernel_shape: list[int] | None = None,
@@ -552,16 +556,14 @@ def QLinearAdd(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:
a = a.int() - a_zero_point
b = b.int() - b_zero_point
c = (a * a_scale + b * b_scale)
c = ((c / c_scale) + c_zero_point).round()
return c.cast(c_zero_point.dtype)
return _quantize_linear(c, c_scale, c_zero_point)
def QLinearGlobalAveragePool(X:Tensor, x_scale:Tensor, x_zero_point:Tensor, y_scale:Tensor, y_zero_point:Tensor, channels_last:int):
assert channels_last in {0, 1}
if channels_last == 1: X = X.permute(0, 2, 3, 1)
X = (X.int() - x_zero_point) * x_scale
y = GlobalAveragePool(X)
y = (y / y_scale + y_zero_point).round()
return y.cast(y_zero_point.dtype)
return _quantize_linear(y, y_scale, y_zero_point)
# **************** ai.onnx.preview.training Ops ****************
# NOTE: onnx test coverage only covers `T==0` cases, so for all `T>0` this isn't tested

View File

@@ -44,6 +44,12 @@ backend_test = onnx.backend.test.BackendTest(TinygradBackend, __name__)
backend_test.exclude('test_adam_multiple_cpu')
backend_test.exclude('test_averagepool_3d_dilations_large_count_include_pad_is_1_ceil_mode_is_True_cpu')
# BUG: onnxruntime 1.20.1 fails these tests too
backend_test.exclude('test_qlinearmatmul_2D_int8_float16_cpu')
backend_test.exclude('test_qlinearmatmul_3D_int8_float16_cpu')
backend_test.exclude('test_qlinearmatmul_2D_int8_float32_cpu')
backend_test.exclude('test_qlinearmatmul_3D_int8_float32_cpu')
# about different dtypes
if not is_dtype_supported(dtypes.float64):
backend_test.exclude('float64')
@@ -86,9 +92,6 @@ backend_test.exclude('test_dequantizelinear_e4m3fn_cpu')
backend_test.exclude('test_dequantizelinear_e4m3fn_zero_point_cpu')
backend_test.exclude('test_dequantizelinear_e5m2_cpu')
# TODO: unsure. The number is off by 1. and it's not because of float16
backend_test.exclude('test_qlinearmatmul_3D_int8_float16_cpu')
# we don't support indexes
backend_test.exclude('test_nonzero_*')