mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix openpilot gated reads (#12570)
* fix gated image counts * slice correctly
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -377,7 +377,7 @@ jobs:
|
||||
llvm: 'true'
|
||||
- name: Test openpilot model kernel count and gate usage
|
||||
run: |
|
||||
ALLOWED_KERNEL_COUNT=190 ALLOWED_READ_IMAGE=2041 ALLOWED_GATED_READ_IMAGE=543 FLOAT16=0 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
ALLOWED_KERNEL_COUNT=190 ALLOWED_READ_IMAGE=2041 ALLOWED_GATED_READ_IMAGE=41 FLOAT16=0 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
- name: Test openpilot alt model correctness (float32)
|
||||
run: FLOAT16=0 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
|
||||
- name: Test openpilot fastvits model correctness (float32)
|
||||
|
||||
@@ -509,6 +509,9 @@ sym = symbolic_flat+PatternMatcher([
|
||||
# fold gated LOAD/STORE
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0
|
||||
# # Where after gated load becomes alt value, TODO: this is sort of duplicated with rules in devectorizer
|
||||
(UPat.var("c1").where(UPat(Ops.LOAD, src=(UPat().index(UPat.var("c2").where(UPat(), invalid_pat)).or_casted(),), allow_any_len=True, name="l"), 0),
|
||||
lambda c1,c2,l,i: l.replace(src=(l.src[0],)+l.src[1:]) if any(c in list(c2.split_uop(Ops.AND)) for c in c1.split_uop(Ops.AND)) else None),
|
||||
# remove VECTORIZE from SINK/BARRIER. TODO: SINK/BARRIER are really the same thing at GLOBAL/LOCAL levels
|
||||
(UPat(Ops.BARRIER, name="root"),
|
||||
lambda root: UOp(Ops.BARRIER, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_BARRIER else (x,) for x in root.src)), root.arg)
|
||||
|
||||
Reference in New Issue
Block a user