mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
1/(x*c) -> (1/c)*(1/x) (#11491)
example: 2*(2*a).reciprocal() -> a.reciprocal() # TODO: bounds for reciprocal # TODO: should z3 work?
This commit is contained in:
@@ -121,7 +121,7 @@ class TestSoftmaxFusion(unittest.TestCase):
|
||||
out = (inp / div).reshape(32, 10)
|
||||
out.realize()
|
||||
|
||||
np.testing.assert_allclose(sout.numpy(), out.numpy())
|
||||
np.testing.assert_allclose(sout.numpy(), out.numpy(), atol=3e-7)
|
||||
|
||||
def test_softmax(self):
|
||||
# this is the softmax from scaled_dot_product_attention
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest, pickle, functools
|
||||
import unittest, pickle, functools, math
|
||||
import z3
|
||||
|
||||
from tinygrad.dtype import dtypes, ConstType
|
||||
@@ -29,16 +29,17 @@ class TestSymbolicPickle(unittest.TestCase):
|
||||
def test_pickle_variable_times_2(self): self._test_pickle_unpickle(Variable("a", 3, 8)*2)
|
||||
|
||||
class TestSymbolic(unittest.TestCase):
|
||||
def helper_test_variable(self, v, n, m, s):
|
||||
def helper_test_variable(self, v, n, m, s, test_z3:bool=True):
|
||||
rendered, nmin, nmax = render(v)
|
||||
if isinstance(s, tuple): self.assertIn(rendered, s)
|
||||
else: self.assertEqual(rendered, s)
|
||||
self.assertEqual(nmin, n)
|
||||
self.assertEqual(nmax, m)
|
||||
solver = z3.Solver()
|
||||
z3_sink = graph_rewrite(v.sink(v.simplify()), z3_renderer, ctx=(solver, {}))
|
||||
expr, epxr_simplified = z3_sink.src[0].arg, z3_sink.src[1].arg
|
||||
self.assertEqual(solver.check(expr != epxr_simplified), z3.unsat, "simplified expression not equal to original")
|
||||
if test_z3:
|
||||
solver = z3.Solver()
|
||||
z3_sink = graph_rewrite(v.sink(v.simplify()), z3_renderer, ctx=(solver, {}))
|
||||
expr, epxr_simplified = z3_sink.src[0].arg, z3_sink.src[1].arg
|
||||
self.assertEqual(solver.check(expr != epxr_simplified), z3.unsat, "simplified expression not equal to original")
|
||||
|
||||
def test_cmp_simple(self):
|
||||
self.helper_test_variable(Variable("a", 3, 8) < 4, 0, 1, "(a<4)")
|
||||
@@ -672,6 +673,12 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(numerator, 3, 390, "(a*((a*4)+-1))")
|
||||
self.helper_test_variable((numerator//denominator)<=0, 1, 1, "True")
|
||||
|
||||
def test_const_reciprocal(self):
|
||||
a = Variable("a", 1, 10, dtypes.float)
|
||||
# TODO: bounds for reciprocal
|
||||
# TODO: should z3 work?
|
||||
self.helper_test_variable(2*(2*a).reciprocal(), -math.inf, math.inf, "(1/a)", test_z3=False)
|
||||
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
MIN, MAX = 0, 10
|
||||
|
||||
@@ -941,6 +941,7 @@ renderer = PatternMatcher([
|
||||
(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
||||
#(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}[={x.src[1].arg}]")),
|
||||
(UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
|
||||
(UPat(Ops.RECIP, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(1/{x.src[0].arg})")),
|
||||
(UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
|
||||
(UPat(Ops.MULACC, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
|
||||
(UPat(Ops.WHERE, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
|
||||
|
||||
@@ -467,6 +467,7 @@ sym = symbolic_flat+PatternMatcher([
|
||||
if any(x.op in REMOVE_FROM_SINK for x in root.src) else None),
|
||||
((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c
|
||||
((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()),
|
||||
((UPat.var("x") * UPat.cvar("c")).reciprocal(), lambda x,c: x.reciprocal()*c.reciprocal()), # 1/(x*c) -> (1/c)*(1/x)
|
||||
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x)
|
||||
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")*UPat.var("y")), lambda x,y,d: y*(1-d)),
|
||||
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")+UPat.var("y")), lambda x,y,d: (1-d)+x*y),
|
||||
|
||||
Reference in New Issue
Block a user