diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index d7c7ebf29c..fad34c6d1b 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -131,7 +131,7 @@ jobs: - name: UsbGPU copy speeds run: sudo -E PYTHONPATH=. AMD=1 AMD_IFACE=USB python3.11 test/external/external_test_usb_asm24.py TestDevCopySpeeds #- name: UsbGPU openpilot test - # run: sudo -E PYTHONPATH=. AMD=1 AMD_IFACE=USB NOLOCALS=0 IMAGE=0 GRAPH_ONE_KERNEL=1 python3.11 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx + # run: sudo -E PYTHONPATH=. AMD=1 AMD_IFACE=USB GRAPH_ONE_KERNEL=1 python3.11 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx - uses: actions/upload-artifact@v4 with: name: Speed (Mac) @@ -626,15 +626,15 @@ jobs: - name: benchmark openpilot 0.9.9 dmonitoring run: BENCHMARK_LOG=openpilot_0_9_9_dmonitoring PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/dmonitoring_model.onnx - name: openpilot compile3 0.9.9 driving_vision - run: PYTHONPATH="." ASSERT_MIN_STEP_TIME=18 QCOM=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/driving_vision.onnx + run: PYTHONPATH="." ASSERT_MIN_STEP_TIME=18 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/driving_vision.onnx - name: openpilot compile3 0.9.9 driving_policy - run: PYTHONPATH="." ASSERT_MIN_STEP_TIME=7 QCOM=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/driving_policy.onnx + run: PYTHONPATH="." ASSERT_MIN_STEP_TIME=7 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/driving_policy.onnx - name: openpilot compile3 0.9.9 dmonitoring - run: PYTHONPATH="." ASSERT_MIN_STEP_TIME=12 QCOM=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/dmonitoring_model.onnx + run: PYTHONPATH="." ASSERT_MIN_STEP_TIME=12 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/dmonitoring_model.onnx - name: openpilot compile3 Space Lab policy + vision run: | - PYTHONPATH="." ASSERT_MIN_STEP_TIME=5 QCOM=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/22aec22a10ce09384d4a4af2a0bbff08d54af7e0c888503508f356fae4ff0e29 - PYTHONPATH="." ASSERT_MIN_STEP_TIME=26 QCOM=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/c824f68646a3b94f117f01c70dc8316fb466e05fbd42ccdba440b8a8dc86914b + PYTHONPATH="." ASSERT_MIN_STEP_TIME=5 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/22aec22a10ce09384d4a4af2a0bbff08d54af7e0c888503508f356fae4ff0e29 + PYTHONPATH="." ASSERT_MIN_STEP_TIME=26 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/c824f68646a3b94f117f01c70dc8316fb466e05fbd42ccdba440b8a8dc86914b - name: benchmark MobileNetV2 on DSP run: | # generate quantized weights diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9e6ab355ae..76619c80e1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -374,15 +374,13 @@ jobs: llvm: 'true' - name: Test openpilot model kernel count and gate usage run: | - ALLOWED_KERNEL_COUNT=190 ALLOWED_READ_IMAGE=2081 ALLOWED_GATED_READ_IMAGE=28 FLOAT16=0 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx - - name: Test openpilot alt model correctness (float32) - run: FLOAT16=0 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx - - name: Test openpilot fastvits model correctness (float32) - run: FLOAT16=0 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx - # - name: Test openpilot simple_plan vision model correctness (float32) - # run: FLOAT16=0 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/35ff4f4577002f2685e50c8346addae33fe8da27a41dd4d6a0f14d1f4b1af81b - - name: Test openpilot LLVM compile - run: CPU=1 CPU_LLVM=1 LLVMOPT=1 JIT=2 BEAM=0 IMAGE=0 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx + ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1452 ALLOWED_GATED_READ_IMAGE=122 FLOAT16=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 + - name: Test openpilot CL compile fp16 + run: FLOAT16=1 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 + - name: Test openpilot CL compile fp32 (test correctness) + run: DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx + - name: Test openpilot LLVM compile fp16 + run: FLOAT16=1 CPU=1 CPU_LLVM=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 - name: Run process replay tests uses: ./.github/actions/process-replay diff --git a/examples/openpilot/compile3.py b/examples/openpilot/compile3.py index c89920d83b..02b8496b26 100644 --- a/examples/openpilot/compile3.py +++ b/examples/openpilot/compile3.py @@ -1,9 +1,5 @@ import os, sys, pickle, time, re import numpy as np -if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1" -if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2" -if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1" -if "JIT_BATCH_SIZE" not in os.environ: os.environ["JIT_BATCH_SIZE"] = "0" from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters, Device, dtypes from tinygrad.helpers import DEBUG, getenv @@ -21,11 +17,14 @@ def compile(onnx_file): input_shapes = {name: spec.shape for name, spec in run_onnx.graph_inputs.items()} input_types = {name: spec.dtype for name, spec in run_onnx.graph_inputs.items()} + # Float inputs and outputs to tinyjits for openpilot are always float32 + # TODO this seems dumb input_types = {k:(dtypes.float32 if v is dtypes.float16 else v) for k,v in input_types.items()} Tensor.manual_seed(100) - new_inputs = {k:Tensor.randn(*shp, dtype=input_types[k]).mul(8).realize() for k,shp in sorted(input_shapes.items())} - new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()} + inputs = {k:Tensor(Tensor.randn(*shp, dtype=input_types[k]).mul(8).realize().numpy(), device='NPY') for k,shp in sorted(input_shapes.items())} + if not getenv("NPY_IMG"): + inputs = {k:Tensor(v.numpy(), device=Device.DEFAULT).realize() if 'img' in k else v for k,v in inputs.items()} print("created tensors") run_onnx_jit = TinyJit(lambda **kwargs: @@ -33,8 +32,6 @@ def compile(onnx_file): for i in range(3): GlobalCounters.reset() print(f"run {i}") - inputs = {**{k:v.clone() for k,v in new_inputs.items() if 'img' in k}, - **{k:Tensor(v, device="NPY").realize() for k,v in new_inputs_numpy.items() if 'img' not in k}} with Context(DEBUG=max(DEBUG.value, 2 if i == 2 else 1)): ret = run_onnx_jit(**inputs).numpy() # copy i == 1 so use of JITBEAM is okay @@ -69,14 +66,9 @@ def compile(onnx_file): print(f"mdl size is {mdl_sz/1e6:.2f}M") print(f"pkl size is {pkl_sz/1e6:.2f}M") print("**** compile done ****") - return test_val + return inputs, test_val -def test_vs_compile(run, new_inputs, test_val=None): - new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()} - - # create fake "from_blob" tensors for the inputs, and wrapped NPY tensors for the numpy inputs (these have the same underlying memory) - inputs = {**{k:v for k,v in new_inputs.items() if 'img' in k}, - **{k:Tensor(v, device="NPY").realize() for k,v in new_inputs_numpy.items() if 'img' not in k}} +def test_vs_compile(run, inputs, test_val=None): # run 20 times step_times = [] @@ -93,68 +85,49 @@ def test_vs_compile(run, new_inputs, test_val=None): min_time = min(step_times) assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms" - print(out, val.shape, val.dtype) if test_val is not None: np.testing.assert_equal(test_val, val) print("**** test done ****") # test that changing the numpy changes the model outputs - if any([x.device == 'NPY' for x in inputs.values()]): - for v in new_inputs_numpy.values(): v *= 2 - out = run(**inputs) - changed_val = out.numpy() - np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, val, changed_val) + inputs_2x = {k: Tensor(v.numpy()*2, device=v.device) for k,v in inputs.items()} + out = run(**inputs_2x) + changed_val = out.numpy() + np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, val, changed_val) return val -def test_vs_onnx(new_inputs, test_val, onnx_file, ort=False): - new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()} +def test_vs_onnx(new_inputs, test_val, onnx_file, tol): + import onnxruntime as ort + + onnx_inputs = {k:v.numpy() for k,v in new_inputs.items()} onnx_model = onnx.load(onnx_file) - timings = [] - if ort: - # test with onnxruntime - import onnxruntime as ort - onnx_session = ort.InferenceSession(onnx_file) - for _ in range(1 if test_val is not None else 5): - st = time.perf_counter() - onnx_output = onnx_session.run([onnx_model.graph.output[0].name], {k:v.astype(np.float16) for k,v in new_inputs_numpy.items()}) - timings.append(time.perf_counter() - st) - new_torch_out = onnx_output[0] - else: - # test with torch - import torch - from onnx2torch import convert - inputs = {k.name:new_inputs_numpy[k.name] for k in onnx_model.graph.input} - torch_model = convert(onnx_model).float() - with torch.no_grad(): - for _ in range(1 if test_val is not None else 5): - st = time.perf_counter() - torch_out = torch_model(*[torch.tensor(x) for x in inputs.values()]) - timings.append(time.perf_counter() - st) - new_torch_out = torch_out.numpy() + ORT_TO_NP_DTYPES: dict[str, np.dtype] = { + 'tensor(float)': np.dtype('float32'), + 'tensor(float16)': np.dtype('float16'), + 'tensor(uint8)': np.dtype('uint8'), + } - if test_val is not None: - np.testing.assert_allclose(new_torch_out.reshape(test_val.shape), test_val, atol=1e-4, rtol=1e-2) - print("test vs onnx passed") + timings = [] + onnx_session = ort.InferenceSession(onnx_file) + onnx_types = {x.name: ORT_TO_NP_DTYPES[x.type] for x in onnx_session.get_inputs()} + onnx_inputs = {k:onnx_inputs[k].astype(onnx_types[k]) for k in onnx_inputs} + + for _ in range(1 if test_val is not None else 5): + st = time.perf_counter() + onnx_output = onnx_session.run([onnx_model.graph.output[0].name], onnx_inputs) + timings.append(time.perf_counter() - st) + + np.testing.assert_allclose(onnx_output[0].reshape(test_val.shape), test_val, atol=tol, rtol=tol) + print("test vs onnx passed") return timings if __name__ == "__main__": onnx_file = fetch(OPENPILOT_MODEL) - test_val = compile(onnx_file) if not getenv("RUN") else None + inputs, outputs = compile(onnx_file) with open(OUTPUT, "rb") as f: pickle_loaded = pickle.load(f) - # same randomness as compile - Tensor.manual_seed(100) - new_inputs = {nm:Tensor.randn(*st.shape, dtype=dtype).mul(8).realize() for nm, (st, _, dtype, _) in - sorted(zip(pickle_loaded.captured.expected_names, pickle_loaded.captured.expected_st_vars_dtype_device))} - - test_val = test_vs_compile(pickle_loaded, new_inputs, test_val) - if getenv("BENCHMARK"): - for be in ["torch", "ort"]: - try: - timings = test_vs_onnx(new_inputs, None, onnx_file, be=="ort") - print(f"timing {be}: {min(timings)*1000:.2f} ms") - except Exception as e: - print(f"{be} fail with {e}") - if not getenv("FLOAT16"): test_vs_onnx(new_inputs, test_val, onnx_file, getenv("ORT")) + test_vs_compile(pickle_loaded, inputs, outputs) + if not getenv("FLOAT16"): + test_vs_onnx(inputs, outputs, onnx_file, 1e-4)