inputs and outputs

This commit is contained in:
Comma Device
2022-07-18 20:17:26 -07:00
parent ae30641b0d
commit 29581b5c85

View File

@@ -160,13 +160,13 @@ def compile(input, output_fn):
print(f"{k:20s} runtime: {v/1e6:.2f} ms")
print(f"total runtime: {total_runtime/1e6:.2f} ms")
tinygrad_out = tinygrad_out.numpy()
tinygrad_out_np = tinygrad_out.numpy()
# float32 only
if int(os.getenv("FLOAT16", 0)) == 0:
torch_out = run_onnx_torch(onnx_model, np_inputs).numpy()
print(tinygrad_out, torch_out)
np.testing.assert_allclose(torch_out, tinygrad_out, atol=1e-4, rtol=1e-2)
print(tinygrad_out_np, torch_out)
np.testing.assert_allclose(torch_out, tinygrad_out_np, atol=1e-4, rtol=1e-2)
# save local_cl_cache as thneed
import struct, json
@@ -237,6 +237,14 @@ def compile(input, output_fn):
"args": targs,
"args_size": args_size
})
jdat['outputs'] = [{"buffer_id": tinygrad_out.lazydata.realized.cl.global_id}]
jdat['inputs'] = [{
"buffer_id": v.lazydata.realized.cl.global_id,
"size": v.lazydata.realized.cl.size,
"name": k
} for k,v in inputs.items()][::-1]
print(jdat['inputs'])
print("saving thneed")
with open(output_fn, "wb") as f: