From dd587d26e318651e52b5160332b583eafbf82fa8 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 28 Aug 2022 11:23:21 -0700 Subject: [PATCH] oops, compare with abs --- openpilot/run_thneed.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/openpilot/run_thneed.py b/openpilot/run_thneed.py index 6f96c8b6c4..3e6599a8f0 100644 --- a/openpilot/run_thneed.py +++ b/openpilot/run_thneed.py @@ -229,7 +229,7 @@ if __name__ == "__main__": "big_input_imgs": np.random.randn(*(1, 12, 128, 256))*256, "desire": np.zeros((1, 8)), "traffic_convention": np.array([[1., 0.]]), - "initial_state": np.zeros((1, 512)) + "initial_state": np.random.randn(*(1, 512)) } np_inputs = {k:v.astype(np.float32) for k,v in np_inputs.items()} inputs = list(np_inputs.values())[::-1] @@ -247,13 +247,17 @@ if __name__ == "__main__": out = run_onnx_torch(onnx_model, np_inputs).numpy()[0] diff = 0 + diffs = [] for i in range(ret.shape[0]): - if out[i]-ret[i] > 0.1 and (out[i]-ret[i])/out[i] > 0.01: - print(i, out[i], ret[i], out[i]-ret[i]) + if abs(out[i]-ret[i]) > 0.1 and abs((out[i]-ret[i])/out[i]) > 0.01: diff += 1 + diffs.append(out[i] - ret[i]) if diff == 10: print("...") - break + elif diff < 10: + print(i, out[i], ret[i], out[i]-ret[i]) + if len(diffs) > 0: + print("%d differences min: %f max: %f" % (diff, min(diffs), max(diffs))) assert diff == 0 """