From 596f41eb46859cc8368a77144d2e2988c991c0ef Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 16 Sep 2024 22:54:07 -0400 Subject: [PATCH] simple drop image valid case (#6548) * simple drop image valid case started unit test, 530 -> 473 valids * cleanup --- .github/workflows/test.yml | 2 +- test/unit/test_image_valid.py | 34 ++++++++++++++++++++++++++++++++++ tinygrad/codegen/uopgraph.py | 12 ++++++++++++ 3 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 test/unit/test_image_valid.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9ada02057e..6c844b77a4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -199,7 +199,7 @@ jobs: - if: ${{ matrix.task == 'optimage' }} name: Test openpilot model compile and size run: | - PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=530 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py + PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=473 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000' - if: ${{ matrix.task == 'optimage' }} name: Test openpilot model correctness (float32) diff --git a/test/unit/test_image_valid.py b/test/unit/test_image_valid.py new file mode 100644 index 0000000000..732f3a79d7 --- /dev/null +++ b/test/unit/test_image_valid.py @@ -0,0 +1,34 @@ +import unittest + +from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite +from tinygrad.dtype import dtypes +from tinygrad.ops import UOp, UOps, BinaryOps + +def render(image_shape, valid:UOp, idx:UOp) -> str: + uops = linearize_uop(full_graph_rewrite(UOp(UOps.LOAD, dtypes.float.vec(4), ( + UOp(UOps.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0), + idx, + UOp(UOps.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0),)*4), + valid + )).sink())) + from tinygrad.renderer.cstyle import OpenCLRenderer + class TestRenderer(OpenCLRenderer): + code_for_op = {**OpenCLRenderer().code_for_op, BinaryOps.IDIV: lambda a,b,dtype: f"({a}//{b})"} + fxn = TestRenderer().render("", uops) + return fxn.split("float4 val0 = ")[1].split(";")[0] + +def Variable(expr, nmax): + return UOp(UOps.SPECIAL, dtypes.int, (), (expr, nmax)) + +class TestValidSimplification(unittest.TestCase): + def test_idx_lt_0(self): + # (idx1 * (-1) < 0) ? (..., idx1-1) : 0 can drop the valid + gidx0 = Variable("gidx0", 32) + gidx1 = Variable("gidx1", 32) + self.assertEqual( + render((10, 10, 4), (gidx1*(-1)).lt(0), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1-1))), + "read_imagef(data0, smp, (int2)(gidx0,(gidx1+(-1))))" + ) + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index e8da1fba1c..b0aa04c625 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -155,6 +155,16 @@ def fold_unrolled_divs(divs:UOp, c:UOp): seen_const.append(s0.src[1].arg) return ans if sorted(seen_const)==list(range(c.arg, c.arg+len(add_chain))) and ans is not None and (ans.vmin, ans.vmax)==(0, c.arg) else None +# ***** image load valid simplification ***** + +def simplify_valid_image_load(load:UOp, buf:UOp): + if not isinstance(buf.dtype, ImageDType) or len(load.src) < 4: return None + buf, idx, _, valid = load.src + if valid.op is UOps.ALU and valid.arg is BinaryOps.CMPLT: + if valid.src[1].arg == 0 and graph_rewrite(valid.src[0]*(-1)-1, constant_folder).key == idx.src[1].key: + # valid: A*(-1) < 0, idx: (..., A-1) -> okay to drop valid + return UOp(UOps.LOAD, dtype=load.dtype, src=(buf, idx)) + # ***** transcendental ***** @functools.lru_cache(None) @@ -494,6 +504,8 @@ reducer = PatternMatcher([ (UPat(UOps.STORE, name="root"), delete_redundant_gates), # late fixup of unfoldable image loads (UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load), + # image load valid simplification + (UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), simplify_valid_image_load), ]) no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.VCONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE), name="x"),