simpler can_pad (#10364)

* simpler can_pad [pr]

* 3 kernels

* tests

* less kernels
This commit is contained in:
qazal
2025-05-18 10:00:07 +03:00
committed by GitHub
parent c91f2c4580
commit 0294bfe507
5 changed files with 8 additions and 12 deletions

View File

@@ -436,7 +436,7 @@ jobs:
llvm: 'true'
- name: Test openpilot model kernel count and gate usage
run: |
PYTHONPATH="." ALLOWED_KERNEL_COUNT=209 ALLOWED_READ_IMAGE=2138 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
PYTHONPATH="." ALLOWED_KERNEL_COUNT=209 ALLOWED_READ_IMAGE=2137 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
- name: Test openpilot alt model correctness (float32)
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
- name: Test openpilot fastvits model correctness (float32)

View File

@@ -93,12 +93,11 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(z, 2))
self.assertEqual(z.item(), 32)
# TODO: same issue in precompute_freqs_cis
def test_push_pads_contiguous(self):
x = Tensor.full((4,1), 2.).contiguous()
y = Tensor.full((4,4), 4.).contiguous()
z = (x.reciprocal().expand(4,4)*y).pad((None, (0,1),)).sum()
run_schedule(check_schedule(z, 3, [x,y]))
run_schedule(check_schedule(z, 2, [x,y]))
self.assertEqual(z.item(), 32)
def test_rand(self):
@@ -1860,10 +1859,10 @@ class TestIndexing(unittest.TestCase):
args = {"dim":32 if CI else 128, "end":2048 if CI else 8192, "theta":10000}
fused = precompute_freqs_cis(**args)
with Context(FUSE_ARANGE=1):
run_schedule(check_schedule(fused, 5)) # TODO: this is too many
run_schedule(check_schedule(fused, 3))
if getenv("CHECK", 1):
ref = precompute_freqs_cis(**args)
run_schedule(check_schedule(ref, 5))
run_schedule(check_schedule(ref, 3))
np.testing.assert_equal(fused.numpy(), ref.numpy())
def test_fuse_assign_contiguous(self):

View File

@@ -413,7 +413,7 @@ class Kernel:
check(not self.vars, "does not work with symbolic shape")
check(axis < self.first_upcast, "cannot pad upcasted")
# ok to pad SUM if all parent ALU ops have f(0) = 0
if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r, {}, cache={}), f"cannot pad {r}")
if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r, {}), f"cannot pad {r}")
padded = False
for i,st in enumerate(self.sts):
if (s:=st.shape[axis]) == 1: continue # reduced

View File

@@ -131,7 +131,7 @@ def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
def realize_before_view(ctx:dict[UOp, None], view:UOp, tr:UOp) -> None:
st = unwrap(view.st)
# always realize unsafe pad ops before masked view
if any(v.mask is not None for v in st.views) and not can_pad(tr, ctx, cache=dict()): return realize(ctx, tr)
if any(v.mask is not None for v in st.views) and not can_pad(tr, ctx): return realize(ctx, tr)
# fold simple pads
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(tr.shape) and resolve(prod(tr.shape) >= prod([y-x for x,y in m])): return
# realize before expand

View File

@@ -188,11 +188,8 @@ class GroupOp:
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
def can_pad(u:UOp, edges:dict[UOp, None], cache:dict[UOp, None]) -> bool:
if u.op in GroupOp.UnsafePad: return False
if u in edges or u in cache: return True
cache[u] = None
return all(can_pad(x.base, edges, cache) for x in u.src)
def can_pad(root:UOp, edges:dict[UOp, None]) -> bool:
return all(u.op not in GroupOp.UnsafePad for u in root.toposort(gate=lambda x:x not in edges))
# With True as the default, this matches the old symbolic behavior
def resolve(x:UOp|bool, default:bool=True):