mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 22:38:16 -05:00
fix thneed self test
This commit is contained in:
@@ -35,6 +35,7 @@ def get_random_input_tensors():
|
||||
"desire": np.zeros((1,100, 8)),
|
||||
"traffic_convention": np.array([[1., 0.]]),
|
||||
"features_buffer": np.random.randn(*(1, 99, 128))
|
||||
#"features_buffer": np.random.randn(*(1, 99, 3*512))
|
||||
#"initial_state": np.zeros((1, 768))
|
||||
}
|
||||
if int(os.getenv("ZERO_OUT", "0")):
|
||||
@@ -124,22 +125,26 @@ def compile(dat, output_fn):
|
||||
torch_out = run_onnx_torch(onnx_model, np_inputs).numpy()
|
||||
print(thneed_out, torch_out, "mse", np.sum((thneed_out-torch_out)**2), "max err", np.max(np.abs((thneed_out-torch_out))))
|
||||
np.testing.assert_allclose(torch_out, thneed_out, atol=1e-4, rtol=1e-2)
|
||||
|
||||
# test loading/run thneed
|
||||
_, new_np_inputs = get_random_input_tensors()
|
||||
new_torch_out = run_onnx_torch(onnx_model, new_np_inputs).numpy()
|
||||
|
||||
nt = Thneed()
|
||||
nt.load(output_fn)
|
||||
|
||||
# inputs
|
||||
for k,v in nt.inputs.items():
|
||||
CL.enqueue_copy(v, new_np_inputs[k], is_blocking=True)
|
||||
|
||||
nt.run()
|
||||
new_thneed_out = np.empty((nt.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape)
|
||||
CL.enqueue_copy(new_thneed_out, nt.outputs[0], is_blocking=True)
|
||||
np.testing.assert_allclose(new_torch_out, new_thneed_out, atol=1e-4, rtol=1e-2)
|
||||
print("thneed self-test passed!")
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
# test loading/run thneed
|
||||
nt = Thneed()
|
||||
nt.load(output_fn)
|
||||
|
||||
# inputs
|
||||
for k,v in nt.inputs.items():
|
||||
CL.enqueue_copy(v, np_inputs[k], is_blocking=True)
|
||||
|
||||
nt.run()
|
||||
new_thneed_out = np.empty((nt.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape)
|
||||
CL.enqueue_copy(new_thneed_out, t.outputs[0], is_blocking=True)
|
||||
np.testing.assert_allclose(torch_out, new_thneed_out, atol=1e-4, rtol=1e-2)
|
||||
print("thneed self-test passed!")
|
||||
|
||||
|
||||
# UNSAFE_FLOAT4=1 DEBUGCL=1 FLOAT16=1 python3 openpilot/compile.py
|
||||
|
||||
Reference in New Issue
Block a user