Revert "don't use intermediate dict in onnx parse" (#15332)

This commit is contained in:
Christopher Milan
2026-03-17 20:46:30 -07:00
committed by GitHub
parent 94926d00d8
commit 0222bfdf69

View File

@@ -153,19 +153,21 @@ class OnnxPBParser:
def _parse_ModelProto(self) -> dict:
"""Entry point for parsing the ONNX model."""
graph: dict|None = None
opset_imports: list[OpSetId] = []
obj: dict[str, Any] = {"opset_import": []}
for fid, wire_type in self._parse_message(self.reader.len):
match fid:
case 7: graph = self._parse_GraphProto()
case 8: opset_imports.append(self._parse_OperatorSetIdProto())
case 4: obj["domain"] = self.reader.read_string()
case 5: obj["model_version"] = self.reader.read_int64()
case 7: obj["graph"] = self._parse_GraphProto()
case 8: obj["opset_import"].append(self._parse_OperatorSetIdProto())
case _: self.reader.skip_field(wire_type)
assert graph is not None
# update opset version
versions = {opset.domain: opset.version for opset in opset_imports}
graph["node"] = [OnnxNode(n.op, OpSetId(n.opset_id.domain, versions.get(n.opset_id.domain, 1)), n.inputs, n.outputs, n.opts)
for n in graph["node"]]
return graph
opset_imports = {Domain.from_onnx(x.get('domain')):x.get('version', 1) for x in obj["opset_import"]}
for n in obj["graph"]["node"]:
n_ = n["parsed_node"]
n["parsed_node"] = OnnxNode(n_.op, OpSetId(n_.opset_id.domain, opset_imports.get(n_.opset_id.domain, 1)), n_.inputs, n_.outputs, n_.opts)
return obj
def _parse_GraphProto(self) -> dict:
obj: dict[str, Any] = {"node": [], "initializer": [], "input": [], "output": []}
@@ -179,23 +181,26 @@ class OnnxPBParser:
case _: self.reader.skip_field(wire_type)
return obj
def _parse_NodeProto(self) -> OnnxNode:
inputs: list[str] = []
outputs: list[str] = []
attributes: list[tuple[str, Any]] = []
domain: str|None = None
op_type = ""
def _parse_NodeProto(self) -> dict:
obj: dict[str, Any] = {"input": [], "output": [], "attribute": [], "domain": None}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: inputs.append(self.reader.read_string())
case 2: outputs.append(self.reader.read_string())
case 4: op_type = self.reader.read_string()
case 5: attributes.append(self._parse_AttributeProto())
case 7: domain = self.reader.read_string()
case 1: obj["input"].append(self.reader.read_string())
case 2: obj["output"].append(self.reader.read_string())
case 3: obj["name"] = self.reader.read_string()
case 4: obj["op_type"] = self.reader.read_string()
case 5: obj["attribute"].append(self._parse_AttributeProto())
case 6: obj["doc_string"] = self.reader.read_string()
case 7: obj["domain"] = self.reader.read_string()
case _: self.reader.skip_field(wire_type)
return OnnxNode(op_type, OpSetId(Domain.from_onnx(domain), 1), tuple(inputs), tuple(outputs), dict(attributes))
def _parse_TensorProto(self) -> tuple[str, Tensor]:
# parse node
attributes = {attr_dict["name"]: attr_dict[AttributeType(attr_dict["type"]).to_field_name()] for attr_dict in obj["attribute"]}
opset_id = OpSetId(Domain.from_onnx(obj.get('domain')), 1) # default version, to be updated later in _parse_ModelProto
obj["parsed_node"] = OnnxNode(obj["op_type"], opset_id, tuple(obj["input"]), tuple(obj["output"]), attributes)
return obj
def _parse_TensorProto(self) -> dict:
obj: dict[str, Any] = {"dims": []}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
@@ -215,16 +220,18 @@ class OnnxPBParser:
# load external data
if self.load_external_data and obj.get("data_location", 0) == 1:
if "external_data" not in obj: raise ValueError("no external_data")
ext = dict(obj["external_data"])
if "location" not in ext: raise ValueError("no location in external_data")
offset = int(ext.get("offset", "0"))
length = int(ext["length"]) if "length" in ext else None
location, length, offset = None, None, 0
for kv in obj["external_data"]:
if kv["key"] == "location": location = kv["value"]
elif kv["key"] == "offset": offset = int(kv["value"])
elif kv["key"] == "length": length = int(kv["value"])
if location is None: raise ValueError("no location in external_data")
if self.file_path is None:
if isinstance(self.tensor.device, str) and self.tensor.device.startswith("DISK:"):
self.file_path = pathlib.Path(self.tensor.device[5:])
else: raise ValueError("onnx external_data needs the origin file path, try passing onnx file path to onnx_load")
ext_path = self.file_path.parent.joinpath(ext["location"])
ext_path = self.file_path.parent.joinpath(location)
if not ext_path.exists(): raise FileNotFoundError(f"external location not exists: {ext_path}")
ext_tensor = Tensor(ext_path)
@@ -234,20 +241,23 @@ class OnnxPBParser:
# parse tensor
to_dtype = dtype_fallback(true_dtype := OnnxDataType(obj['data_type']).to_dtype(), "buffer parse")
shape = tuple(obj['dims'])
data_fields = [f for f in ('float_data','int32_data','int64_data','double_data','uint64_data','raw_data') if f in obj]
data = obj[get_single_element(data_fields)]
name = obj.get("name", "")
if not isinstance(data, Tensor): return name, Tensor(data, dtype=to_dtype).reshape(shape)
assert data.dtype == dtypes.uint8, data
present_fields = [field for field in ['float_data', 'int32_data', 'int64_data', 'double_data', 'uint64_data', 'raw_data'] if field in obj]
assert len(present_fields) == 1, f"only 1 data field is allowed from {obj=}"
data = obj[present_fields[0]]
if not isinstance(data, Tensor):
obj["parsed_tensor"] = Tensor(data, dtype=to_dtype).reshape(shape)
return obj
assert isinstance(data, Tensor) and data.dtype == dtypes.uint8, data
data = data.bitcast(true_dtype).reshape(shape)
data = data.to(Device.DEFAULT) if true_dtype is to_dtype else data.to("cpu").cast(to_dtype).to(Device.DEFAULT)
# const folding
if shape == ():
if data.dtype == dtypes.float16 and sys.version_info < (3, 12): data = data.cast(dtypes.float32)
data = Tensor(data.item(), dtype=to_dtype).reshape(shape)
return name, data
obj["parsed_tensor"] = data
return obj
def _parse_AttributeProto(self) -> tuple[str, Any]:
def _parse_AttributeProto(self) -> dict:
obj: dict[str, Any] = {"floats": [], "ints": [], "strings": []}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
@@ -255,7 +265,7 @@ class OnnxPBParser:
case 2: obj["f"] = self.reader.read_float()
case 3: obj["i"] = self.reader.read_int64()
case 4: obj["s"] = self.reader.read_bytes().data().tobytes().decode("utf8")
case 5: obj["t"] = self._parse_TensorProto()[1]
case 5: obj["t"] = self._parse_TensorProto()['parsed_tensor']
case 6: obj["g"] = OnnxRunner._from_subgraph(self._parse_GraphProto())
case 7: obj["floats"].append(self.reader.read_float())
case 8: obj["ints"].append(self.reader.read_int64())
@@ -263,22 +273,26 @@ class OnnxPBParser:
case 20: obj["type"] = self.reader.read_int64()
case _: self.reader.skip_field(wire_type)
obj["floats"], obj["ints"], obj["strings"] = tuple(obj["floats"]), tuple(obj["ints"]), tuple(obj["strings"])
return obj["name"], obj[AttributeType(obj["type"]).to_field_name()]
return obj
def _parse_ValueInfoProto(self) -> tuple[str, OnnxValue|None]:
name, type_obj = "", None
def _parse_ValueInfoProto(self) -> dict:
obj: dict[str, Any] = {}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: name = self.reader.read_string()
case 2: type_obj = self._parse_TypeProto()
case 1: obj["name"] = self.reader.read_string()
case 2: obj["type"] = self._parse_TypeProto()
case _: self.reader.skip_field(wire_type)
if type_obj is None: return name, None
# parse type
if "type" not in obj: return {**obj, "parsed_type": None}
type_obj = obj["type"]
if is_optional := "optional_type" in type_obj: type_obj = type_obj["optional_type"]["elem_type"]
if is_sequence := "sequence_type" in type_obj: type_obj = type_obj["sequence_type"]["elem_type"]
assert "tensor_type" in type_obj, type_obj
shape_dims = type_obj['tensor_type'].get('shape', {}).get('dim', [])
return name, OnnxValue(tuple(d.get('dim_param') or d.get('dim_value') for d in shape_dims),
OnnxDataType(type_obj['tensor_type']['elem_type']).to_dtype(), is_optional, is_sequence)
obj['parsed_type'] = OnnxValue(tuple(d.get('dim_param') or d.get('dim_value') for d in shape_dims),
OnnxDataType(type_obj['tensor_type']['elem_type']).to_dtype(), is_optional, is_sequence)
return obj
def _parse_TypeProto(self) -> dict:
obj: dict[str, Any] = {}
@@ -324,24 +338,23 @@ class OnnxPBParser:
case _: self.reader.skip_field(wire_type)
return obj
def _parse_StringStringEntryProto(self) -> tuple[str, str]:
key, value = "", ""
def _parse_StringStringEntryProto(self) -> dict:
obj: dict[str, Any] = {}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: key = self.reader.read_string()
case 2: value = self.reader.read_string()
case 1: obj["key"] = self.reader.read_string()
case 2: obj["value"] = self.reader.read_string()
case _: self.reader.skip_field(wire_type)
return key, value
return obj
def _parse_OperatorSetIdProto(self) -> OpSetId:
domain: str|None = None
version = 1
def _parse_OperatorSetIdProto(self) -> dict:
obj: dict[str, Any] = {}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: domain = self.reader.read_string()
case 2: version = self.reader.read_int64()
case 1: obj["domain"] = self.reader.read_string()
case 2: obj["version"] = self.reader.read_int64()
case _: self.reader.skip_field(wire_type)
return OpSetId(Domain.from_onnx(domain), version)
return obj
# ***** python const *****
required_input_python_consts: dict[str, tuple[int, ...]] = {
@@ -367,15 +380,16 @@ class OnnxRunner:
model_path: The ONNX model, provided as a file path (a string or Path object) or a Tensor.
"""
def __init__(self, model_path: Tensor | str | pathlib.Path):
self._init_from_graph(OnnxPBParser(model_path, load_external_data=True).parse())
model = OnnxPBParser(model_path, load_external_data=True).parse()
self._init_from_graph(model["graph"])
def _init_from_graph(self, graph: dict, is_subgraph: bool = False):
self.is_training = any(n.opset_id.domain in {Domain.AI_ONNX_TRAINING, Domain.AI_ONNX_PREVIEW_TRAINING} for n in graph["node"])
self.is_training = any(n['parsed_node'].opset_id.domain in {Domain.AI_ONNX_TRAINING, Domain.AI_ONNX_PREVIEW_TRAINING} for n in graph["node"])
self.graph_name = graph["name"] if is_subgraph else ""
self.graph_values: dict[str, Any] = {"": None, **dict(graph["initializer"])}
self.graph_inputs = {name: typ for name, typ in graph["input"] if name not in self.graph_values}
self.graph_outputs = tuple(name for name, _ in graph["output"])
self.graph_nodes = tuple(graph["node"])
self.graph_values = {"": None, **{i["name"]: i["parsed_tensor"] for i in graph["initializer"]}}
self.graph_inputs = {i["name"]: i["parsed_type"] for i in graph["input"] if i["name"] not in self.graph_values}
self.graph_outputs = tuple(o["name"] for o in graph["output"])
self.graph_nodes = tuple(n["parsed_node"] for n in graph["node"])
# track names from initializers and Constant nodes for fast path optimizations
self.const_names: set[str] = set(self.graph_values.keys()) | {o for n in self.graph_nodes if n.op == "Constant" for o in n.outputs}