Files
tinygrad/examples/openpilot/compile3.py
2025-10-17 14:35:40 -07:00

116 lines
4.4 KiB
Python

import os, sys, pickle, time, re
import numpy as np
from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters, Device, dtypes
from tinygrad.helpers import DEBUG, getenv
from tinygrad.engine.realize import CompiledRunner
import onnx
from tinygrad.nn.onnx import OnnxRunner
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx"
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/openpilot.pkl"
def compile(onnx_file):
run_onnx = OnnxRunner(onnx_file)
print("loaded model")
input_shapes = {name: spec.shape for name, spec in run_onnx.graph_inputs.items()}
input_types = {name: spec.dtype for name, spec in run_onnx.graph_inputs.items()}
# Float inputs and outputs to tinyjits for openpilot are always float32
# TODO this seems dumb
input_types = {k:(dtypes.float32 if v is dtypes.float16 else v) for k,v in input_types.items()}
Tensor.manual_seed(100)
inputs = {k:Tensor(Tensor.randn(*shp, dtype=input_types[k]).mul(8).realize().numpy(), device='NPY') for k,shp in sorted(input_shapes.items())}
if not getenv("NPY_IMG"):
inputs = {k:Tensor(v.numpy(), device=Device.DEFAULT).realize() if 'img' in k else v for k,v in inputs.items()}
print("created tensors")
run_onnx_jit = TinyJit(lambda **kwargs:
next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())).cast('float32'), prune=True)
for i in range(3):
GlobalCounters.reset()
print(f"run {i}")
with Context(DEBUG=max(DEBUG.value, 2 if i == 2 else 1)):
ret = run_onnx_jit(**inputs).numpy()
# copy i == 1 so use of JITBEAM is okay
if i == 1: test_val = np.copy(ret)
print(f"captured {len(run_onnx_jit.captured.jit_cache)} kernels")
np.testing.assert_equal(test_val, ret, "JIT run failed")
print("jit run validated")
with open(OUTPUT, "wb") as f:
pickle.dump(run_onnx_jit, f)
mdl_sz = os.path.getsize(onnx_file)
pkl_sz = os.path.getsize(OUTPUT)
print(f"mdl size is {mdl_sz/1e6:.2f}M")
print(f"pkl size is {pkl_sz/1e6:.2f}M")
print("**** compile done ****")
return inputs, test_val
def test_vs_compile(run, inputs, test_val=None):
# run 20 times
step_times = []
for _ in range(20):
st = time.perf_counter()
out = run(**inputs)
mt = time.perf_counter()
val = out.numpy()
et = time.perf_counter()
step_times.append((et-st)*1e3)
print(f"enqueue {(mt-st)*1e3:6.2f} ms -- total run {step_times[-1]:6.2f} ms")
if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")):
min_time = min(step_times)
assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms"
if test_val is not None: np.testing.assert_equal(test_val, val)
print("**** test done ****")
# test that changing the numpy changes the model outputs
inputs_2x = {k: Tensor(v.numpy()*2, device=v.device) for k,v in inputs.items()}
out = run(**inputs_2x)
changed_val = out.numpy()
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, val, changed_val)
return val
def test_vs_onnx(new_inputs, test_val, onnx_file, tol):
import onnxruntime as ort
onnx_inputs = {k:v.numpy() for k,v in new_inputs.items()}
onnx_model = onnx.load(onnx_file)
ORT_TO_NP_DTYPES: dict[str, np.dtype] = {
'tensor(float)': np.dtype('float32'),
'tensor(float16)': np.dtype('float16'),
'tensor(uint8)': np.dtype('uint8'),
}
timings = []
onnx_session = ort.InferenceSession(onnx_file)
onnx_types = {x.name: ORT_TO_NP_DTYPES[x.type] for x in onnx_session.get_inputs()}
onnx_inputs = {k:onnx_inputs[k].astype(onnx_types[k]) for k in onnx_inputs}
for _ in range(1 if test_val is not None else 5):
st = time.perf_counter()
onnx_output = onnx_session.run([onnx_model.graph.output[0].name], onnx_inputs)
timings.append(time.perf_counter() - st)
np.testing.assert_allclose(onnx_output[0].reshape(test_val.shape), test_val, atol=tol, rtol=tol)
print("test vs onnx passed")
return timings
if __name__ == "__main__":
onnx_file = fetch(OPENPILOT_MODEL)
inputs, outputs = compile(onnx_file)
with open(OUTPUT, "rb") as f: pickle_loaded = pickle.load(f)
test_vs_compile(pickle_loaded, inputs, outputs)
if not getenv("NO_ORT"):
tol = 1 if getenv("FLOAT16") else 1e-4 # This tolerance is absurd, but better than nothing
test_vs_onnx(inputs, outputs, onnx_file, tol)