Files
tinygrad/test/external/external_test_onnx_runner.py
geohotstan 5ce278b245 OnnxRunner file as input (#10789)
* file path as input and have parse be in OnnxRunner.__init__

* modelproto_to_onnxrunner -> modelproto_to_runner

* whoops, fix import

* oh flakiness again, is it because it's getting gc-ed?

* small changes

* CI flaky so just move compile4 fix in

* copy typing of onnx_load

* actually can just import onnx_load instead of onnx.load

* fix external_benchmark_openpilot

* fix onnx_runner test to use onnx_helper

* rerun CI

* try run_modelproto

* spam CI a few times

* revert run_modelproto since that's flaky also

* no external onnx_load usage except onnx.py

* cursor tab complete is evil. Snuck a darn sorted in. But does order change result? Why?

* model_benchmark 193s -> 80s, add OnnxRunner.to()...

* minimize diff and clean up

* device can be None, weird but eh

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2025-07-12 14:27:46 -04:00

65 lines
3.9 KiB
Python

import unittest, onnx
from tinygrad import dtypes, Tensor
from tinygrad.device import is_dtype_supported
from extra.onnx import data_types
from tinygrad.frontend.onnx import OnnxRunner
from hypothesis import given, settings, strategies as st
import numpy as np
data_types.pop(16) # TODO: this is bf16, need to support double parsing first.
device_supported_dtypes = [odt for odt, dtype in data_types.items() if is_dtype_supported(dtype)]
device_unsupported_dtypes = [odt for odt, dtype in data_types.items() if not is_dtype_supported(dtype)]
class TestOnnxRunnerDtypes(unittest.TestCase):
def _test_input_spec_dtype(self, onnx_data_type, tinygrad_dtype):
input_tensor = onnx.helper.make_tensor_value_info('input', onnx_data_type, ())
output_tensor = onnx.helper.make_tensor_value_info('output', onnx_data_type, ())
node = onnx.helper.make_node('Identity', inputs=['input'], outputs=['output'])
graph = onnx.helper.make_graph([node], 'identity_test', [input_tensor], [output_tensor])
model = onnx.helper.make_model(graph)
runner = OnnxRunner(Tensor(model.SerializeToString(), device="PYTHON"))
self.assertEqual(len(runner.graph_inputs), 1)
self.assertEqual(runner.graph_inputs['input'].dtype, tinygrad_dtype)
def _test_initializer_dtype(self, onnx_data_type, tinygrad_dtype):
arr = np.array([0, 1], dtype=onnx.helper.tensor_dtype_to_np_dtype(onnx_data_type))
initializer = onnx.helper.make_tensor('initializer', onnx_data_type, arr.shape, arr.tobytes(), raw=True)
input_tensor = onnx.helper.make_tensor_value_info('input', onnx_data_type, ())
output_tensor = onnx.helper.make_tensor_value_info('output', onnx_data_type, ())
node = onnx.helper.make_node('Identity', inputs=['input'], outputs=['output'])
graph = onnx.helper.make_graph([node], 'identity_test', [input_tensor], [output_tensor], [initializer])
model = onnx.helper.make_model(graph)
runner = OnnxRunner(Tensor(model.SerializeToString(), device="PYTHON"))
self.assertEqual(len(runner.graph_inputs), 1)
self.assertEqual(runner.graph_values['initializer'].dtype, tinygrad_dtype)
def _test_node_attribute_dtype(self, onnx_data_type, tinygrad_dtype):
arr = np.array([0, 1], dtype=onnx.helper.tensor_dtype_to_np_dtype(onnx_data_type))
output_tensor = onnx.helper.make_tensor_value_info('output', onnx_data_type, arr.shape)
value_tensor = onnx.helper.make_tensor('value', onnx_data_type, arr.shape, arr.tobytes(), raw=True)
node = onnx.helper.make_node('Constant', inputs=[], outputs=['output'], value=value_tensor)
graph = onnx.helper.make_graph([node], 'attribute_test', [], [output_tensor])
model = onnx.helper.make_model(graph)
runner = OnnxRunner(Tensor(model.SerializeToString(), device="PYTHON"))
self.assertEqual(runner.graph_nodes[0].opts['value'].dtype, tinygrad_dtype)
@settings(deadline=1000) # TODO investigate unreliable timing
@given(onnx_data_type=st.sampled_from(device_supported_dtypes))
def test_supported_dtype_spec(self, onnx_data_type):
tinygrad_dtype = data_types[onnx_data_type]
self._test_input_spec_dtype(onnx_data_type, tinygrad_dtype)
self._test_initializer_dtype(onnx_data_type, tinygrad_dtype)
self._test_node_attribute_dtype(onnx_data_type, tinygrad_dtype)
@unittest.skipUnless(device_unsupported_dtypes, "No unsupported dtypes for this device to test.")
@settings(deadline=1000) # TODO investigate unreliable timing
@given(onnx_data_type=st.sampled_from(device_unsupported_dtypes))
def test_unsupported_dtype_spec(self, onnx_data_type):
true_dtype = data_types[onnx_data_type]
default_dtype = dtypes.default_int if dtypes.is_int(true_dtype) else dtypes.default_float
self._test_input_spec_dtype(onnx_data_type, true_dtype)
self._test_initializer_dtype(onnx_data_type, default_dtype)
self._test_node_attribute_dtype(onnx_data_type, default_dtype)
if __name__ == '__main__':
unittest.main()