diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d48489c8ca..5fb815d7ed 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -142,8 +142,8 @@ jobs: mypy -c "from tinygrad.tensor import Tensor; print(Tensor([1,2,3,4,5]))" - name: Test DEBUG run: DEBUG=100 python3 -c "from tinygrad import Tensor; N = 1024; a, b = Tensor.rand(N, N), Tensor.rand(N, N); c = (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2); print((c.numpy() - (a.numpy() @ b.numpy())).mean())" - - name: Repo line count <9600 lines - run: MAX_LINE_COUNT=9600 python sz.py + - name: Repo line count <9700 lines + run: MAX_LINE_COUNT=9700 python sz.py testopencl: strategy: @@ -197,7 +197,7 @@ jobs: - if: ${{ matrix.task == 'optimage' }} name: Test openpilot model compile and size run: | - PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=364 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py + PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=356 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000' - if: ${{ matrix.task == 'optimage' }} name: Test openpilot model correctness (float32) diff --git a/test/unit/test_image_valid.py b/test/unit/test_image_valid.py index 71808bfddc..afbf4cd2e3 100644 --- a/test/unit/test_image_valid.py +++ b/test/unit/test_image_valid.py @@ -15,6 +15,7 @@ def render(image_shape, valid:UOp, idx:UOp) -> str: class TestRenderer(OpenCLRenderer): code_for_op = {**OpenCLRenderer().code_for_op, BinaryOps.IDIV: lambda a,b,dtype: f"({a}//{b})"} fxn = TestRenderer().render("", uops) + # print(fxn) return fxn.split("float4 val0 = ")[1].split(";")[0] def Variable(expr, nmax): @@ -60,5 +61,30 @@ class TestValidSimplification(unittest.TestCase): self.assertEqual(render((10, 10, 4), (gidx1).lt(5), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1+5))), "read_imagef(data0, smp, (int2)(gidx0,(gidx1+5)))") + def test_simplify1(self): + # idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1) + gidx = Variable("gidx", 512) + valid = gidx.lt(488) & (-gidx).lt(-479) + idx = ((gidx*3+18)%26, (gidx*3+18)//26-56) + # alu0 is ((gidx*3)+18) + self.assertEqual(render((1, 26, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)), + "read_imagef(data0, smp, (int2)(((gidx*3)+(-1438)),0))") + + def test_simplify2(self): + # from GPU=1 DEBUG=4 FORWARD_ONLY=1 IMAGE=2 python3 test/test_ops.py TestOps.test_simple_padding_conv2d + lidx = Variable("lidx", 4) + valid = lidx.lt(3) & (-lidx).lt(0) + idx = ((lidx+1)%2, (lidx+1)//2-1) + self.assertEqual(render((1, 2, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)), + "read_imagef(data0, smp, (int2)((lidx+(-1)),0))") + + def test_simplify3(self): + # from openpilot + idx0 = Variable("idx0", 265) + valid = (-idx0).lt(-200) + idx = ((idx0+55)%64, (idx0+55)//64-4) + self.assertEqual(render((1, 64, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)), + "read_imagef(data0, smp, (int2)((idx0+(-201)),0))") + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 8bb897efd0..665ac8b6f6 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -160,6 +160,8 @@ def fold_unrolled_divs(divs:UOp, c:UOp): def simplify_valid_image_load(load:UOp, buf:UOp): if not isinstance(buf_dtype:=buf.dtype, ImageDType) or len(load.src) < 4: return None buf, idx, invalid_val, valid = load.src + + # TODO: merge this into the generic case drop = False for stmt in _get_chain(valid, BinaryOps.AND): if stmt.op is UOps.ALU and stmt.arg is BinaryOps.CMPLT and stmt.src[1].op is UOps.CONST: @@ -172,6 +174,44 @@ def simplify_valid_image_load(load:UOp, buf:UOp): new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s is not stmt]) else None return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, new_valid)) if new_valid else UOp(UOps.LOAD, load.dtype, (buf, idx)) + # We want to simplify expressions like (X*c+d)%m in the idx, with optional *c and +d. m is the total length of the row. + # If the contraints in valid implies that it "spans" the whole row, and we can rewrite it to X*c+k for some k, and drop the valid. + m = mod.src[1].arg if (mod:=idx.src[0]).op is UOps.ALU and mod.arg is BinaryOps.MOD and mod.src[1].op is UOps.CONST else None + if not m or m != buf_dtype.shape[1]: return None + d = add.src[1].arg if (add:=mod.src[0]).op is UOps.ALU and add.arg is BinaryOps.ADD and add.src[1].op is UOps.CONST else 0 + mul = add.src[0] if d else add # + d is optional + c = mul.src[1].arg if mul.op is UOps.ALU and mul.arg is BinaryOps.MUL and mul.src[1].op is UOps.CONST else 1 + X = mul.src[0] if c != 1 else mul # * c is optional + + lower, upper = X.vmin, X.vmax + drop_stmt = [] + + for stmt in _get_chain(valid, BinaryOps.AND): + if stmt.op is UOps.ALU and stmt.arg is BinaryOps.CMPLT and stmt.src[1].op is UOps.CONST: + if stmt.src[0].key == X.key: # X < c + upper = stmt.src[1].arg-1 + drop_stmt.append(stmt) + elif stmt.src[0].key == (-X).key: # -X < -c -> X > c + lower = -stmt.src[1].arg+1 + drop_stmt.append(stmt) + + new_indx0, new_indx1 = None, None + if (L:=(lower * c + d)) // m == (U:=(upper * c + d)) // m: # in the same row + if (L % m - c < 0) and (U % m + c >= m): # spans the whole row + new_indx0 = graph_rewrite(mul - ((L // m) * m - d), constant_folder) + + # Because (X * c + d) % m spans the whole row, (X * c + d) // m has a fixed value. + # check if idx1 is a div that can be simplified. idx1 = (add // m + e) + e = add1.src[1].arg if (add1:=idx.src[1]).op is UOps.ALU and add1.arg is BinaryOps.ADD and add1.src[1].op is UOps.CONST else 0 + div = add1.src[0] if e else add1 + m_ = div.src[1].arg if div.op is UOps.ALU and div.arg is BinaryOps.IDIV and div.src[1].op is UOps.CONST else None + if m_ == m and div.src[0] == add: new_indx1 = idx.src[1].const_like(L // m + e) + + if new_indx0 and new_indx1: + new_idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (new_indx0, new_indx1)) + new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None + return UOp(UOps.LOAD, load.dtype, (buf, new_idx, invalid_val, new_valid)) if new_valid else UOp(UOps.LOAD, load.dtype, (buf, new_idx)) + # ***** transcendental ***** @functools.lru_cache(None)