mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
onnx helper intermediate node output validation (#12740)
* start * update comments * good * add comments and better printing * done
This commit is contained in:
@@ -3,8 +3,21 @@ from tinygrad.tensor import _to_np_dtype
|
|||||||
from tinygrad.nn.onnx import OnnxRunner, OnnxValue
|
from tinygrad.nn.onnx import OnnxRunner, OnnxValue
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
ort_options = ort.SessionOptions()
|
||||||
|
ort_options.log_severity_level = 3
|
||||||
|
|
||||||
def get_example_inputs(graph_inputs:dict[str, OnnxValue], config={}):
|
def get_example_inputs(graph_inputs:dict[str, OnnxValue], config={}):
|
||||||
|
"""
|
||||||
|
Generate example input tensors based on the provided ONNX graph input specifications.
|
||||||
|
|
||||||
|
NOTE: This is not guaranteed to be reliable. It's a best-effort helper
|
||||||
|
that uses heuristics to guess input shapes and values.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
from tinygrad.nn.onnx import OnnxRunner
|
||||||
|
from extra.onnx_helpers import get_example_inputs
|
||||||
|
inputs = get_example_inputs(OnnxRunner(model_path).graph_inputs)
|
||||||
|
"""
|
||||||
def _get_shape(onnx_shape: tuple[str|int]):
|
def _get_shape(onnx_shape: tuple[str|int]):
|
||||||
shape = []
|
shape = []
|
||||||
for onnx_dim in onnx_shape:
|
for onnx_dim in onnx_shape:
|
||||||
@@ -44,11 +57,9 @@ def get_example_inputs(graph_inputs:dict[str, OnnxValue], config={}):
|
|||||||
ret.update({name:value})
|
ret.update({name:value})
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def validate(onnx_file, inputs, rtol=1e-5, atol=1e-5):
|
def _get_tinygrad_and_ort_np_outputs(onnx_file, inputs):
|
||||||
run_onnx = OnnxRunner(onnx_file)
|
run_onnx = OnnxRunner(onnx_file)
|
||||||
|
|
||||||
ort_options = ort.SessionOptions()
|
|
||||||
ort_options.log_severity_level = 3
|
|
||||||
ort_sess = ort.InferenceSession(onnx_file, ort_options, ["CPUExecutionProvider"])
|
ort_sess = ort.InferenceSession(onnx_file, ort_options, ["CPUExecutionProvider"])
|
||||||
np_inputs = {k:v.numpy() if isinstance(v, Tensor) else v for k,v in inputs.items()}
|
np_inputs = {k:v.numpy() if isinstance(v, Tensor) else v for k,v in inputs.items()}
|
||||||
out_names = list(run_onnx.graph_outputs)
|
out_names = list(run_onnx.graph_outputs)
|
||||||
@@ -56,9 +67,101 @@ def validate(onnx_file, inputs, rtol=1e-5, atol=1e-5):
|
|||||||
ort_out = dict(zip(out_names, out_values))
|
ort_out = dict(zip(out_names, out_values))
|
||||||
|
|
||||||
tinygrad_out = run_onnx(inputs)
|
tinygrad_out = run_onnx(inputs)
|
||||||
|
Tensor.realize(*(x for x in tinygrad_out.values() if x is not None))
|
||||||
|
tinygrad_out = {k:v.numpy() if v is not None else None for k,v in tinygrad_out.items()}
|
||||||
|
return tinygrad_out, ort_out
|
||||||
|
|
||||||
|
def validate(onnx_file, inputs, rtol=1e-5, atol=1e-5):
|
||||||
|
"""
|
||||||
|
Compares the final output tensors of an onnx model run in tinygrad and onnxruntime.
|
||||||
|
"""
|
||||||
|
tinygrad_out, ort_out = _get_tinygrad_and_ort_np_outputs(onnx_file, inputs)
|
||||||
|
|
||||||
assert tinygrad_out.keys() == ort_out.keys()
|
assert tinygrad_out.keys() == ort_out.keys()
|
||||||
for k in tinygrad_out.keys():
|
for k in tinygrad_out.keys():
|
||||||
tiny_v, onnx_v = tinygrad_out[k], ort_out[k]
|
tiny_v, onnx_v = tinygrad_out[k], ort_out[k]
|
||||||
if tiny_v is None: assert onnx_v is None, f"{k}: {tiny_v=}, {onnx_v=}"
|
if tiny_v is None: assert onnx_v is None, f"{k}: {tiny_v=}, {onnx_v=}"
|
||||||
else: np.testing.assert_allclose(tiny_v.numpy(), onnx_v, rtol=rtol, atol=atol, err_msg=f"For tensor '{k}' in {tinygrad_out.keys()}")
|
else: np.testing.assert_allclose(tiny_v, onnx_v, rtol=rtol, atol=atol, err_msg=f"For tensor '{k}' in {tinygrad_out.keys()}")
|
||||||
|
|
||||||
|
def validate_all_intermediates(onnx_file, inputs, rtol=1e-5, atol=1e-5):
|
||||||
|
"""
|
||||||
|
Compares all intermediate node output of an onnx model run in tinygrad and onnxruntime.
|
||||||
|
"""
|
||||||
|
report = generate_node_output_report(onnx_file, inputs)
|
||||||
|
for i, node in enumerate(report):
|
||||||
|
node_name = node["node"]
|
||||||
|
op = node["op"]
|
||||||
|
outputs = node["outputs"]
|
||||||
|
for output in outputs:
|
||||||
|
output_name = output["name"]
|
||||||
|
tinygrad_out = output["tinygrad"]
|
||||||
|
ort_out = output["onnxruntime"]
|
||||||
|
try:
|
||||||
|
if tinygrad_out is None: assert ort_out is None, f"None outputs are not equal {tinygrad_out=} {ort_out=}"
|
||||||
|
else: np.testing.assert_allclose(tinygrad_out, ort_out, rtol=rtol, atol=atol)
|
||||||
|
print(f"Validated {i}: {op=} {node_name=} {output_name=}")
|
||||||
|
except AssertionError as e:
|
||||||
|
print(f"FAILED {i}: {op=} {node_name=} {output_name=}")
|
||||||
|
print(str(e).strip() + "\n")
|
||||||
|
|
||||||
|
def generate_node_output_report(onnx_file, inputs):
|
||||||
|
"""
|
||||||
|
Build a report of all ONNX node outputs from tinygrad and onnxruntime
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of dictionaries, where each entry corresponds to one
|
||||||
|
node in the ONNX graph. The structure is as follows:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"node": str, # The name of the ONNX node.
|
||||||
|
"op": str, # The operation type of the ONNX node.
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": str, # The name of the output tensor.
|
||||||
|
"tinygrad": np.ndarray | None, # The output value from tinygrad.
|
||||||
|
"onnxruntime": np.ndarray | None, # The output value from onnxruntime.
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
import onnx_graphsurgeon as gs
|
||||||
|
import onnx
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
# rewrite the model to output all the node outputs
|
||||||
|
# `infer_shapes` here tries to fill the shapes and dtypes of intermediate values which graphsurgeon requires when assigning them as outputs
|
||||||
|
inferred_model = onnx.shape_inference.infer_shapes(onnx.load(onnx_file))
|
||||||
|
model = gs.import_onnx(inferred_model)
|
||||||
|
model_nodes = model.nodes
|
||||||
|
node_outputs = [n.outputs for n in model.nodes]
|
||||||
|
model.outputs = [
|
||||||
|
each_output for outputs in node_outputs for each_output in outputs
|
||||||
|
if not (each_output.dtype is None and each_output.shape is None) # output with None dtype and None shape is likely a `None` value
|
||||||
|
]
|
||||||
|
rewritten_model = gs.export_onnx(model)
|
||||||
|
|
||||||
|
# TODO: remove this once ORT supports 1.18.0
|
||||||
|
if getattr(rewritten_model, "ir_version", 0) > 10:
|
||||||
|
rewritten_model.ir_version = 10
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".onnx") as f:
|
||||||
|
onnx.save(rewritten_model, f.name)
|
||||||
|
rewritten_model_path = f.name
|
||||||
|
tinygrad_out, ort_out = _get_tinygrad_and_ort_np_outputs(rewritten_model_path, inputs)
|
||||||
|
|
||||||
|
report = []
|
||||||
|
for node in model_nodes:
|
||||||
|
outputs = []
|
||||||
|
for each_output in node.outputs:
|
||||||
|
if each_output.dtype is None and each_output.shape is None:
|
||||||
|
continue
|
||||||
|
name = each_output.name
|
||||||
|
tinygrad_output = tinygrad_out[name]
|
||||||
|
ort_output = ort_out[name]
|
||||||
|
outputs.append({"name": name, "tinygrad": tinygrad_output, "onnxruntime": ort_output})
|
||||||
|
report.append({"node": node.name, "op": node.op, "outputs": outputs})
|
||||||
|
|
||||||
|
return report
|
||||||
|
|||||||
Reference in New Issue
Block a user