From 7ce9e4547427e5e0720f37d864af8a8e18434e8d Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 8 Jul 2025 19:50:28 -0400 Subject: [PATCH] mypy onnx_parser (#11141) --- .github/workflows/test.yml | 4 +++- extra/onnx_parser.py | 21 ++++++++++++--------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 15ea498aa5..269a0d012c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -331,7 +331,9 @@ jobs: - name: Lint tinygrad with pylint run: python -m pylint tinygrad/ - name: Run mypy - run: python -m mypy --strict-equality --lineprecision-report . && cat lineprecision.txt + run: | + python -m mypy --strict-equality --lineprecision-report . && cat lineprecision.txt + python -m mypy --strict-equality extra/onnx_parser.py unittest: name: Unit Tests diff --git a/extra/onnx_parser.py b/extra/onnx_parser.py index dd8b8946d8..15739422ef 100644 --- a/extra/onnx_parser.py +++ b/extra/onnx_parser.py @@ -2,7 +2,6 @@ import os, pathlib, struct from io import BufferedReader -from typing import Tuple, Union from types import SimpleNamespace from tinygrad.nn.state import TensorIO from tinygrad.tensor import Tensor, dtypes @@ -22,7 +21,7 @@ class AttributeType: class PBType: FLOAT = 1; INT = 2; STRING = 3; FLOATS = 4; INTS = 5; STRINGS = 6; BYTES = 7; SUB = 8 # noqa: E702 -PB_INFOS = { +PB_INFOS: dict[str, dict] = { "OperatorSetIdProto": {1: ("domain", PBType.STRING), 2: ("version", PBType.INT)}, "StringStringEntryProto": {1: ("key", PBType.STRING), 2: ("value", PBType.STRING)}, # TODO: support uint64 parsing (11: "uint64_data") and double parsing (10: "double_data") @@ -55,7 +54,7 @@ PB_INFOS = { "TypeProtoTensor": {1: ("elem_type", PBType.INT), 2: ("shape", PBType.SUB, False, ("TensorShapeProto", lambda: {"dim": []}))}, } -def onnx_load(fn: Union[Tensor, str, pathlib.Path], load_external_data: bool=True): +def onnx_load(fn: Tensor|str|pathlib.Path, load_external_data: bool=True): parser = OnnxParser(fn, load_external_data) onnx_model = parser.parse() model = dict_to_namespace(onnx_model) @@ -71,8 +70,8 @@ def dict_to_namespace(d): return d class OnnxParser: - def __init__(self, inp: Union[Tensor, str, pathlib.Path], load_external_data: bool=True): - self.file_path: Union[pathlib.Path, None] = None + def __init__(self, inp: Tensor|str|pathlib.Path, load_external_data: bool=True): + self.file_path: pathlib.Path|None = None self.load_external_data = load_external_data if not isinstance(inp, Tensor): self.file_path = pathlib.Path(inp) @@ -90,7 +89,6 @@ class OnnxParser: elif len(config) == 4: name, attr, repeated, parser_fn = config handler_fn = self.attr_func_dict[attr] def _wrapper_handler(obj, reader, wt, h=handler_fn, n=name, p=parser_fn, r=repeated): return h(obj, n, reader, wt, parser_func=p, repeated=r) - _wrapper_handler._debug_info = f"{fid}, {name} => {handler_fn}" res[fid] = _wrapper_handler self.registered_handles[pb_name] = res @@ -131,16 +129,19 @@ class OnnxParser: if message_field_handlers_name == "TensorProto" and self.load_external_data and obj.get("data_location", 0) == 1: self._parse_external_data(obj) return obj - def _handle_delimited(self, reader:BufferedReader, use_tensor=False) -> Tuple[bytes, Tensor]: + def _handle_delimited(self, reader:BufferedReader, use_tensor=False) -> Tensor|bytes: str_len = self.decode_varint(reader) if not use_tensor: return reader.read(str_len) - res = reader.raw._tensor[reader.tell():(reader.tell()+str_len)] + raw = reader.raw + assert isinstance(raw, TensorIO) + res = raw._tensor[reader.tell():(reader.tell()+str_len)] reader.seek(str_len, os.SEEK_CUR) return res def _handle_string(self, obj, key_name, reader, wire_type, parser_func=None, repeated=False): if wire_type != WIRETYPE_LENGTH_DELIMITED: raise ValueError(f"Expected length-delimited for string field '{key_name}'") value = self._handle_delimited(reader) + assert isinstance(value, bytes) gen_result(obj, key_name, value.decode("utf-8"), repeated) def _handle_bytes(self, obj, key_name, reader, wire_type, parser_func=None, repeated=False): @@ -171,11 +172,13 @@ class OnnxParser: def _handle_packed_floats(self, obj, key_name, reader, wire_type, parser_func=None, repeated=False): if wire_type != WIRETYPE_LENGTH_DELIMITED: raise ValueError("Packed floats expected length_delimited") value = self._handle_delimited(reader, use_tensor=True) + assert isinstance(value, Tensor) obj[key_name] = value.bitcast(dtypes.float32) def _handle_sub_message(self, obj, key_name, reader, wire_type, parser_func=None, repeated=False): if wire_type != WIRETYPE_LENGTH_DELIMITED: raise ValueError(f"Expected length-delimited for sub-message field '{key_name}'") value = self._handle_delimited(reader, use_tensor=True) + assert isinstance(value, Tensor) if isinstance(parser_func, str): sub_obj = self._parse_message(BufferedReader(TensorIO(value)), parser_func) elif isinstance(parser_func, tuple): sub_obj = self._parse_message(BufferedReader(TensorIO(value)), parser_func[0], parser_func[1]) else: sub_obj = parser_func(BufferedReader(TensorIO(value))) @@ -194,7 +197,7 @@ class OnnxParser: if self.file_path is None: # get onnx file path from Tensor if isinstance(self.tensor.device, str) and self.tensor.device.startswith("DISK:"): - self.file_path = self.tensor.device[5:] + self.file_path = pathlib.Path(self.tensor.device[5:]) if not (ext_path := self.file_path.parent.joinpath(location)).exists(): raise Exception(f"external location not exists: {ext_path}, may caused by symbolic link, try passing onnx file path to onnx_load") else: raise Exception("onnx external_data need the origin file path, try passing onnx file path to onnx_load")