add onnx DequantizeLinear (#8468)

* is this right?

* small changes

* dont support float8

* mergeable?
This commit is contained in:
geohotstan
2025-01-02 22:52:49 +08:00
committed by GitHub
parent f2bee34197
commit c4b13e2f6d
3 changed files with 27 additions and 11 deletions

View File

@@ -36,8 +36,7 @@ DTYPE_MAP: dict[int, DType] = {
TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8,
TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64,
TensorProto.BOOL:dtypes.bool, TensorProto.FLOAT16:dtypes.float32, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32,
TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16, TensorProto.FLOAT8E4M3FN:dtypes.float,
TensorProto.FLOAT8E4M3FNUZ:dtypes.float, TensorProto.FLOAT8E5M2:dtypes.float, TensorProto.FLOAT8E5M2FNUZ:dtypes.float
TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16,
}
def dtype_parse(onnx_dtype: int) -> DType:
if onnx_dtype not in DTYPE_MAP: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported")

View File

@@ -25,7 +25,6 @@ def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0)
def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0)
def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0)
def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0)
# NOTE: does not support saturate
def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to))
def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype)
@@ -448,14 +447,22 @@ def EyeLike(x:Tensor, dtype:int|None=None, k:int=0):
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int = 0, axis:int=1, block_size:int=0):
def _prepare_quantize_linear(x, scale, zero_point, axis, block_size):
if axis < 0: axis += x.ndim
if not isinstance(x_zero_point, Tensor): x_zero_point = Tensor(x_zero_point)
if block_size: x_zer, x_sc = x_zero_point.repeat_interleave(block_size, axis), x_scale.repeat_interleave(block_size, axis)
else:
shape = (*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim))
x_sc, x_zer = x_scale.reshape(shape), x_zero_point.reshape(shape)
return ((x.float() - x_zer) * x_sc).cast(x_scale.dtype)
if not isinstance(zero_point, Tensor): zero_point = Tensor(zero_point, dtype=dtypes.uint8)._broadcast_to(scale.shape)
if block_size == 0:
shape = (*[1]*axis, *scale.shape, *[1]*(x.ndim - axis - scale.ndim))
return scale.reshape(shape), zero_point.reshape(shape)
return scale.repeat_interleave(block_size, dim=axis), zero_point.repeat_interleave(block_size, dim=axis)
def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8
y_scale, y_zero_point = _prepare_quantize_linear(x, y_scale, y_zero_point, axis, block_size)
return ((x / y_scale).round() + y_zero_point).clamp(dtypes.min(out_dtype), dtypes.max(out_dtype)).cast(out_dtype)
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0):
x_scale, x_zero_point = _prepare_quantize_linear(x, x_scale, x_zero_point, axis, block_size)
return ((x.float() - x_zero_point) * x_scale).cast(x_scale.dtype)
# copied from https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_image_decoder.py
def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"):

View File

@@ -79,6 +79,17 @@ backend_test.exclude('test_matmulinteger_*')
backend_test.exclude('test_dequantizelinear_int4_cpu')
backend_test.exclude('test_dequantizelinear_uint4_cpu')
backend_test.exclude('test_quantizelinear_int4_cpu')
backend_test.exclude('test_quantizelinear_uint4_cpu')
# no support for FLOAT8
backend_test.exclude('test_quantizelinear_e4m3fn_cpu')
backend_test.exclude('test_quantizelinear_e5m2_cpu')
backend_test.exclude('test_quantizelinear_e4m3fn_cpu')
backend_test.exclude('test_quantizelinear_e5m2_cpu')
backend_test.exclude('test_dequantizelinear_e4m3fn_cpu')
backend_test.exclude('test_dequantizelinear_e4m3fn_zero_point_cpu')
backend_test.exclude('test_dequantizelinear_e5m2_cpu')
# we don't support indexes
backend_test.exclude('test_nonzero_*')
@@ -101,7 +112,6 @@ backend_test.exclude('test_regex_*')
backend_test.exclude('test_dynamicquantizelinear_*')
backend_test.exclude('test_qlinearmatmul_*')
backend_test.exclude('test_qlinearconv_*')
backend_test.exclude('test_quantizelinear_*')
# no rnn
backend_test.exclude('test_gru_*')