From c9a9631818245a9e23695326bae3fe748b4ebe28 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 21 Aug 2024 11:08:22 -0400 Subject: [PATCH] no UnaryOps.NEG in generated UOp patterns (#6209) * no UnaryOps.NEG in generated UOp patterns removed pattern `x * (-1) -> -x` and `x != True` * those are fine because NEG became CMPNE and True * fix sd validation L2 norm --- examples/sdxl.py | 2 +- examples/stable_diffusion.py | 2 +- test/test_linearizer_dumb.py | 2 +- test/test_linearizer_failures.py | 2 +- test/unit/test_uop_symbolic.py | 25 ++++++++++++------------- tinygrad/codegen/uopgraph.py | 3 --- tinygrad/ops.py | 2 +- 7 files changed, 17 insertions(+), 21 deletions(-) diff --git a/examples/sdxl.py b/examples/sdxl.py index cc45bd3704..aa1ad53e54 100644 --- a/examples/sdxl.py +++ b/examples/sdxl.py @@ -423,6 +423,6 @@ if __name__ == "__main__": if args.prompt == default_prompt and args.steps == 10 and args.seed == 0 and args.guidance == 6.0 and args.width == args.height == 1024 \ and not args.weights: ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "sdxl_seed0.png"))) - distance = (((x - ref_image).cast(dtypes.float) / ref_image.max())**2).mean().item() + distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item() assert distance < 4e-3, colored(f"validation failed with {distance=}", "red") print(colored(f"output validated with {distance=}", "green")) diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 8618ef5f4d..7c12727d61 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -289,6 +289,6 @@ if __name__ == "__main__": # validation! if args.prompt == default_prompt and args.steps == 5 and args.seed == 0 and args.guidance == 7.5: ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "stable_diffusion_seed0.png"))) - distance = (((x - ref_image).cast(dtypes.float) / ref_image.max())**2).mean().item() + distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item() assert distance < 3e-4, colored(f"validation failed with {distance=}", "red") print(colored(f"output validated with {distance=}", "green")) diff --git a/test/test_linearizer_dumb.py b/test/test_linearizer_dumb.py index b12337abb0..111eae686a 100644 --- a/test/test_linearizer_dumb.py +++ b/test/test_linearizer_dumb.py @@ -75,7 +75,7 @@ class TestLinearizerDumb(unittest.TestCase): if_uops = [u for u in k.uops if u.op is UOps.IF] self.assertEqual(len(if_uops), 1) conditions = if_uops[0].src[0].sparents - self.assertLessEqual(len(conditions), 8) + self.assertLessEqual(len(conditions), 9) # this was a bug in embedding, someday we should fold this anyway def test_llama_embedding(self): diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 30a6cd6f8d..76289502e7 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -388,7 +388,7 @@ class TestLinearizerFailures(unittest.TestCase): ifs = [u for u in k.uops if u.op is UOps.IF] self.assertEqual(len(ifs), 1) #for st in k.uops.sink.src: self.assertEqual(len(st.src), 4) - self.assertLessEqual(len(ifs[0].src[0].sparents), 16) + self.assertLessEqual(len(ifs[0].src[0].sparents), 17) def test_failure_45(self): ast = LazyOp(MetaOps.KERNEL, arg=None, src=( diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index a4047ce6b7..179c677d5d 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -62,13 +62,13 @@ class TestSymbolic(unittest.TestCase): def test_cmp_simple(self): self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 4), 0, 1, "(a<4)") - self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, {"((a*-1)<-7)", "((-a)<(-7))"}) + self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, {"((a*-1)<-7)", "((a*(-1))<(-7))"}) def test_ge(self): self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 77), 0, 0, "0") self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 9), 0, 0, "0") - self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, {"((a*-1)<-7)", "((-a)<(-7))"}) - self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 4), 0, 1, {"((a*-1)<-3)", "((-a)<(-3))"}) + self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, {"((a*-1)<-7)", "((a*(-1))<(-7))"}) + self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 4), 0, 1, {"((a*-1)<-3)", "((a*(-1))<(-3))"}) self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 3), 1, 1, "1") self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 2), 1, 1, "1") @@ -136,7 +136,7 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(a+a*3, 0, 8*4, "(a*4)") def test_neg(self): - self.helper_test_variable(-Variable("a", 0, 8), -8, 0, {"(a*-1)", "(-a)"}) + self.helper_test_variable(-Variable("a", 0, 8), -8, 0, {"(a*-1)", "(a*(-1))"}) def test_add_1(self): self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, {"(1+a)", "(a+1)"}) @@ -150,7 +150,6 @@ class TestSymbolic(unittest.TestCase): def test_sub_num_1(self): self.helper_test_variable(Variable("a", 0, 8)-NumNode(1), -1, 7, {"(-1+a)", "(a+(-1))"}) - @unittest.expectedFailure def test_sub_self(self): a = Variable("a", 0, 8) self.helper_test_variable(a*3-a, 0, 16, "(a*2)") @@ -333,9 +332,9 @@ class TestSymbolic(unittest.TestCase): # TODO: simplify the expression def test_div_neg_cancel(self): - self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 1, 26, "((((-idx)+199)//(-4))+50)") - self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "((((-idx)+200)//(-4))+50)") - self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, 0, 25, "((((-idx)+201)//(-4))+50)") + self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 1, 26, "((((idx*(-1))+199)//(-4))+50)") + self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "((((idx*(-1))+200)//(-4))+50)") + self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, 0, 25, "((((idx*(-1))+201)//(-4))+50)") def test_sum_div_big_const(self): gidx0 = Variable("gidx0", 0, 24) @@ -376,10 +375,10 @@ class TestSymbolic(unittest.TestCase): def test_div_neg_all_range(self): gidx = Variable("gidx", 0, 124) lidx = Variable("lidx", 0, 7) - self.helper_test_variable((-gidx*8-lidx+999)//-4 + 250, 1, 250, "(((((-gidx)*8)+(-lidx)+999)//(-4))+250)") - self.helper_test_variable((-gidx*8-lidx+1000)//-4 + 250, 0, 250, "(((((-gidx)*8)+(-lidx)+1000)//(-4))+250)") - self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, 0, 250, "(((((-gidx)*8)+(-lidx)+1001)//(-4))+250)") - self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, 0, 250, "(((((-gidx)*8)+(-lidx)+1002)//(-4))+250)") + self.helper_test_variable((-gidx*8-lidx+999)//-4 + 250, 1, 250, "((((gidx*(-8))+(lidx*(-1))+999)//(-4))+250)") + self.helper_test_variable((-gidx*8-lidx+1000)//-4 + 250, 0, 250, "((((gidx*(-8))+(lidx*(-1))+1000)//(-4))+250)") + self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, 0, 250, "((((gidx*(-8))+(lidx*(-1))+1001)//(-4))+250)") + self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, 0, 250, "((((gidx*(-8))+(lidx*(-1))+1002)//(-4))+250)") # NOTE: tests are not correct in symbolic def test_div_neg_then_neg(self): @@ -389,7 +388,7 @@ class TestSymbolic(unittest.TestCase): alu2 = -lidx0-lidx1 self.helper_test_variable((((alu2+14)//(-32))+4), 4, 4, "4") self.helper_test_variable(-(((alu2+14)//(-32))+4), -4, -4, "(-4)") - self.helper_test_variable((((alu2+134)//(-32))+4), 0, 1, "((((-lidx0)+(-lidx1)+134)//(-32))+4)") + self.helper_test_variable((((alu2+134)//(-32))+4), 0, 1, "((((lidx0*(-1))+(lidx1*(-1))+134)//(-32))+4)") self.helper_test_variable((((alu2+142)//(-32))+4), 0, 0, "0") self.helper_test_variable((((alu2+150)//(-32))+4), 0, 0, "0") self.helper_test_variable((((alu2+158)//(-32))+4), 0, 0, "0") diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index aaa0110d5d..20582d6f3b 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -272,7 +272,6 @@ constant_folder = PatternMatcher([ (-(-NOp.var('x')), lambda x: x), # -(-x) -> x (NOp.var('x') + 0, lambda x: x), # x+0 -> x (NOp.var('x') * 1, lambda x: x), # x*1 -> x - (NOp.var('x') * -1, lambda x: -x), # x*-1 -> -x (NOp.var('x') // NOp.var('x'), lambda x: x.const(1)), # x//x -> 1 (NOp.var('x') // 1, lambda x: x), # x//1 -> x (NOp.var('x') // -1, lambda x: -x), # x//-1 -> -x @@ -337,8 +336,6 @@ constant_folder = PatternMatcher([ (NOp.var("x") + NOp.var("x") * NOp.cvar("c0"), lambda x,c0: x*(c0.arg+1)), # x!=0 -> (bool)x (NOp.var("x").ne(0), lambda x: x.cast(dtypes.bool)), - # bool != 1 -> not bool - (NOp.var("x", dtype=dtypes.bool).ne(1), lambda x: -x), # TODO: can do the invert of this (flip alt/load) when we fix double ops (NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.var("gate").where(NOp.var("alt"), NOp.load(NOp.var("buf"), NOp.var("idx")))), lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)), diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 323a43b490..a0d6edaf7b 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -87,7 +87,7 @@ class UOp: def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,)) def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,)) def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar() if self.dtype is not None else None, (self,), i) - def __neg__(self): return self.alu(UnaryOps.NEG) + def __neg__(self): return self*(-1) if self.dtype != dtypes.bool else self.ne(True) def __add__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x)) def __radd__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x)) def __sub__(self, x): return self.alu(BinaryOps.ADD, self.ufix(-x))