run_onnx_torch

This commit is contained in:
Comma Device
2022-08-18 08:30:12 -07:00
parent 1f23517d92
commit 85453288d7

View File

@@ -22,7 +22,6 @@ import tinygrad.ops as ops
from tinygrad.llops.ops_gpu import CL, CLProgram, CLBuffer
from extra.utils import fetch
from extra.onnx import get_run_onnx
from test.test_onnx import run_onnx_torch
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod
@@ -165,6 +164,7 @@ def compile(input, output_fn):
# float32 only
FLOAT16 = int(os.getenv("FLOAT16", 0))
if FLOAT16 == 0:
from test.test_onnx import run_onnx_torch
torch_out = run_onnx_torch(onnx_model, np_inputs).numpy()
print(tinygrad_out_np, torch_out)
np.testing.assert_allclose(torch_out, tinygrad_out_np, atol=1e-4, rtol=1e-2)