mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
inputs and outputs
This commit is contained in:
@@ -160,13 +160,13 @@ def compile(input, output_fn):
|
||||
print(f"{k:20s} runtime: {v/1e6:.2f} ms")
|
||||
print(f"total runtime: {total_runtime/1e6:.2f} ms")
|
||||
|
||||
tinygrad_out = tinygrad_out.numpy()
|
||||
tinygrad_out_np = tinygrad_out.numpy()
|
||||
|
||||
# float32 only
|
||||
if int(os.getenv("FLOAT16", 0)) == 0:
|
||||
torch_out = run_onnx_torch(onnx_model, np_inputs).numpy()
|
||||
print(tinygrad_out, torch_out)
|
||||
np.testing.assert_allclose(torch_out, tinygrad_out, atol=1e-4, rtol=1e-2)
|
||||
print(tinygrad_out_np, torch_out)
|
||||
np.testing.assert_allclose(torch_out, tinygrad_out_np, atol=1e-4, rtol=1e-2)
|
||||
|
||||
# save local_cl_cache as thneed
|
||||
import struct, json
|
||||
@@ -237,6 +237,14 @@ def compile(input, output_fn):
|
||||
"args": targs,
|
||||
"args_size": args_size
|
||||
})
|
||||
|
||||
jdat['outputs'] = [{"buffer_id": tinygrad_out.lazydata.realized.cl.global_id}]
|
||||
jdat['inputs'] = [{
|
||||
"buffer_id": v.lazydata.realized.cl.global_id,
|
||||
"size": v.lazydata.realized.cl.size,
|
||||
"name": k
|
||||
} for k,v in inputs.items()][::-1]
|
||||
print(jdat['inputs'])
|
||||
|
||||
print("saving thneed")
|
||||
with open(output_fn, "wb") as f:
|
||||
|
||||
Reference in New Issue
Block a user