oops, compare with abs

This commit is contained in:
George Hotz
2022-08-28 11:23:21 -07:00
parent dc7af8c3ac
commit dd587d26e3

View File

@@ -229,7 +229,7 @@ if __name__ == "__main__":
"big_input_imgs": np.random.randn(*(1, 12, 128, 256))*256,
"desire": np.zeros((1, 8)),
"traffic_convention": np.array([[1., 0.]]),
"initial_state": np.zeros((1, 512))
"initial_state": np.random.randn(*(1, 512))
}
np_inputs = {k:v.astype(np.float32) for k,v in np_inputs.items()}
inputs = list(np_inputs.values())[::-1]
@@ -247,13 +247,17 @@ if __name__ == "__main__":
out = run_onnx_torch(onnx_model, np_inputs).numpy()[0]
diff = 0
diffs = []
for i in range(ret.shape[0]):
if out[i]-ret[i] > 0.1 and (out[i]-ret[i])/out[i] > 0.01:
print(i, out[i], ret[i], out[i]-ret[i])
if abs(out[i]-ret[i]) > 0.1 and abs((out[i]-ret[i])/out[i]) > 0.01:
diff += 1
diffs.append(out[i] - ret[i])
if diff == 10:
print("...")
break
elif diff < 10:
print(i, out[i], ret[i], out[i]-ret[i])
if len(diffs) > 0:
print("%d differences min: %f max: %f" % (diff, min(diffs), max(diffs)))
assert diff == 0
"""