add openpilot onnx parser test (#15334)

This commit is contained in:
Christopher Milan
2026-03-17 21:12:02 -07:00
committed by GitHub
parent 0222bfdf69
commit 864d3917d5

View File

@@ -3,7 +3,8 @@ import numpy as np
from tinygrad import dtypes, Tensor
from tinygrad.uop.ops import Ops
from tinygrad.device import is_dtype_supported
from tinygrad.nn.onnx import OnnxRunner, OnnxDataType
from typing import Any
from tinygrad.nn.onnx import OnnxRunner, OnnxPBParser, OnnxDataType
from hypothesis import given, strategies as st
# copied from test_const_folding.py
@@ -136,5 +137,40 @@ class TestOnnxRunnerDtypes(unittest.TestCase):
from_disk=False)
self.assertEqual(runner.graph_nodes[0].opts['value'].dtype, expected_dtype)
# from openpilot selfdrive/modeld/get_model_metadata.py
class MetadataOnnxPBParser(OnnxPBParser):
def _parse_ModelProto(self) -> dict:
obj: dict[str, Any] = {"graph": {"input": [], "output": []}, "metadata_props": []}
for fid, wire_type in self._parse_message(self.reader.len):
match fid:
case 7: obj["graph"] = self._parse_GraphProto()
case 14: obj["metadata_props"].append(self._parse_StringStringEntryProto())
case _: self.reader.skip_field(wire_type)
return obj
class TestOnnxMetadata(unittest.TestCase):
def test_metadata_props(self):
graph = onnx.helper.make_graph(
nodes=[onnx.helper.make_node('Identity', ['input'], ['output'])],
name='test',
inputs=[onnx.helper.make_tensor_value_info('input', onnx.TensorProto.FLOAT, (1, 3))],
outputs=[onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, (1, 3))],
)
model = onnx.helper.make_model(graph)
model.metadata_props.append(onnx.StringStringEntryProto(key="model_checkpoint", value="v1.0"))
model.metadata_props.append(onnx.StringStringEntryProto(key="output_slices", value="dGVzdA=="))
with tempfile.TemporaryDirectory() as tmpdir:
model_path = pathlib.Path(tmpdir) / "model.onnx"
onnx.save(model, model_path)
parsed = MetadataOnnxPBParser(model_path).parse()
# metadata_props should be accessible as dicts with "key" and "value"
self.assertEqual(len(parsed["metadata_props"]), 2)
self.assertEqual(parsed["metadata_props"][0]["key"], "model_checkpoint")
self.assertEqual(parsed["metadata_props"][0]["value"], "v1.0")
self.assertEqual(parsed["metadata_props"][1]["key"], "output_slices")
self.assertEqual(parsed["metadata_props"][1]["value"], "dGVzdA==")
if __name__ == '__main__':
unittest.main()