mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
add onnx DequantizeLinear (#8468)
* is this right? * small changes * dont support float8 * mergeable?
This commit is contained in:
@@ -36,8 +36,7 @@ DTYPE_MAP: dict[int, DType] = {
|
||||
TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8,
|
||||
TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64,
|
||||
TensorProto.BOOL:dtypes.bool, TensorProto.FLOAT16:dtypes.float32, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32,
|
||||
TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16, TensorProto.FLOAT8E4M3FN:dtypes.float,
|
||||
TensorProto.FLOAT8E4M3FNUZ:dtypes.float, TensorProto.FLOAT8E5M2:dtypes.float, TensorProto.FLOAT8E5M2FNUZ:dtypes.float
|
||||
TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16,
|
||||
}
|
||||
def dtype_parse(onnx_dtype: int) -> DType:
|
||||
if onnx_dtype not in DTYPE_MAP: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported")
|
||||
|
||||
@@ -25,7 +25,6 @@ def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0)
|
||||
def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0)
|
||||
def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0)
|
||||
def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0)
|
||||
# NOTE: does not support saturate
|
||||
def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to))
|
||||
def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype)
|
||||
|
||||
@@ -448,14 +447,22 @@ def EyeLike(x:Tensor, dtype:int|None=None, k:int=0):
|
||||
|
||||
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated
|
||||
|
||||
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int = 0, axis:int=1, block_size:int=0):
|
||||
def _prepare_quantize_linear(x, scale, zero_point, axis, block_size):
|
||||
if axis < 0: axis += x.ndim
|
||||
if not isinstance(x_zero_point, Tensor): x_zero_point = Tensor(x_zero_point)
|
||||
if block_size: x_zer, x_sc = x_zero_point.repeat_interleave(block_size, axis), x_scale.repeat_interleave(block_size, axis)
|
||||
else:
|
||||
shape = (*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim))
|
||||
x_sc, x_zer = x_scale.reshape(shape), x_zero_point.reshape(shape)
|
||||
return ((x.float() - x_zer) * x_sc).cast(x_scale.dtype)
|
||||
if not isinstance(zero_point, Tensor): zero_point = Tensor(zero_point, dtype=dtypes.uint8)._broadcast_to(scale.shape)
|
||||
if block_size == 0:
|
||||
shape = (*[1]*axis, *scale.shape, *[1]*(x.ndim - axis - scale.ndim))
|
||||
return scale.reshape(shape), zero_point.reshape(shape)
|
||||
return scale.repeat_interleave(block_size, dim=axis), zero_point.repeat_interleave(block_size, dim=axis)
|
||||
|
||||
def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
|
||||
out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8
|
||||
y_scale, y_zero_point = _prepare_quantize_linear(x, y_scale, y_zero_point, axis, block_size)
|
||||
return ((x / y_scale).round() + y_zero_point).clamp(dtypes.min(out_dtype), dtypes.max(out_dtype)).cast(out_dtype)
|
||||
|
||||
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0):
|
||||
x_scale, x_zero_point = _prepare_quantize_linear(x, x_scale, x_zero_point, axis, block_size)
|
||||
return ((x.float() - x_zero_point) * x_scale).cast(x_scale.dtype)
|
||||
|
||||
# copied from https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_image_decoder.py
|
||||
def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"):
|
||||
|
||||
12
test/external/external_test_onnx_backend.py
vendored
12
test/external/external_test_onnx_backend.py
vendored
@@ -79,6 +79,17 @@ backend_test.exclude('test_matmulinteger_*')
|
||||
|
||||
backend_test.exclude('test_dequantizelinear_int4_cpu')
|
||||
backend_test.exclude('test_dequantizelinear_uint4_cpu')
|
||||
backend_test.exclude('test_quantizelinear_int4_cpu')
|
||||
backend_test.exclude('test_quantizelinear_uint4_cpu')
|
||||
|
||||
# no support for FLOAT8
|
||||
backend_test.exclude('test_quantizelinear_e4m3fn_cpu')
|
||||
backend_test.exclude('test_quantizelinear_e5m2_cpu')
|
||||
backend_test.exclude('test_quantizelinear_e4m3fn_cpu')
|
||||
backend_test.exclude('test_quantizelinear_e5m2_cpu')
|
||||
backend_test.exclude('test_dequantizelinear_e4m3fn_cpu')
|
||||
backend_test.exclude('test_dequantizelinear_e4m3fn_zero_point_cpu')
|
||||
backend_test.exclude('test_dequantizelinear_e5m2_cpu')
|
||||
|
||||
# we don't support indexes
|
||||
backend_test.exclude('test_nonzero_*')
|
||||
@@ -101,7 +112,6 @@ backend_test.exclude('test_regex_*')
|
||||
backend_test.exclude('test_dynamicquantizelinear_*')
|
||||
backend_test.exclude('test_qlinearmatmul_*')
|
||||
backend_test.exclude('test_qlinearconv_*')
|
||||
backend_test.exclude('test_quantizelinear_*')
|
||||
|
||||
# no rnn
|
||||
backend_test.exclude('test_gru_*')
|
||||
|
||||
Reference in New Issue
Block a user