mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
still broken
This commit is contained in:
@@ -34,10 +34,16 @@ def get_random_input_tensors():
|
||||
"big_input_imgs": np.random.randn(*(1, 12, 128, 256))*256,
|
||||
"desire": np.zeros((1, 8)),
|
||||
"traffic_convention": np.array([[1., 0.]]),
|
||||
"initial_state": np.zeros((1, 512))
|
||||
"initial_state": np.random.randn(*(1, 512))
|
||||
#"initial_state": np.zeros((1, 768))
|
||||
}
|
||||
|
||||
#import pickle
|
||||
#frames, big_frames, last_state, frame_inputs, policy_outs = pickle.load(open("openpilot/test/frame_0.pkl", "rb"))
|
||||
#np_inputs["input_imgs"] = frames
|
||||
#np_inputs["big_input_imgs"] = big_frames
|
||||
#np_inputs["initial_state"] = last_state[0]
|
||||
|
||||
#for i,k in enumerate(np_inputs.keys()):
|
||||
# dat = open("/home/batman/openpilot/xx/ml_tools/snpe/compile_test_data/dlc_input_%d" % i, "rb").read()
|
||||
# np_inputs[k] = np.frombuffer(dat, np.float32).reshape(np_inputs[k].shape)
|
||||
@@ -172,7 +178,7 @@ def compile(input, output_fn):
|
||||
try:
|
||||
from test.test_onnx import run_onnx_torch
|
||||
torch_out = run_onnx_torch(onnx_model, np_inputs).numpy()
|
||||
print(tinygrad_out_np, torch_out, "mse", np.sum((tinygrad_out_np-torch_out)**2))
|
||||
print(tinygrad_out_np, torch_out, "mse", np.sum((tinygrad_out_np-torch_out)**2), "max err", np.max(np.abs((tinygrad_out_np-torch_out))))
|
||||
np.testing.assert_allclose(torch_out, tinygrad_out_np, atol=1e-4, rtol=1e-2)
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user