mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
output file to disk
This commit is contained in:
@@ -43,14 +43,11 @@ def get_random_input_tensors():
|
||||
for _,v in inputs.items(): v.realize()
|
||||
return inputs, np_inputs
|
||||
|
||||
# OPTWG=1 UNSAFE_FLOAT4=1 DEBUGCL=1 FLOAT16=1 MATMUL=1 python3 openpilot/compile.py
|
||||
# 22.59 ms
|
||||
if __name__ == "__main__":
|
||||
def compile(input, output_fn):
|
||||
Tensor.no_grad = True
|
||||
using_graph = ops.GRAPH
|
||||
ops.GRAPH = False
|
||||
|
||||
dat = fetch(OPENPILOT_MODEL)
|
||||
onnx_model = onnx.load(io.BytesIO(dat))
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
inputs, _ = get_random_input_tensors()
|
||||
@@ -241,9 +238,20 @@ if __name__ == "__main__":
|
||||
})
|
||||
|
||||
print("saving thneed")
|
||||
with open("/tmp/output.thneed", "wb") as f:
|
||||
with open(output_fn, "wb") as f:
|
||||
j = json.dumps(jdat, ensure_ascii=False).encode('latin_1')
|
||||
f.write(struct.pack("I", len(j)))
|
||||
f.write(j)
|
||||
f.write(b''.join(weights))
|
||||
f.write(b''.join(binaries))
|
||||
|
||||
# OPTWG=1 UNSAFE_FLOAT4=1 DEBUGCL=1 FLOAT16=1 MATMUL=1 python3 openpilot/compile.py
|
||||
# 22.59 ms
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) >= 3:
|
||||
with open(sys.argv[1], "rb") as f:
|
||||
dat = f.read()
|
||||
compile(dat, sys.argv[2])
|
||||
else:
|
||||
dat = fetch(OPENPILOT_MODEL)
|
||||
compile(dat, "/tmp/output.thneed")
|
||||
|
||||
Reference in New Issue
Block a user