still broken

This commit is contained in:
George Hotz
2022-08-29 19:08:07 -07:00
parent 5efab7cf1d
commit 33ac355bcd

View File

@@ -34,10 +34,16 @@ def get_random_input_tensors():
"big_input_imgs": np.random.randn(*(1, 12, 128, 256))*256,
"desire": np.zeros((1, 8)),
"traffic_convention": np.array([[1., 0.]]),
"initial_state": np.zeros((1, 512))
"initial_state": np.random.randn(*(1, 512))
#"initial_state": np.zeros((1, 768))
}
#import pickle
#frames, big_frames, last_state, frame_inputs, policy_outs = pickle.load(open("openpilot/test/frame_0.pkl", "rb"))
#np_inputs["input_imgs"] = frames
#np_inputs["big_input_imgs"] = big_frames
#np_inputs["initial_state"] = last_state[0]
#for i,k in enumerate(np_inputs.keys()):
# dat = open("/home/batman/openpilot/xx/ml_tools/snpe/compile_test_data/dlc_input_%d" % i, "rb").read()
# np_inputs[k] = np.frombuffer(dat, np.float32).reshape(np_inputs[k].shape)
@@ -172,7 +178,7 @@ def compile(input, output_fn):
try:
from test.test_onnx import run_onnx_torch
torch_out = run_onnx_torch(onnx_model, np_inputs).numpy()
print(tinygrad_out_np, torch_out, "mse", np.sum((tinygrad_out_np-torch_out)**2))
print(tinygrad_out_np, torch_out, "mse", np.sum((tinygrad_out_np-torch_out)**2), "max err", np.max(np.abs((tinygrad_out_np-torch_out))))
np.testing.assert_allclose(torch_out, tinygrad_out_np, atol=1e-4, rtol=1e-2)
except ModuleNotFoundError:
pass