mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add openpilot onnx parser test (#15334)
This commit is contained in:
committed by
GitHub
parent
0222bfdf69
commit
864d3917d5
38
test/external/external_test_onnx_runner.py
vendored
38
test/external/external_test_onnx_runner.py
vendored
@@ -3,7 +3,8 @@ import numpy as np
|
||||
from tinygrad import dtypes, Tensor
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.nn.onnx import OnnxRunner, OnnxDataType
|
||||
from typing import Any
|
||||
from tinygrad.nn.onnx import OnnxRunner, OnnxPBParser, OnnxDataType
|
||||
from hypothesis import given, strategies as st
|
||||
|
||||
# copied from test_const_folding.py
|
||||
@@ -136,5 +137,40 @@ class TestOnnxRunnerDtypes(unittest.TestCase):
|
||||
from_disk=False)
|
||||
self.assertEqual(runner.graph_nodes[0].opts['value'].dtype, expected_dtype)
|
||||
|
||||
# from openpilot selfdrive/modeld/get_model_metadata.py
|
||||
class MetadataOnnxPBParser(OnnxPBParser):
|
||||
def _parse_ModelProto(self) -> dict:
|
||||
obj: dict[str, Any] = {"graph": {"input": [], "output": []}, "metadata_props": []}
|
||||
for fid, wire_type in self._parse_message(self.reader.len):
|
||||
match fid:
|
||||
case 7: obj["graph"] = self._parse_GraphProto()
|
||||
case 14: obj["metadata_props"].append(self._parse_StringStringEntryProto())
|
||||
case _: self.reader.skip_field(wire_type)
|
||||
return obj
|
||||
|
||||
class TestOnnxMetadata(unittest.TestCase):
|
||||
def test_metadata_props(self):
|
||||
graph = onnx.helper.make_graph(
|
||||
nodes=[onnx.helper.make_node('Identity', ['input'], ['output'])],
|
||||
name='test',
|
||||
inputs=[onnx.helper.make_tensor_value_info('input', onnx.TensorProto.FLOAT, (1, 3))],
|
||||
outputs=[onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, (1, 3))],
|
||||
)
|
||||
model = onnx.helper.make_model(graph)
|
||||
model.metadata_props.append(onnx.StringStringEntryProto(key="model_checkpoint", value="v1.0"))
|
||||
model.metadata_props.append(onnx.StringStringEntryProto(key="output_slices", value="dGVzdA=="))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model_path = pathlib.Path(tmpdir) / "model.onnx"
|
||||
onnx.save(model, model_path)
|
||||
parsed = MetadataOnnxPBParser(model_path).parse()
|
||||
|
||||
# metadata_props should be accessible as dicts with "key" and "value"
|
||||
self.assertEqual(len(parsed["metadata_props"]), 2)
|
||||
self.assertEqual(parsed["metadata_props"][0]["key"], "model_checkpoint")
|
||||
self.assertEqual(parsed["metadata_props"][0]["value"], "v1.0")
|
||||
self.assertEqual(parsed["metadata_props"][1]["key"], "output_slices")
|
||||
self.assertEqual(parsed["metadata_props"][1]["value"], "dGVzdA==")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user