mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
run_onnx_torch
This commit is contained in:
@@ -22,7 +22,6 @@ import tinygrad.ops as ops
|
||||
from tinygrad.llops.ops_gpu import CL, CLProgram, CLBuffer
|
||||
from extra.utils import fetch
|
||||
from extra.onnx import get_run_onnx
|
||||
from test.test_onnx import run_onnx_torch
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod
|
||||
|
||||
@@ -165,6 +164,7 @@ def compile(input, output_fn):
|
||||
# float32 only
|
||||
FLOAT16 = int(os.getenv("FLOAT16", 0))
|
||||
if FLOAT16 == 0:
|
||||
from test.test_onnx import run_onnx_torch
|
||||
torch_out = run_onnx_torch(onnx_model, np_inputs).numpy()
|
||||
print(tinygrad_out_np, torch_out)
|
||||
np.testing.assert_allclose(torch_out, tinygrad_out_np, atol=1e-4, rtol=1e-2)
|
||||
|
||||
Reference in New Issue
Block a user