simpler grouptop in hcopt (#11219)

* simpler grouptop in hcopt

keep the only perf relevant conditions and the rest is handled by try except

* update openpilot read image count
This commit is contained in:
chenyu
2025-07-13 16:06:09 -04:00
committed by GitHub
parent 40847ca29c
commit 85ddd72038
3 changed files with 11 additions and 13 deletions

View File

@@ -467,7 +467,7 @@ jobs:
llvm: 'true'
- name: Test openpilot model kernel count and gate usage
run: |
PYTHONPATH="." ALLOWED_KERNEL_COUNT=209 ALLOWED_READ_IMAGE=2137 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2134 ALLOWED_GATED_READ_IMAGE=13 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
- name: Test openpilot alt model correctness (float32)
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
- name: Test openpilot fastvits model correctness (float32)

View File

@@ -54,11 +54,11 @@ def compile(onnx_file):
gated_read_image_count += ei.prg.p.src.count("?read_image")
print(f"{kernel_count=}, {read_image_count=}, {gated_read_image_count=}")
if (allowed_kernel_count:=getenv("ALLOWED_KERNEL_COUNT", -1)) != -1:
assert kernel_count <= allowed_kernel_count, f"too many kernels! {kernel_count=}, {allowed_kernel_count=}"
assert kernel_count == allowed_kernel_count, f"different kernels! {kernel_count=}, {allowed_kernel_count=}"
if (allowed_read_image:=getenv("ALLOWED_READ_IMAGE", -1)) != -1:
assert read_image_count == allowed_read_image, f"different read_image! {read_image_count=}, {allowed_read_image=}"
if (allowed_gated_read_image:=getenv("ALLOWED_GATED_READ_IMAGE", -1)) != -1:
assert gated_read_image_count <= allowed_gated_read_image, f"too many gated read_image! {gated_read_image_count=}, {allowed_gated_read_image=}"
assert gated_read_image_count == allowed_gated_read_image, f"different gated read_image! {gated_read_image_count=}, {allowed_gated_read_image=}"
with open(OUTPUT, "wb") as f:
pickle.dump(run_onnx_jit, f)

View File

@@ -1,6 +1,6 @@
import itertools
from tinygrad.opt.kernel import Kernel, Opt, OptOps, KernelOptError, AxisType
from tinygrad.helpers import getenv, DEBUG, all_int, prod, NOLOCALS
from tinygrad.helpers import getenv, DEBUG, prod, NOLOCALS
from tinygrad.dtype import ImageDType
from tinygrad.uop.ops import Ops, resolve
@@ -26,15 +26,13 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
if MV_ROWS_PER_THREAD > 1: k.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
return k.applied_opts
if k.opts.has_local and k.opts.has_shared and all_int(k.sts[0].shape[:k.first_reduce]):
# are we grouping? (requires local shape support)
if k.first_reduce <= 2 and k.first_reduce < k.shape_len and prod(k.sts[0].shape[:k.first_reduce]) <= 2048:
# TODO: use 1024 if it's allowed in a smarter way
for sz in ([256, 16] if prod(k.sts[0].shape[:k.first_reduce]) <= 32 else [16]):
try: # may fail due to excessive smem usage
k.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
break
except KernelOptError: pass
# are we grouping? (requires local shape support)
if resolve(prod(k.sts[0].shape[:k.first_reduce]) <= 2048, False):
for sz in [16]:
try: # may fail due to excessive smem usage
k.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
break
except KernelOptError: pass
# upcast float4 images
for buf_index,buf in enumerate(k.bufs):