diff --git a/test/external/external_test_onnx_runner.py b/test/external/external_test_onnx_runner.py index 4d162c8187..3d58f9c323 100644 --- a/test/external/external_test_onnx_runner.py +++ b/test/external/external_test_onnx_runner.py @@ -3,7 +3,8 @@ import numpy as np from tinygrad import dtypes, Tensor from tinygrad.uop.ops import Ops from tinygrad.device import is_dtype_supported -from tinygrad.nn.onnx import OnnxRunner, OnnxDataType +from typing import Any +from tinygrad.nn.onnx import OnnxRunner, OnnxPBParser, OnnxDataType from hypothesis import given, strategies as st # copied from test_const_folding.py @@ -136,5 +137,40 @@ class TestOnnxRunnerDtypes(unittest.TestCase): from_disk=False) self.assertEqual(runner.graph_nodes[0].opts['value'].dtype, expected_dtype) +# from openpilot selfdrive/modeld/get_model_metadata.py +class MetadataOnnxPBParser(OnnxPBParser): + def _parse_ModelProto(self) -> dict: + obj: dict[str, Any] = {"graph": {"input": [], "output": []}, "metadata_props": []} + for fid, wire_type in self._parse_message(self.reader.len): + match fid: + case 7: obj["graph"] = self._parse_GraphProto() + case 14: obj["metadata_props"].append(self._parse_StringStringEntryProto()) + case _: self.reader.skip_field(wire_type) + return obj + +class TestOnnxMetadata(unittest.TestCase): + def test_metadata_props(self): + graph = onnx.helper.make_graph( + nodes=[onnx.helper.make_node('Identity', ['input'], ['output'])], + name='test', + inputs=[onnx.helper.make_tensor_value_info('input', onnx.TensorProto.FLOAT, (1, 3))], + outputs=[onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, (1, 3))], + ) + model = onnx.helper.make_model(graph) + model.metadata_props.append(onnx.StringStringEntryProto(key="model_checkpoint", value="v1.0")) + model.metadata_props.append(onnx.StringStringEntryProto(key="output_slices", value="dGVzdA==")) + + with tempfile.TemporaryDirectory() as tmpdir: + model_path = pathlib.Path(tmpdir) / "model.onnx" + onnx.save(model, model_path) + parsed = MetadataOnnxPBParser(model_path).parse() + + # metadata_props should be accessible as dicts with "key" and "value" + self.assertEqual(len(parsed["metadata_props"]), 2) + self.assertEqual(parsed["metadata_props"][0]["key"], "model_checkpoint") + self.assertEqual(parsed["metadata_props"][0]["value"], "v1.0") + self.assertEqual(parsed["metadata_props"][1]["key"], "output_slices") + self.assertEqual(parsed["metadata_props"][1]["value"], "dGVzdA==") + if __name__ == '__main__': unittest.main() \ No newline at end of file