diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4138b17888..194a5f6743 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -377,7 +377,7 @@ jobs: llvm: 'true' - name: Test openpilot model kernel count and gate usage run: | - ALLOWED_KERNEL_COUNT=190 ALLOWED_READ_IMAGE=2041 ALLOWED_GATED_READ_IMAGE=33 FLOAT16=0 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx + ALLOWED_KERNEL_COUNT=190 ALLOWED_READ_IMAGE=2041 ALLOWED_GATED_READ_IMAGE=543 FLOAT16=0 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx - name: Test openpilot alt model correctness (float32) run: FLOAT16=0 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx - name: Test openpilot fastvits model correctness (float32) diff --git a/examples/openpilot/compile3.py b/examples/openpilot/compile3.py index 6624ce1c9f..c89920d83b 100644 --- a/examples/openpilot/compile3.py +++ b/examples/openpilot/compile3.py @@ -1,4 +1,4 @@ -import os, sys, pickle, time +import os, sys, pickle, time, re import numpy as np if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1" if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2" @@ -52,6 +52,8 @@ def compile(onnx_file): kernel_count += 1 read_image_count += ei.prg.p.src.count("read_image") gated_read_image_count += ei.prg.p.src.count("?read_image") + for v in [m.group(1) for m in re.finditer(r'(val\d+)\s*=\s*read_imagef\(', ei.prg.p.src)]: + if len(re.findall(fr'[\?\:]{v}\.[xyzw]', ei.prg.p.src)) > 0: gated_read_image_count += 1 print(f"{kernel_count=}, {read_image_count=}, {gated_read_image_count=}") if (allowed_kernel_count:=getenv("ALLOWED_KERNEL_COUNT", -1)) != -1: assert kernel_count == allowed_kernel_count, f"different kernels! {kernel_count=}, {allowed_kernel_count=}"