mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
oops, compare with abs
This commit is contained in:
@@ -229,7 +229,7 @@ if __name__ == "__main__":
|
||||
"big_input_imgs": np.random.randn(*(1, 12, 128, 256))*256,
|
||||
"desire": np.zeros((1, 8)),
|
||||
"traffic_convention": np.array([[1., 0.]]),
|
||||
"initial_state": np.zeros((1, 512))
|
||||
"initial_state": np.random.randn(*(1, 512))
|
||||
}
|
||||
np_inputs = {k:v.astype(np.float32) for k,v in np_inputs.items()}
|
||||
inputs = list(np_inputs.values())[::-1]
|
||||
@@ -247,13 +247,17 @@ if __name__ == "__main__":
|
||||
out = run_onnx_torch(onnx_model, np_inputs).numpy()[0]
|
||||
|
||||
diff = 0
|
||||
diffs = []
|
||||
for i in range(ret.shape[0]):
|
||||
if out[i]-ret[i] > 0.1 and (out[i]-ret[i])/out[i] > 0.01:
|
||||
print(i, out[i], ret[i], out[i]-ret[i])
|
||||
if abs(out[i]-ret[i]) > 0.1 and abs((out[i]-ret[i])/out[i]) > 0.01:
|
||||
diff += 1
|
||||
diffs.append(out[i] - ret[i])
|
||||
if diff == 10:
|
||||
print("...")
|
||||
break
|
||||
elif diff < 10:
|
||||
print(i, out[i], ret[i], out[i]-ret[i])
|
||||
if len(diffs) > 0:
|
||||
print("%d differences min: %f max: %f" % (diff, min(diffs), max(diffs)))
|
||||
assert diff == 0
|
||||
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user