move COMMUTATIVE flipping to symbolic (#7507)

* move COMMUTATIVE flipping to symbolic

it cannot go with TRANSCENDENTAL

* skip LLVM
This commit is contained in:
chenyu
2024-11-03 09:03:45 -05:00
committed by GitHub
parent 50ea2105e5
commit 4617c9a565
2 changed files with 13 additions and 2 deletions

View File

@@ -40,6 +40,17 @@ class TestTranscendentalMath(unittest.TestCase):
op[1](np.array([x], dtype=_to_np_dtype(dtypes.float16))),
atol=1e-2, rtol=5e-3) # exp can have bigger rtol
@unittest.skipIf(Device.DEFAULT=="LLVM", "FIXME: LLVM might change computer")
@given(strat.sampled_from([(dtypes.float64, 709.5), (dtypes.float32, 88.7), (dtypes.float16, 11)]))
def test_exp_near_inf(self, dtype_x):
# reordering compute might give inf result
dtype, x = dtype_x
if not is_dtype_supported(dtype): return
with Context(TRANSCENDENTAL=2):
y = Tensor([x], dtype=dtype).exp().numpy()
expected = np.exp(np.array([x], dtype=_to_np_dtype(dtype)))
np.testing.assert_allclose(y, expected, rtol=5e-3)
class TestTranscendentalSchedule(unittest.TestCase):
# w/ payne_hanek_reduction (fp32)
def test_transcendental_sin_fusion(self):

View File

@@ -1039,8 +1039,6 @@ symbolic_simple = PatternMatcher([
# ** constant folding **
(Pat(Ops.ALU, name="root", src=Pat((Ops.VCONST, Ops.CONST))),
lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src], truncate_output=False))),
# ** COMMUTATIVE flipping **
*[(Pat(Ops.ALU, arg=op, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None) for op in COMMUTATIVE],
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
(Pat.var('x', dtype=dtypes.bool) * Pat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
(Pat.var('x', dtype=dtypes.bool) + Pat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
@@ -1051,6 +1049,8 @@ symbolic_simple = PatternMatcher([
])
symbolic = symbolic_simple+PatternMatcher([
# ** COMMUTATIVE flipping **
*[(Pat(Ops.ALU, arg=op, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None) for op in COMMUTATIVE],
# group like
((Pat.var("x") + Pat.var("y")) + Pat.var("x") * Pat.cvar("c"), lambda x,y,c: (x+x*c)+y),
# ** combine terms **