Bump onnx to 1.18.0 (#11266)

* bump

* thou hast implement functions

* hacked in domain support

* some clean ups

* hack quantize_onnx_test too

* add helper lol, why onnx tests why

* better dispatcher, but need tests and better naming

* flaky ci

* change some names

* small clean ups

* make it easier to clean up tests once ORT supports 1.18.0

* nits

* fix bug of Softmax_1 being registered in onnx_ops

* need a default value

* resolve_const is better name

* fix OnnxRunner.to

* use proper domain names
This commit is contained in:
geohotstan
2025-07-18 03:35:41 +08:00
committed by GitHub
parent 1606491b1c
commit 536b254df4
5 changed files with 156 additions and 45 deletions

View File

@@ -1,8 +1,8 @@
from types import SimpleNamespace
from typing import Any, Sequence, cast, Literal, Callable
import dataclasses, functools, io, math, types, warnings, pathlib, sys
from typing import Any, Sequence, cast, Literal, Callable, get_args, NamedTuple
import dataclasses, functools, io, math, types, warnings, pathlib, sys, enum
from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr
from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple, argsort, is_numpy_ndarray
from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple, argsort, is_numpy_ndarray, get_single_element
from tinygrad.dtype import DType, ConstType, dtypes, _from_np_dtype
from tinygrad.device import is_dtype_supported, Device
from extra.onnx_parser import onnx_load
@@ -95,10 +95,24 @@ class OnnxValue:
is_optional: bool
is_sequence: bool
class Domain(enum.StrEnum):
ONNX = "ai.onnx"
ONNX_ML = "ai.onnx.ml"
AI_ONNX_TRAINING = "ai.onnx.training"
AI_ONNX_PREVIEW_TRAINING = "ai.onnx.preview.training"
MICROSOFT_CONTRIB_OPS = "com.microsoft"
@classmethod
def from_onnx(cls, domain: str | None) -> "Domain": return cls.ONNX if domain is None or domain == "" else cls(domain)
class OpSetId(NamedTuple):
domain: Domain
version: int
@dataclasses.dataclass(frozen=True)
class OnnxNode:
num: int
op: str
opset_id: OpSetId
inputs: tuple[str, ...]
outputs: tuple[str, ...]
opts: dict[str, Any]
@@ -141,15 +155,19 @@ class OnnxRunner:
"""
def __init__(self, model_path: Tensor | str | pathlib.Path):
model = onnx_load(model_path)
self.is_training = any(n.domain in {"ai.onnx.training", "ai.onnx.preview.training"} for n in model.graph.node)
self.is_training = any(n.domain in {Domain.AI_ONNX_TRAINING, Domain.AI_ONNX_PREVIEW_TRAINING} for n in model.graph.node)
self.old_training = Tensor.training
Tensor.training = True if self.is_training else False
self.graph_values = {"": None, **{x.name:buffer_parse(x) for x in model.graph.initializer}}
self.graph_inputs = {x.name:type_parse(x.type) for x in model.graph.input if x.name not in self.graph_values}
self.graph_outputs = tuple(x.name for x in model.graph.output)
self.graph_nodes = tuple(OnnxNode(num, n.op_type, tuple(n.input), tuple(n.output), {x.name:attribute_parse(x) for x in n.attribute})
for num,n in enumerate(model.graph.node))
self.opset_version = model.opset_import[0].version
opset_imports = {Domain.from_onnx(getattr(x, "domain", "")):x.version for x in model.opset_import}
self.graph_nodes = []
for num, n in enumerate(model.graph.node):
domain = Domain.from_onnx(n.domain)
opset_id = OpSetId(domain, opset_imports.get(domain, 1))
self.graph_nodes.append(OnnxNode(num, n.op_type, opset_id, tuple(n.input), tuple(n.output), {x.name:attribute_parse(x) for x in n.attribute}))
self.graph_nodes = tuple(self.graph_nodes)
self.variable_dims: dict[str, int] = {}
self.onnx_ops = onnx_ops
@@ -171,23 +189,22 @@ class OnnxRunner:
if user_dim_input != onnx_dim: raise RuntimeError(f"input {name} has mismatch on {dim=}. Expected {onnx_dim}, received {user_dim_input}.")
return tensor
def _dispatch_op(self, op, inps, opts):
if op in self.onnx_ops:
fxn = self.onnx_ops[op]
if isinstance(fxn, dict):
for k in sorted(fxn.keys()):
if k <= self.opset_version:
real_fxn = fxn[k]
else: real_fxn = fxn
return real_fxn(*inps, **opts)
raise NotImplementedError(f"{op=} not supported")
def _select_op(self, op:str, required_opset:OpSetId) -> types.FunctionType:
if op not in self.onnx_ops: raise NotImplementedError(f"{op=} is not supported")
# return default implementation if no opset_id is specified
if isinstance(impl := self.onnx_ops[op], types.FunctionType): return impl
# match domain and select implementation with latest compatible version
eligible_ops = {impl_opset.version:impl_fxn for impl_opset,impl_fxn in impl.items()
if impl_opset.domain == required_opset.domain and impl_opset.version <= required_opset.version}
if not eligible_ops: raise NotImplementedError(f"{op=} is not supported for domain {required_opset.domain} and version {required_opset.version}")
return eligible_ops[max(eligible_ops.keys())]
def get_empty_input_data(self, device:str|None=None, dtype:DType|None=None) -> dict[str, Tensor]:
return {name:Tensor.empty(*spec.shape, device=device, dtype=dtype or spec.dtype) for name, spec in self.graph_inputs.items()}
def to(self, device:str|None):
self.graph_values = {k:v.to(device) if isinstance(v, Tensor) else v for k,v in self.graph_values.items()}
self.graph_nodes = tuple(OnnxNode(n.num, n.op, tuple(n.inputs), tuple(n.outputs),
self.graph_nodes = tuple(OnnxNode(n.num, n.op, n.opset_id, tuple(n.inputs), tuple(n.outputs),
{k:v.to(device) if isinstance(v, Tensor) else v for k,v in n.opts.items()}) for n in self.graph_nodes)
return self
@@ -206,7 +223,7 @@ class OnnxRunner:
if debug >= 1: print(f"{node.num}: op '{node.op}' opt {opts}")
if debug >= 2 and node.inputs: print("\tinputs:\n" + "\n".join(f"\t\t{x} - {i!r}" for x,i in zip(node.inputs, inps)))
ret = self._dispatch_op(node.op, inps, opts)
ret = self._select_op(node.op, node.opset_id)(*inps, **opts)
ret = ret if isinstance(ret, tuple) else (ret,)
if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{x} - {o!r}" for x,o in zip(node.outputs, ret)))
@@ -221,8 +238,10 @@ class OnnxRunner:
####################
##### ONNX OPS #####
####################
def get_onnx_ops():
def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionType]]:
# ***** helper functions *****
def _resolve_const(x: Sequence[ConstType]|ConstType): return x if isinstance(x, get_args(ConstType)) else get_single_element(x)
def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None)
# (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...)
@@ -293,7 +312,8 @@ def get_onnx_ops():
if value_string is not None or value_strings is not None and sparse_value is not None:
raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value')
def Range(start:float|int, limit:float|int, delta:float|int): return Tensor.arange(start=start, stop=limit, step=delta)
def Range(start:float|int|list[float|int], limit:float|int|list[float|int], delta:float|int|list[float|int]):
return Tensor.arange(start=_resolve_const(start), stop=_resolve_const(limit), step=_resolve_const(delta))
def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"):
try: import PIL.Image
@@ -324,9 +344,9 @@ def get_onnx_ops():
def IsInf(x:Tensor, detect_negative:int=1, detect_positive:int=1): return x.isinf(bool(detect_positive), bool(detect_negative))
# ***** Unary Ops (activation) *****
def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis)
def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis)
Softmax = {1:Softmax_1, 13:Softmax_13}
def softmax_1(x:Tensor, axis:int=1): return x.softmax(axis)
def softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis)
Softmax = {OpSetId(Domain.ONNX, 1):softmax_1, OpSetId(Domain.ONNX, 13):softmax_13}
def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1)
def Gelu(x:Tensor, approximate:str|None=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf())
def BiasGelu(x: Tensor, bias: Tensor, approximate: str | None = None) -> Tensor: return Gelu(x + bias, approximate)
@@ -481,14 +501,16 @@ def get_onnx_ops():
def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs)
def CumSum(X:Tensor, axis:int|list, exclusive:int=0, reverse:int=0):
axis = X._resolve_dim(axis[0] if isinstance(axis, list) else axis)
def CumSum(X:Tensor, axis:int|list[int], exclusive:int=0, reverse:int=0):
axis = X._resolve_dim(_resolve_const(axis))
if reverse: X = X.flip(axis)
if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\
.shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim)))
return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis)
def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k)
def Trilu(x:Tensor, k:int|list[int]=0, upper:int=1):
k_ = _resolve_const(k)
return x.triu(k_) if upper else x.tril(k_)
def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, sizes:list[int]|None=None, antialias:int=0,
axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0,
@@ -550,7 +572,7 @@ def get_onnx_ops():
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated
def TopK(X:Tensor, K:int|list[int], axis:int=-1, largest:int=1, sorted:int=1): # noqa: A002
val, idx = X.topk(K if isinstance(K, int) else K[0], axis, largest, sorted)
val, idx = X.topk(_resolve_const(K), axis, largest, sorted)
return val, idx.cast(dtypes.int64)
# ***** Neural Network Ops *****
@@ -612,9 +634,9 @@ def get_onnx_ops():
def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]):
return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9)
def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1):
def OneHot(indices:Tensor, depth:float|int|list[int|float], values:Tensor, axis:int=-1):
# Scalar or Rank 1 tensor containing exactly one element
depth = int(depth[0] if isinstance(depth, list) else depth)
depth = int(_resolve_const(depth))
indices = indices.int()
indices = (indices < 0).where(indices+depth, indices)
return indices.unsqueeze(axis)._one_hot_along_dim(depth, dim=axis).where(values[1], values[0])
@@ -625,7 +647,7 @@ def get_onnx_ops():
return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize)
# Reimplemented here because you need legacy RNG for passing ONNX tests.
def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None):
def dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None):
if not training_mode: return data, data.full_like(True, dtype=dtypes.bool)
if seed is not None:
rand = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)), requires_grad=False, dtype=data.dtype, device=data.device)
@@ -634,8 +656,8 @@ def get_onnx_ops():
mask = rand >= ratio
return data * mask / (1.0 - ratio), mask
# 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx
def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test)
Dropout = {6:Dropout_6, 7:Dropout_7}
def dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return dropout_7(data, ratio, training_mode=not is_test)
Dropout = {OpSetId(Domain.ONNX, 6):dropout_6, OpSetId(Domain.ONNX, 7):dropout_7}
def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0):
pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1)
@@ -657,10 +679,10 @@ def get_onnx_ops():
base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1)
return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1)
def Attention(x:Tensor, weights:Tensor, bias:Tensor|None=None, mask_index:Tensor|None=None, past:Tensor|None=None, attention_bias:Tensor|None=None,
past_sequence_length:Tensor|None=None, do_rotary:int=0, mask_filter_value:float=-10000.0, num_heads:int|None=None,
past_present_share_buffer:int|None=None, qkv_hidden_sizes:list[int]|None=None, rotary_embedding_dim:int|None=None,
scale:float|None=None, unidirectional:int=0):
def attention_contrib(x:Tensor, weights:Tensor, bias:Tensor|None=None, mask_index:Tensor|None=None, past:Tensor|None=None,
attention_bias:Tensor|None=None, past_sequence_length:Tensor|None=None, do_rotary:int=0, mask_filter_value:float=-10000.0,
num_heads:int|None=None, past_present_share_buffer:int|None=None, qkv_hidden_sizes:list[int]|None=None,
rotary_embedding_dim:int|None=None, scale:float|None=None, unidirectional:int=0):
assert not do_rotary and not attention_bias, "TODO"
if qkv_hidden_sizes is None: qkv_hidden_sizes = [weights.shape[1] // 3] * 3
qkv = x.linear(weights, bias)
@@ -701,6 +723,86 @@ def get_onnx_ops():
output = output.transpose(1, 2).reshape(batch_size, seq_len, -1)
return output, present
def attention_onnx(Q:Tensor, K:Tensor, V:Tensor, attn_mask:Tensor|None=None, past_key:Tensor|None=None, past_value:Tensor|None=None,
is_causal:int=0, kv_num_heads:int|None=None, q_num_heads:int|None=None, qk_matmul_output_mode:int=0, scale:float|None=None,
softcap:float=0.0, softmax_precision:int|None=None):
input_shape_len = Q.ndim
if input_shape_len == 3:
assert q_num_heads is not None and kv_num_heads is not None
Q = Q.reshape(Q.shape[0], q_num_heads, Q.shape[1], -1)
K = K.reshape(K.shape[0], kv_num_heads, K.shape[1], -1)
V = V.reshape(V.shape[0], kv_num_heads, V.shape[1], -1)
if past_key is not None: K = past_key.cat(K, dim=2)
if past_value is not None: V = past_value.cat(V, dim=2)
present_key, present_value = K, V
_q_heads, _kv_heads = q_num_heads or Q.shape[1], kv_num_heads or K.shape[1]
if _q_heads != _kv_heads:
K = K.repeat((1, _q_heads // _kv_heads, 1, 1))
V = V.repeat((1, _q_heads // _kv_heads, 1, 1))
effective_scale = scale if scale is not None else 1.0 / (Q.shape[-1] ** 0.5)
scores = (Q @ K.transpose(-1, -2)) * effective_scale
qk_matmul_return_val = scores
if is_causal:
causal_mask = Tensor.ones(Q.shape[-2], K.shape[-2], device=Q.device, dtype=dtypes.bool, requires_grad=False).tril(0)
scores = scores.masked_fill(causal_mask.logical_not(), -float("inf"))
if attn_mask is not None:
mask_to_add = attn_mask.where(0, -float("inf")) if attn_mask.dtype == dtypes.bool else attn_mask
scores = scores + mask_to_add
if qk_matmul_output_mode == 1: qk_matmul_return_val = scores
if softcap > 0.0: scores = (scores / softcap).tanh() * softcap
if qk_matmul_output_mode == 2: qk_matmul_return_val = scores
if softmax_precision: scores = scores.cast({1: dtypes.float32, 10: dtypes.float16, 16: dtypes.bfloat16}[softmax_precision])
qk_softmax = scores.softmax(-1).cast(Q.dtype)
if qk_matmul_output_mode == 3: qk_matmul_return_val = qk_softmax
output = (qk_softmax @ V).cast(Q.dtype)
if input_shape_len == 3: output = output.permute(0, 2, 1, 3).reshape(Q.shape[0], Q.shape[2], -1)
return output, present_key, present_value, qk_matmul_return_val
Attention = {OpSetId(Domain.ONNX, 1): attention_onnx, OpSetId(Domain.MICROSOFT_CONTRIB_OPS, 1): attention_contrib}
def RMSNormalization(X:Tensor, scale:Tensor, axis:int=-1, epsilon:float=1e-5):
norm = X.square().mean(axis=tuple(range(axis + X.ndim if axis < 0 else axis, X.ndim)), keepdim=True).add(epsilon).rsqrt()
return X * norm * scale
def RotaryEmbedding(X:Tensor, cos_cache:Tensor, sin_cache:Tensor, position_ids:Tensor|None=None, interleaved:int=0, num_heads:int|None=None,
rotary_embedding_dim:int=0):
original_input_shape = X.shape
if X.ndim == 4: X = X.permute(0, 2, 1, 3)
elif X.ndim == 3:
assert num_heads is not None, "num_heads must be provided for 3D input"
X = X.reshape(*X.shape[:-1], num_heads, X.shape[-1] // num_heads)
head_size = X.shape[-1]
rot_dim = rotary_embedding_dim or head_size
x_rotate, x_pass = X[..., :rot_dim], X[..., rot_dim:]
cos = cos_cache[position_ids] if position_ids is not None else cos_cache[:X.shape[1]]
sin = sin_cache[position_ids] if position_ids is not None else sin_cache[:X.shape[1]]
cos = cos[..., :rot_dim//2].unsqueeze(2)
sin = sin[..., :rot_dim//2].unsqueeze(2)
if interleaved:
x1, x2 = x_rotate[..., ::2], x_rotate[..., 1::2]
real = x1 * cos - x2 * sin
imag = x1 * sin + x2 * cos
x_rotated = Tensor.stack(real, imag, dim=-1).flatten(start_dim=-2)
else:
x1, x2 = x_rotate.chunk(2, dim=-1)
real = x1 * cos - x2 * sin
imag = x1 * sin + x2 * cos
x_rotated = real.cat(imag, dim=-1)
output = x_rotated.cat(x_pass, dim=-1)
return output.flatten(start_dim=2) if len(original_input_shape) == 3 else output.permute(0, 2, 1, 3)
# ***** Indexing Ops *****
def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices]

View File

@@ -55,7 +55,7 @@ setup(name='tinygrad',
],
'testing': testing_minimal + [
"pillow",
"onnx==1.17.0",
"onnx==1.18.0",
"onnx2torch",
"onnxruntime",
"opencv-python",

View File

@@ -184,11 +184,6 @@ backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_mapping_cpu') # bad d
backend_test.exclude('test_scatternd_min_cpu') # min not yet supported
backend_test.exclude('test_scatternd_max_cpu') # max not yet supported
backend_test.exclude('test_rms_normalization') # RMSNormalization
backend_test.exclude('test_rotary_embedding') # RotaryEmbedding
backend_test.exclude('test_attention_3d') # not piped correctly?
backend_test.exclude('test_attention_4d') # not piped correctly?
if Device.DEFAULT in ['GPU', 'METAL']:
backend_test.exclude('test_resize_upsample_sizes_nearest_axes_2_3_cpu')
backend_test.exclude('test_resize_upsample_sizes_nearest_axes_3_2_cpu')

View File

@@ -10,6 +10,11 @@ import numpy as np
from extra.onnx_helpers import validate
from onnx.defs import ONNX_DOMAIN, AI_ONNX_PREVIEW_TRAINING_DOMAIN
MICROSOFT_CONTRIB_OPS_DOMAIN = "com.microsoft"
# TODO: remove this once ORT supports 1.18.0
from onnx.helper import VERSION_TABLE
VERSION_MAP = {row[0]: row[1:] for row in VERSION_TABLE}
IR_VERSION, ai_onnx, ai_onnx_ml, ai_onnx_training = VERSION_MAP["1.17.0"]
class TestOnnxOps(unittest.TestCase):
DOMAIN = None
@@ -18,7 +23,14 @@ class TestOnnxOps(unittest.TestCase):
onnx_outputs = [onnx.helper.make_empty_tensor_value_info(name) for name in outs]
nodes = [onnx.helper.make_node(op, list(inps), list(outs), domain=self.DOMAIN, **opts)]
graph = onnx.helper.make_graph(nodes, f"test_{op.lower()}", onnx_inputs, onnx_outputs)
model = onnx.helper.make_model(graph, producer_name=f"test_{op.lower()}")
#model = onnx.helper.make_model(graph, producer_name=f"test_{op.lower()}")
# TODO: remove this once ORT supports 1.18.0
opset_id = None
if type(self).__name__ == "TestMainOnnxOps": opset_id = ai_onnx
if type(self).__name__ == "TestTrainingOnnxOps": opset_id = ai_onnx_training
if type(self).__name__ == "TestContribOnnxOps": opset_id = 1
model = onnx.helper.make_model(graph, producer_name=f"test_{op.lower()}", ir_version=IR_VERSION,
opset_imports=[onnx.helper.make_opsetid(self.DOMAIN, opset_id)])
return model
def helper_test_single_op(self, op:str, inps:dict[str, np.ndarray], opts:dict[str, Any], outs:list[str], rtol=1e-3, atol=1e-6):

View File

@@ -32,7 +32,9 @@ def create_gemm_model(model_path:str, batch_size=N, in_size=N, out_size=N, bias=
graph_def = helper.make_graph([gemm_node], "SingleGemmGraph", [input_tensor], [output_tensor], initializer=[W_init])
# Create and save the model
model_def = helper.make_model(graph_def, producer_name="single_gemm_example")
#model_def = helper.make_model(graph_def, producer_name="single_gemm_example")
# TODO remove this once ORT supports 1.18.0
model_def = helper.make_model(graph_def, producer_name="single_gemm_example", ir_version=10, opset_imports=[helper.make_opsetid("", 22)])
onnx.save_model(model_def, model_path)
return model_path