diff --git a/test/test_softmax_fusion.py b/test/test_softmax_fusion.py index be8550d006..01d874cc5d 100644 --- a/test/test_softmax_fusion.py +++ b/test/test_softmax_fusion.py @@ -121,7 +121,7 @@ class TestSoftmaxFusion(unittest.TestCase): out = (inp / div).reshape(32, 10) out.realize() - np.testing.assert_allclose(sout.numpy(), out.numpy()) + np.testing.assert_allclose(sout.numpy(), out.numpy(), atol=3e-7) def test_softmax(self): # this is the softmax from scaled_dot_product_attention diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 8568d157f6..b121749b78 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -import unittest, pickle, functools +import unittest, pickle, functools, math import z3 from tinygrad.dtype import dtypes, ConstType @@ -29,16 +29,17 @@ class TestSymbolicPickle(unittest.TestCase): def test_pickle_variable_times_2(self): self._test_pickle_unpickle(Variable("a", 3, 8)*2) class TestSymbolic(unittest.TestCase): - def helper_test_variable(self, v, n, m, s): + def helper_test_variable(self, v, n, m, s, test_z3:bool=True): rendered, nmin, nmax = render(v) if isinstance(s, tuple): self.assertIn(rendered, s) else: self.assertEqual(rendered, s) self.assertEqual(nmin, n) self.assertEqual(nmax, m) - solver = z3.Solver() - z3_sink = graph_rewrite(v.sink(v.simplify()), z3_renderer, ctx=(solver, {})) - expr, epxr_simplified = z3_sink.src[0].arg, z3_sink.src[1].arg - self.assertEqual(solver.check(expr != epxr_simplified), z3.unsat, "simplified expression not equal to original") + if test_z3: + solver = z3.Solver() + z3_sink = graph_rewrite(v.sink(v.simplify()), z3_renderer, ctx=(solver, {})) + expr, epxr_simplified = z3_sink.src[0].arg, z3_sink.src[1].arg + self.assertEqual(solver.check(expr != epxr_simplified), z3.unsat, "simplified expression not equal to original") def test_cmp_simple(self): self.helper_test_variable(Variable("a", 3, 8) < 4, 0, 1, "(a<4)") @@ -672,6 +673,12 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(numerator, 3, 390, "(a*((a*4)+-1))") self.helper_test_variable((numerator//denominator)<=0, 1, 1, "True") + def test_const_reciprocal(self): + a = Variable("a", 1, 10, dtypes.float) + # TODO: bounds for reciprocal + # TODO: should z3 work? + self.helper_test_variable(2*(2*a).reciprocal(), -math.inf, math.inf, "(1/a)", test_z3=False) + class TestSymbolicNumeric(unittest.TestCase): def helper_test_numeric(self, f): MIN, MAX = 0, 10 diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 8935499a6d..36408af9c9 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -941,6 +941,7 @@ renderer = PatternMatcher([ (UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]), #(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}[={x.src[1].arg}]")), (UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")), + (UPat(Ops.RECIP, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(1/{x.src[0].arg})")), (UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")), (UPat(Ops.MULACC, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")), (UPat(Ops.WHERE, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")), diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 2c295d0af9..4bce9221fc 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -467,6 +467,7 @@ sym = symbolic_flat+PatternMatcher([ if any(x.op in REMOVE_FROM_SINK for x in root.src) else None), ((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c ((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()), + ((UPat.var("x") * UPat.cvar("c")).reciprocal(), lambda x,c: x.reciprocal()*c.reciprocal()), # 1/(x*c) -> (1/c)*(1/x) (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x) (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")*UPat.var("y")), lambda x,y,d: y*(1-d)), (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")+UPat.var("y")), lambda x,y,d: (1-d)+x*y),