mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
add test_onnx_ops.py (#8569)
* boom * fix webgpu * use exact variable names in test so that AI can read easier * add tag for specific test name like test a specific dtype * fix ruff * astype everything * dtype in array creation * just arange * is 67% considered fixed? * move test up * small cleanups * share function * add qgemm as well * add qgemm too * make sure qgemm comes out as int * take out qgemm for now * fixed test * add correct qgemm * addressing feedback here too, early naive fix for now * simplify bias and c to be minimalistic enough to test correctness * refactored qlinearops * maybe these asserts aren't the best.. * fix test * updated tests to cover new ops * try to add to CI * move test_onnx_ops into testextra/ * more attention tests * qlinear_add atol=1 * attention still not fullllllly correct * it is what it is --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -16,7 +16,6 @@ def get_example_inputs(graph_inputs:dict[str, OnnxValue]):
|
||||
|
||||
def validate(onnx_file, inputs, rtol=1e-5, atol=1e-5):
|
||||
run_onnx = OnnxRunner(onnx.load(onnx_file))
|
||||
tinygrad_out = run_onnx(inputs)
|
||||
|
||||
ort_options = ort.SessionOptions()
|
||||
ort_options.log_severity_level = 3
|
||||
@@ -26,8 +25,10 @@ def validate(onnx_file, inputs, rtol=1e-5, atol=1e-5):
|
||||
out_values = ort_sess.run(out_names, np_inputs)
|
||||
ort_out = dict(zip(out_names, out_values))
|
||||
|
||||
assert len(tinygrad_out) == len(ort_out) and tinygrad_out.keys() == ort_out.keys()
|
||||
tinygrad_out = run_onnx(inputs)
|
||||
|
||||
assert tinygrad_out.keys() == ort_out.keys()
|
||||
for k in tinygrad_out.keys():
|
||||
tiny_v, onnx_v = tinygrad_out[k], ort_out[k]
|
||||
if tiny_v is None: assert tiny_v == onnx_v
|
||||
if tiny_v is None: assert onnx_v is None, f"{k}: {tiny_v=}, {onnx_v=}"
|
||||
else: np.testing.assert_allclose(tiny_v.numpy(), onnx_v, rtol=rtol, atol=atol, err_msg=f"For tensor '{k}' in {tinygrad_out.keys()}")
|
||||
Reference in New Issue
Block a user