mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
local
This commit is contained in:
@@ -66,7 +66,6 @@ class TestOnnxModel(unittest.TestCase):
|
||||
ps.print_stats(30)
|
||||
|
||||
def test_openpilot_model(self):
|
||||
Tensor.no_grad = True
|
||||
dat = fetch(OPENPILOT_MODEL)
|
||||
onnx_model = onnx.load(io.BytesIO(dat))
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
@@ -90,7 +89,9 @@ class TestOnnxModel(unittest.TestCase):
|
||||
et = time.monotonic()
|
||||
print(f"ran openpilot model in {(et-st)*1000.0:.2f} ms, waited {(mt2-mt)*1000.0:.2f} ms for realize, {(et-mt2)*1000.0:.2f} ms for GPU queue")
|
||||
|
||||
Tensor.no_grad = True
|
||||
torch_out = run_onnx_torch(onnx_model, inputs).numpy()
|
||||
Tensor.no_grad = False
|
||||
print(tinygrad_out, torch_out)
|
||||
np.testing.assert_allclose(torch_out, tinygrad_out, atol=1e-4, rtol=1e-2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user