other side of simple out of bound valid case (#6552)

462 -> 455
This commit is contained in:
chenyu
2024-09-16 23:57:15 -04:00
committed by GitHub
parent aeaf7894a7
commit 7c942418a1
3 changed files with 17 additions and 3 deletions

View File

@@ -199,7 +199,7 @@ jobs:
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model compile and size
run: |
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=462 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=455 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model correctness (float32)

View File

@@ -21,7 +21,7 @@ def Variable(expr, nmax):
return UOp(UOps.SPECIAL, dtypes.int, (), (expr, nmax))
class TestValidSimplification(unittest.TestCase):
def test_idx_lt_c(self):
def test_idx_neg_lt_c(self):
# (idx1 * (-1) < c) ? (..., idx1-1+c) : 0 can drop the valid
gidx0 = Variable("gidx0", 32)
gidx1 = Variable("gidx1", 32)
@@ -32,5 +32,15 @@ class TestValidSimplification(unittest.TestCase):
self.assertEqual(render((10, 10, 4), (gidx1*(-1)).lt(1), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1))),
"read_imagef(data0, smp, (int2)(gidx0,gidx1))")
def test_idx_lt_bound(self):
# (idx1 < image_bound) ? (..., idx1) : 0 can drop the valid
gidx0 = Variable("gidx0", 32)
gidx1 = Variable("gidx1", 32)
self.assertEqual(render((10, 10, 4), (gidx1).lt(10), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1))),
"read_imagef(data0, smp, (int2)(gidx0,gidx1))")
# 10x20 image, not out of bound
self.assertEqual(render((20, 10, 4), (gidx1).lt(10), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1))),
"((gidx1<10)?read_imagef(data0, smp, (int2)(gidx0,gidx1)):(float4)(0.0f,0.0f,0.0f,0.0f))")
if __name__ == '__main__':
unittest.main()

View File

@@ -158,13 +158,17 @@ def fold_unrolled_divs(divs:UOp, c:UOp):
# ***** image load valid simplification *****
def simplify_valid_image_load(load:UOp, buf:UOp):
if not isinstance(buf.dtype, ImageDType) or len(load.src) < 4: return None
if not isinstance(buf_dtype:=buf.dtype, ImageDType) or len(load.src) < 4: return None
buf, idx, _, valid = load.src
if valid.op is UOps.ALU and valid.arg is BinaryOps.CMPLT:
if graph_rewrite(valid.src[0]*(-1)-1+valid.src[1].arg, constant_folder).key == idx.src[1].key:
# valid: A*(-1) < c, idx: (..., A-1+c) -> okay to drop valid because A*(-1) >= c -> A <= -c -> A-1+c <= -1 is out of bound
return UOp(UOps.LOAD, dtype=load.dtype, src=(buf, idx))
if valid.src[1].arg == buf_dtype.shape[0] and idx.src[1].key == valid.src[0].key:
# valid: A < image bound, idx: (..., A) -> okay to drop valid
return UOp(UOps.LOAD, dtype=load.dtype, src=(buf, idx))
# ***** transcendental *****
@functools.lru_cache(None)