From d33311ebe0269561d0cd3ac17a8c7f22cf5cf8f5 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 6 Mar 2024 21:12:11 -0500 Subject: [PATCH] remove parens of ALU if it has associative property (#3635) need to remove SUB since it's possible to have (const - (const - const)) in test/test_ops.py::TestOps::test_cos, in which case cannot remove the parens of children --- test/test_linearizer_failures.py | 2 +- tinygrad/renderer/cstyle.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 137ee0bc9f..b26d1b1ada 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -94,7 +94,7 @@ class TestLinearizerFailures(unittest.TestCase): # fatal error: bracket nesting level exceeded maximum of 256 # note: use -fbracket-depth=N to increase maximum nesting level ast = helper_add_store(ast) - helper_test_lin(Linearizer(ast), opts, failed_platforms=["CLANG", "METAL", "GPU"]) + helper_test_lin(Linearizer(ast), opts, failed_platforms=[]) def test_failure_9(self): ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 2, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 3, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 2, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 4500, 0, 0, 0, 0, 0, 0, 900, 60, 12, 4, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 0, 0, 4500, 0, 0, 0, 0, 900, 60, 12, 4, 1), offset=0, mask=None, contiguous=True),)))) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 1426b0b2b1..c6d73c9315 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -127,10 +127,9 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str depth += 1 elif uop is UOps.ALU: # remove parens if ALU types are the same. TODO: can do more here - if vin[0].uop is UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.XOR}: - val = lang.code_for_op[args](strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]], dtype) - else: - val = lang.code_for_op[args](*[r[x] for x in vin] + [dtype]) + if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in vin] + else: operands = [r[v] for v in vin] + val = lang.code_for_op[args](*operands, dtype) assert child_count[u] != 0, f"childless ALU op found {u}" # TODO: fix index rendering issue. fix clang nested max macro issue if child_count[u] <= 1 and args != BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val