mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Revert "don't use intermediate dict in onnx parse" (#15332)
This commit is contained in:
committed by
GitHub
parent
94926d00d8
commit
0222bfdf69
@@ -153,19 +153,21 @@ class OnnxPBParser:
|
||||
|
||||
def _parse_ModelProto(self) -> dict:
|
||||
"""Entry point for parsing the ONNX model."""
|
||||
graph: dict|None = None
|
||||
opset_imports: list[OpSetId] = []
|
||||
obj: dict[str, Any] = {"opset_import": []}
|
||||
for fid, wire_type in self._parse_message(self.reader.len):
|
||||
match fid:
|
||||
case 7: graph = self._parse_GraphProto()
|
||||
case 8: opset_imports.append(self._parse_OperatorSetIdProto())
|
||||
case 4: obj["domain"] = self.reader.read_string()
|
||||
case 5: obj["model_version"] = self.reader.read_int64()
|
||||
case 7: obj["graph"] = self._parse_GraphProto()
|
||||
case 8: obj["opset_import"].append(self._parse_OperatorSetIdProto())
|
||||
case _: self.reader.skip_field(wire_type)
|
||||
assert graph is not None
|
||||
|
||||
# update opset version
|
||||
versions = {opset.domain: opset.version for opset in opset_imports}
|
||||
graph["node"] = [OnnxNode(n.op, OpSetId(n.opset_id.domain, versions.get(n.opset_id.domain, 1)), n.inputs, n.outputs, n.opts)
|
||||
for n in graph["node"]]
|
||||
return graph
|
||||
opset_imports = {Domain.from_onnx(x.get('domain')):x.get('version', 1) for x in obj["opset_import"]}
|
||||
for n in obj["graph"]["node"]:
|
||||
n_ = n["parsed_node"]
|
||||
n["parsed_node"] = OnnxNode(n_.op, OpSetId(n_.opset_id.domain, opset_imports.get(n_.opset_id.domain, 1)), n_.inputs, n_.outputs, n_.opts)
|
||||
return obj
|
||||
|
||||
def _parse_GraphProto(self) -> dict:
|
||||
obj: dict[str, Any] = {"node": [], "initializer": [], "input": [], "output": []}
|
||||
@@ -179,23 +181,26 @@ class OnnxPBParser:
|
||||
case _: self.reader.skip_field(wire_type)
|
||||
return obj
|
||||
|
||||
def _parse_NodeProto(self) -> OnnxNode:
|
||||
inputs: list[str] = []
|
||||
outputs: list[str] = []
|
||||
attributes: list[tuple[str, Any]] = []
|
||||
domain: str|None = None
|
||||
op_type = ""
|
||||
def _parse_NodeProto(self) -> dict:
|
||||
obj: dict[str, Any] = {"input": [], "output": [], "attribute": [], "domain": None}
|
||||
for fid, wire_type in self._parse_message(self._decode_end_pos()):
|
||||
match fid:
|
||||
case 1: inputs.append(self.reader.read_string())
|
||||
case 2: outputs.append(self.reader.read_string())
|
||||
case 4: op_type = self.reader.read_string()
|
||||
case 5: attributes.append(self._parse_AttributeProto())
|
||||
case 7: domain = self.reader.read_string()
|
||||
case 1: obj["input"].append(self.reader.read_string())
|
||||
case 2: obj["output"].append(self.reader.read_string())
|
||||
case 3: obj["name"] = self.reader.read_string()
|
||||
case 4: obj["op_type"] = self.reader.read_string()
|
||||
case 5: obj["attribute"].append(self._parse_AttributeProto())
|
||||
case 6: obj["doc_string"] = self.reader.read_string()
|
||||
case 7: obj["domain"] = self.reader.read_string()
|
||||
case _: self.reader.skip_field(wire_type)
|
||||
return OnnxNode(op_type, OpSetId(Domain.from_onnx(domain), 1), tuple(inputs), tuple(outputs), dict(attributes))
|
||||
|
||||
def _parse_TensorProto(self) -> tuple[str, Tensor]:
|
||||
# parse node
|
||||
attributes = {attr_dict["name"]: attr_dict[AttributeType(attr_dict["type"]).to_field_name()] for attr_dict in obj["attribute"]}
|
||||
opset_id = OpSetId(Domain.from_onnx(obj.get('domain')), 1) # default version, to be updated later in _parse_ModelProto
|
||||
obj["parsed_node"] = OnnxNode(obj["op_type"], opset_id, tuple(obj["input"]), tuple(obj["output"]), attributes)
|
||||
return obj
|
||||
|
||||
def _parse_TensorProto(self) -> dict:
|
||||
obj: dict[str, Any] = {"dims": []}
|
||||
for fid, wire_type in self._parse_message(self._decode_end_pos()):
|
||||
match fid:
|
||||
@@ -215,16 +220,18 @@ class OnnxPBParser:
|
||||
# load external data
|
||||
if self.load_external_data and obj.get("data_location", 0) == 1:
|
||||
if "external_data" not in obj: raise ValueError("no external_data")
|
||||
ext = dict(obj["external_data"])
|
||||
if "location" not in ext: raise ValueError("no location in external_data")
|
||||
offset = int(ext.get("offset", "0"))
|
||||
length = int(ext["length"]) if "length" in ext else None
|
||||
location, length, offset = None, None, 0
|
||||
for kv in obj["external_data"]:
|
||||
if kv["key"] == "location": location = kv["value"]
|
||||
elif kv["key"] == "offset": offset = int(kv["value"])
|
||||
elif kv["key"] == "length": length = int(kv["value"])
|
||||
if location is None: raise ValueError("no location in external_data")
|
||||
|
||||
if self.file_path is None:
|
||||
if isinstance(self.tensor.device, str) and self.tensor.device.startswith("DISK:"):
|
||||
self.file_path = pathlib.Path(self.tensor.device[5:])
|
||||
else: raise ValueError("onnx external_data needs the origin file path, try passing onnx file path to onnx_load")
|
||||
ext_path = self.file_path.parent.joinpath(ext["location"])
|
||||
ext_path = self.file_path.parent.joinpath(location)
|
||||
if not ext_path.exists(): raise FileNotFoundError(f"external location not exists: {ext_path}")
|
||||
|
||||
ext_tensor = Tensor(ext_path)
|
||||
@@ -234,20 +241,23 @@ class OnnxPBParser:
|
||||
# parse tensor
|
||||
to_dtype = dtype_fallback(true_dtype := OnnxDataType(obj['data_type']).to_dtype(), "buffer parse")
|
||||
shape = tuple(obj['dims'])
|
||||
data_fields = [f for f in ('float_data','int32_data','int64_data','double_data','uint64_data','raw_data') if f in obj]
|
||||
data = obj[get_single_element(data_fields)]
|
||||
name = obj.get("name", "")
|
||||
if not isinstance(data, Tensor): return name, Tensor(data, dtype=to_dtype).reshape(shape)
|
||||
assert data.dtype == dtypes.uint8, data
|
||||
present_fields = [field for field in ['float_data', 'int32_data', 'int64_data', 'double_data', 'uint64_data', 'raw_data'] if field in obj]
|
||||
assert len(present_fields) == 1, f"only 1 data field is allowed from {obj=}"
|
||||
data = obj[present_fields[0]]
|
||||
if not isinstance(data, Tensor):
|
||||
obj["parsed_tensor"] = Tensor(data, dtype=to_dtype).reshape(shape)
|
||||
return obj
|
||||
assert isinstance(data, Tensor) and data.dtype == dtypes.uint8, data
|
||||
data = data.bitcast(true_dtype).reshape(shape)
|
||||
data = data.to(Device.DEFAULT) if true_dtype is to_dtype else data.to("cpu").cast(to_dtype).to(Device.DEFAULT)
|
||||
# const folding
|
||||
if shape == ():
|
||||
if data.dtype == dtypes.float16 and sys.version_info < (3, 12): data = data.cast(dtypes.float32)
|
||||
data = Tensor(data.item(), dtype=to_dtype).reshape(shape)
|
||||
return name, data
|
||||
obj["parsed_tensor"] = data
|
||||
return obj
|
||||
|
||||
def _parse_AttributeProto(self) -> tuple[str, Any]:
|
||||
def _parse_AttributeProto(self) -> dict:
|
||||
obj: dict[str, Any] = {"floats": [], "ints": [], "strings": []}
|
||||
for fid, wire_type in self._parse_message(self._decode_end_pos()):
|
||||
match fid:
|
||||
@@ -255,7 +265,7 @@ class OnnxPBParser:
|
||||
case 2: obj["f"] = self.reader.read_float()
|
||||
case 3: obj["i"] = self.reader.read_int64()
|
||||
case 4: obj["s"] = self.reader.read_bytes().data().tobytes().decode("utf8")
|
||||
case 5: obj["t"] = self._parse_TensorProto()[1]
|
||||
case 5: obj["t"] = self._parse_TensorProto()['parsed_tensor']
|
||||
case 6: obj["g"] = OnnxRunner._from_subgraph(self._parse_GraphProto())
|
||||
case 7: obj["floats"].append(self.reader.read_float())
|
||||
case 8: obj["ints"].append(self.reader.read_int64())
|
||||
@@ -263,22 +273,26 @@ class OnnxPBParser:
|
||||
case 20: obj["type"] = self.reader.read_int64()
|
||||
case _: self.reader.skip_field(wire_type)
|
||||
obj["floats"], obj["ints"], obj["strings"] = tuple(obj["floats"]), tuple(obj["ints"]), tuple(obj["strings"])
|
||||
return obj["name"], obj[AttributeType(obj["type"]).to_field_name()]
|
||||
return obj
|
||||
|
||||
def _parse_ValueInfoProto(self) -> tuple[str, OnnxValue|None]:
|
||||
name, type_obj = "", None
|
||||
def _parse_ValueInfoProto(self) -> dict:
|
||||
obj: dict[str, Any] = {}
|
||||
for fid, wire_type in self._parse_message(self._decode_end_pos()):
|
||||
match fid:
|
||||
case 1: name = self.reader.read_string()
|
||||
case 2: type_obj = self._parse_TypeProto()
|
||||
case 1: obj["name"] = self.reader.read_string()
|
||||
case 2: obj["type"] = self._parse_TypeProto()
|
||||
case _: self.reader.skip_field(wire_type)
|
||||
if type_obj is None: return name, None
|
||||
|
||||
# parse type
|
||||
if "type" not in obj: return {**obj, "parsed_type": None}
|
||||
type_obj = obj["type"]
|
||||
if is_optional := "optional_type" in type_obj: type_obj = type_obj["optional_type"]["elem_type"]
|
||||
if is_sequence := "sequence_type" in type_obj: type_obj = type_obj["sequence_type"]["elem_type"]
|
||||
assert "tensor_type" in type_obj, type_obj
|
||||
shape_dims = type_obj['tensor_type'].get('shape', {}).get('dim', [])
|
||||
return name, OnnxValue(tuple(d.get('dim_param') or d.get('dim_value') for d in shape_dims),
|
||||
OnnxDataType(type_obj['tensor_type']['elem_type']).to_dtype(), is_optional, is_sequence)
|
||||
obj['parsed_type'] = OnnxValue(tuple(d.get('dim_param') or d.get('dim_value') for d in shape_dims),
|
||||
OnnxDataType(type_obj['tensor_type']['elem_type']).to_dtype(), is_optional, is_sequence)
|
||||
return obj
|
||||
|
||||
def _parse_TypeProto(self) -> dict:
|
||||
obj: dict[str, Any] = {}
|
||||
@@ -324,24 +338,23 @@ class OnnxPBParser:
|
||||
case _: self.reader.skip_field(wire_type)
|
||||
return obj
|
||||
|
||||
def _parse_StringStringEntryProto(self) -> tuple[str, str]:
|
||||
key, value = "", ""
|
||||
def _parse_StringStringEntryProto(self) -> dict:
|
||||
obj: dict[str, Any] = {}
|
||||
for fid, wire_type in self._parse_message(self._decode_end_pos()):
|
||||
match fid:
|
||||
case 1: key = self.reader.read_string()
|
||||
case 2: value = self.reader.read_string()
|
||||
case 1: obj["key"] = self.reader.read_string()
|
||||
case 2: obj["value"] = self.reader.read_string()
|
||||
case _: self.reader.skip_field(wire_type)
|
||||
return key, value
|
||||
return obj
|
||||
|
||||
def _parse_OperatorSetIdProto(self) -> OpSetId:
|
||||
domain: str|None = None
|
||||
version = 1
|
||||
def _parse_OperatorSetIdProto(self) -> dict:
|
||||
obj: dict[str, Any] = {}
|
||||
for fid, wire_type in self._parse_message(self._decode_end_pos()):
|
||||
match fid:
|
||||
case 1: domain = self.reader.read_string()
|
||||
case 2: version = self.reader.read_int64()
|
||||
case 1: obj["domain"] = self.reader.read_string()
|
||||
case 2: obj["version"] = self.reader.read_int64()
|
||||
case _: self.reader.skip_field(wire_type)
|
||||
return OpSetId(Domain.from_onnx(domain), version)
|
||||
return obj
|
||||
|
||||
# ***** python const *****
|
||||
required_input_python_consts: dict[str, tuple[int, ...]] = {
|
||||
@@ -367,15 +380,16 @@ class OnnxRunner:
|
||||
model_path: The ONNX model, provided as a file path (a string or Path object) or a Tensor.
|
||||
"""
|
||||
def __init__(self, model_path: Tensor | str | pathlib.Path):
|
||||
self._init_from_graph(OnnxPBParser(model_path, load_external_data=True).parse())
|
||||
model = OnnxPBParser(model_path, load_external_data=True).parse()
|
||||
self._init_from_graph(model["graph"])
|
||||
|
||||
def _init_from_graph(self, graph: dict, is_subgraph: bool = False):
|
||||
self.is_training = any(n.opset_id.domain in {Domain.AI_ONNX_TRAINING, Domain.AI_ONNX_PREVIEW_TRAINING} for n in graph["node"])
|
||||
self.is_training = any(n['parsed_node'].opset_id.domain in {Domain.AI_ONNX_TRAINING, Domain.AI_ONNX_PREVIEW_TRAINING} for n in graph["node"])
|
||||
self.graph_name = graph["name"] if is_subgraph else ""
|
||||
self.graph_values: dict[str, Any] = {"": None, **dict(graph["initializer"])}
|
||||
self.graph_inputs = {name: typ for name, typ in graph["input"] if name not in self.graph_values}
|
||||
self.graph_outputs = tuple(name for name, _ in graph["output"])
|
||||
self.graph_nodes = tuple(graph["node"])
|
||||
self.graph_values = {"": None, **{i["name"]: i["parsed_tensor"] for i in graph["initializer"]}}
|
||||
self.graph_inputs = {i["name"]: i["parsed_type"] for i in graph["input"] if i["name"] not in self.graph_values}
|
||||
self.graph_outputs = tuple(o["name"] for o in graph["output"])
|
||||
self.graph_nodes = tuple(n["parsed_node"] for n in graph["node"])
|
||||
# track names from initializers and Constant nodes for fast path optimizations
|
||||
self.const_names: set[str] = set(self.graph_values.keys()) | {o for n in self.graph_nodes if n.op == "Constant" for o in n.outputs}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user