mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
455 -> 364 valids. generalize `idx < image bound` to `idx < image bound + c` for some `c`
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -199,7 +199,7 @@ jobs:
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot model compile and size
|
||||
run: |
|
||||
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=455 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
|
||||
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=364 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
|
||||
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot model correctness (float32)
|
||||
|
||||
@@ -51,5 +51,14 @@ class TestValidSimplification(unittest.TestCase):
|
||||
self.assertEqual(render((20, 10, 4), (gidx1).lt(10), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1))),
|
||||
"((gidx1<10)?read_imagef(data0, smp, (int2)(gidx0,gidx1)):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
|
||||
def test_generic_idx_lt_bound(self):
|
||||
# (idx1 < image_bound - c) ? (..., idx1 + c) : 0 can drop the valid
|
||||
gidx0 = Variable("gidx0", 32)
|
||||
gidx1 = Variable("gidx1", 32)
|
||||
self.assertEqual(render((10, 10, 4), (gidx1).lt(8), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1+2))),
|
||||
"read_imagef(data0, smp, (int2)(gidx0,(gidx1+2)))")
|
||||
self.assertEqual(render((10, 10, 4), (gidx1).lt(5), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1+5))),
|
||||
"read_imagef(data0, smp, (int2)(gidx0,(gidx1+5)))")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -162,11 +162,11 @@ def simplify_valid_image_load(load:UOp, buf:UOp):
|
||||
buf, idx, invalid_val, valid = load.src
|
||||
drop = False
|
||||
for stmt in _get_chain(valid, BinaryOps.AND):
|
||||
if stmt.op is UOps.ALU and stmt.arg is BinaryOps.CMPLT:
|
||||
if stmt.op is UOps.ALU and stmt.arg is BinaryOps.CMPLT and stmt.src[1].op is UOps.CONST:
|
||||
# valid: A*(-1) < c, idx: (..., A-1+c) -> okay to drop valid because A*(-1) >= c -> A <= -c -> A-1+c <= -1 is out of bound
|
||||
if graph_rewrite(stmt.src[0]*(-1)-1+stmt.src[1].arg, constant_folder).key == idx.src[1].key: drop = True
|
||||
# valid: A < image bound, idx: (..., A) -> okay to drop valid
|
||||
elif stmt.src[1].arg == buf_dtype.shape[0] and idx.src[1].key == stmt.src[0].key: drop = True
|
||||
# valid: A < image bound - c, idx: (..., A+c) -> okay to drop valid because A >= bound - c -> A + c >= bound is out of bound
|
||||
elif graph_rewrite(stmt.src[0]+(buf_dtype.shape[0]-stmt.src[1].arg), constant_folder).key == idx.src[1].key: drop = True
|
||||
|
||||
if drop:
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s is not stmt]) else None
|
||||
|
||||
Reference in New Issue
Block a user