diff --git a/openpilot/compile.py b/openpilot/compile.py index c83fd897b9..fbe8491ccc 100644 --- a/openpilot/compile.py +++ b/openpilot/compile.py @@ -43,14 +43,11 @@ def get_random_input_tensors(): for _,v in inputs.items(): v.realize() return inputs, np_inputs -# OPTWG=1 UNSAFE_FLOAT4=1 DEBUGCL=1 FLOAT16=1 MATMUL=1 python3 openpilot/compile.py -# 22.59 ms -if __name__ == "__main__": +def compile(input, output_fn): Tensor.no_grad = True using_graph = ops.GRAPH ops.GRAPH = False - dat = fetch(OPENPILOT_MODEL) onnx_model = onnx.load(io.BytesIO(dat)) run_onnx = get_run_onnx(onnx_model) inputs, _ = get_random_input_tensors() @@ -241,9 +238,20 @@ if __name__ == "__main__": }) print("saving thneed") - with open("/tmp/output.thneed", "wb") as f: + with open(output_fn, "wb") as f: j = json.dumps(jdat, ensure_ascii=False).encode('latin_1') f.write(struct.pack("I", len(j))) f.write(j) f.write(b''.join(weights)) f.write(b''.join(binaries)) + +# OPTWG=1 UNSAFE_FLOAT4=1 DEBUGCL=1 FLOAT16=1 MATMUL=1 python3 openpilot/compile.py +# 22.59 ms +if __name__ == "__main__": + if len(sys.argv) >= 3: + with open(sys.argv[1], "rb") as f: + dat = f.read() + compile(dat, sys.argv[2]) + else: + dat = fetch(OPENPILOT_MODEL) + compile(dat, "/tmp/output.thneed")