mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
Simplify openpilot compile3.py (#12748)
* Simpler compile3 * tests * remove default args * onnx file is still fp16 * self-test FP16 too * allow test disable * absurd tolerance * Just do latest * Try simplest * use later models * kernel count not relevant if speed is good * dead improts * Revert "dead improts" This reverts commitf68c2cd15d. * Revert "kernel count not relevant if speed is good" This reverts commit0955ca4ee0. * add back kernal count check on latest model
This commit is contained in:
@@ -1,9 +1,5 @@
|
||||
import os, sys, pickle, time, re
|
||||
import numpy as np
|
||||
if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1"
|
||||
if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
|
||||
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
|
||||
if "JIT_BATCH_SIZE" not in os.environ: os.environ["JIT_BATCH_SIZE"] = "0"
|
||||
|
||||
from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters, Device, dtypes
|
||||
from tinygrad.helpers import DEBUG, getenv
|
||||
@@ -21,11 +17,14 @@ def compile(onnx_file):
|
||||
|
||||
input_shapes = {name: spec.shape for name, spec in run_onnx.graph_inputs.items()}
|
||||
input_types = {name: spec.dtype for name, spec in run_onnx.graph_inputs.items()}
|
||||
|
||||
# Float inputs and outputs to tinyjits for openpilot are always float32
|
||||
# TODO this seems dumb
|
||||
input_types = {k:(dtypes.float32 if v is dtypes.float16 else v) for k,v in input_types.items()}
|
||||
Tensor.manual_seed(100)
|
||||
new_inputs = {k:Tensor.randn(*shp, dtype=input_types[k]).mul(8).realize() for k,shp in sorted(input_shapes.items())}
|
||||
new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()}
|
||||
inputs = {k:Tensor(Tensor.randn(*shp, dtype=input_types[k]).mul(8).realize().numpy(), device='NPY') for k,shp in sorted(input_shapes.items())}
|
||||
if not getenv("NPY_IMG"):
|
||||
inputs = {k:Tensor(v.numpy(), device=Device.DEFAULT).realize() if 'img' in k else v for k,v in inputs.items()}
|
||||
print("created tensors")
|
||||
|
||||
run_onnx_jit = TinyJit(lambda **kwargs:
|
||||
@@ -33,8 +32,6 @@ def compile(onnx_file):
|
||||
for i in range(3):
|
||||
GlobalCounters.reset()
|
||||
print(f"run {i}")
|
||||
inputs = {**{k:v.clone() for k,v in new_inputs.items() if 'img' in k},
|
||||
**{k:Tensor(v, device="NPY").realize() for k,v in new_inputs_numpy.items() if 'img' not in k}}
|
||||
with Context(DEBUG=max(DEBUG.value, 2 if i == 2 else 1)):
|
||||
ret = run_onnx_jit(**inputs).numpy()
|
||||
# copy i == 1 so use of JITBEAM is okay
|
||||
@@ -69,14 +66,9 @@ def compile(onnx_file):
|
||||
print(f"mdl size is {mdl_sz/1e6:.2f}M")
|
||||
print(f"pkl size is {pkl_sz/1e6:.2f}M")
|
||||
print("**** compile done ****")
|
||||
return test_val
|
||||
return inputs, test_val
|
||||
|
||||
def test_vs_compile(run, new_inputs, test_val=None):
|
||||
new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()}
|
||||
|
||||
# create fake "from_blob" tensors for the inputs, and wrapped NPY tensors for the numpy inputs (these have the same underlying memory)
|
||||
inputs = {**{k:v for k,v in new_inputs.items() if 'img' in k},
|
||||
**{k:Tensor(v, device="NPY").realize() for k,v in new_inputs_numpy.items() if 'img' not in k}}
|
||||
def test_vs_compile(run, inputs, test_val=None):
|
||||
|
||||
# run 20 times
|
||||
step_times = []
|
||||
@@ -93,68 +85,49 @@ def test_vs_compile(run, new_inputs, test_val=None):
|
||||
min_time = min(step_times)
|
||||
assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms"
|
||||
|
||||
print(out, val.shape, val.dtype)
|
||||
if test_val is not None: np.testing.assert_equal(test_val, val)
|
||||
print("**** test done ****")
|
||||
|
||||
# test that changing the numpy changes the model outputs
|
||||
if any([x.device == 'NPY' for x in inputs.values()]):
|
||||
for v in new_inputs_numpy.values(): v *= 2
|
||||
out = run(**inputs)
|
||||
changed_val = out.numpy()
|
||||
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, val, changed_val)
|
||||
inputs_2x = {k: Tensor(v.numpy()*2, device=v.device) for k,v in inputs.items()}
|
||||
out = run(**inputs_2x)
|
||||
changed_val = out.numpy()
|
||||
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, val, changed_val)
|
||||
return val
|
||||
|
||||
def test_vs_onnx(new_inputs, test_val, onnx_file, ort=False):
|
||||
new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()}
|
||||
def test_vs_onnx(new_inputs, test_val, onnx_file, tol):
|
||||
import onnxruntime as ort
|
||||
|
||||
onnx_inputs = {k:v.numpy() for k,v in new_inputs.items()}
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
|
||||
timings = []
|
||||
if ort:
|
||||
# test with onnxruntime
|
||||
import onnxruntime as ort
|
||||
onnx_session = ort.InferenceSession(onnx_file)
|
||||
for _ in range(1 if test_val is not None else 5):
|
||||
st = time.perf_counter()
|
||||
onnx_output = onnx_session.run([onnx_model.graph.output[0].name], {k:v.astype(np.float16) for k,v in new_inputs_numpy.items()})
|
||||
timings.append(time.perf_counter() - st)
|
||||
new_torch_out = onnx_output[0]
|
||||
else:
|
||||
# test with torch
|
||||
import torch
|
||||
from onnx2torch import convert
|
||||
inputs = {k.name:new_inputs_numpy[k.name] for k in onnx_model.graph.input}
|
||||
torch_model = convert(onnx_model).float()
|
||||
with torch.no_grad():
|
||||
for _ in range(1 if test_val is not None else 5):
|
||||
st = time.perf_counter()
|
||||
torch_out = torch_model(*[torch.tensor(x) for x in inputs.values()])
|
||||
timings.append(time.perf_counter() - st)
|
||||
new_torch_out = torch_out.numpy()
|
||||
ORT_TO_NP_DTYPES: dict[str, np.dtype] = {
|
||||
'tensor(float)': np.dtype('float32'),
|
||||
'tensor(float16)': np.dtype('float16'),
|
||||
'tensor(uint8)': np.dtype('uint8'),
|
||||
}
|
||||
|
||||
if test_val is not None:
|
||||
np.testing.assert_allclose(new_torch_out.reshape(test_val.shape), test_val, atol=1e-4, rtol=1e-2)
|
||||
print("test vs onnx passed")
|
||||
timings = []
|
||||
onnx_session = ort.InferenceSession(onnx_file)
|
||||
onnx_types = {x.name: ORT_TO_NP_DTYPES[x.type] for x in onnx_session.get_inputs()}
|
||||
onnx_inputs = {k:onnx_inputs[k].astype(onnx_types[k]) for k in onnx_inputs}
|
||||
|
||||
for _ in range(1 if test_val is not None else 5):
|
||||
st = time.perf_counter()
|
||||
onnx_output = onnx_session.run([onnx_model.graph.output[0].name], onnx_inputs)
|
||||
timings.append(time.perf_counter() - st)
|
||||
|
||||
np.testing.assert_allclose(onnx_output[0].reshape(test_val.shape), test_val, atol=tol, rtol=tol)
|
||||
print("test vs onnx passed")
|
||||
return timings
|
||||
|
||||
if __name__ == "__main__":
|
||||
onnx_file = fetch(OPENPILOT_MODEL)
|
||||
test_val = compile(onnx_file) if not getenv("RUN") else None
|
||||
inputs, outputs = compile(onnx_file)
|
||||
|
||||
with open(OUTPUT, "rb") as f: pickle_loaded = pickle.load(f)
|
||||
|
||||
# same randomness as compile
|
||||
Tensor.manual_seed(100)
|
||||
new_inputs = {nm:Tensor.randn(*st.shape, dtype=dtype).mul(8).realize() for nm, (st, _, dtype, _) in
|
||||
sorted(zip(pickle_loaded.captured.expected_names, pickle_loaded.captured.expected_st_vars_dtype_device))}
|
||||
|
||||
test_val = test_vs_compile(pickle_loaded, new_inputs, test_val)
|
||||
if getenv("BENCHMARK"):
|
||||
for be in ["torch", "ort"]:
|
||||
try:
|
||||
timings = test_vs_onnx(new_inputs, None, onnx_file, be=="ort")
|
||||
print(f"timing {be}: {min(timings)*1000:.2f} ms")
|
||||
except Exception as e:
|
||||
print(f"{be} fail with {e}")
|
||||
if not getenv("FLOAT16"): test_vs_onnx(new_inputs, test_val, onnx_file, getenv("ORT"))
|
||||
test_vs_compile(pickle_loaded, inputs, outputs)
|
||||
if not getenv("FLOAT16"):
|
||||
test_vs_onnx(inputs, outputs, onnx_file, 1e-4)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user