mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
move COMMUTATIVE flipping to symbolic (#7507)
* move COMMUTATIVE flipping to symbolic it cannot go with TRANSCENDENTAL * skip LLVM
This commit is contained in:
@@ -40,6 +40,17 @@ class TestTranscendentalMath(unittest.TestCase):
|
||||
op[1](np.array([x], dtype=_to_np_dtype(dtypes.float16))),
|
||||
atol=1e-2, rtol=5e-3) # exp can have bigger rtol
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT=="LLVM", "FIXME: LLVM might change computer")
|
||||
@given(strat.sampled_from([(dtypes.float64, 709.5), (dtypes.float32, 88.7), (dtypes.float16, 11)]))
|
||||
def test_exp_near_inf(self, dtype_x):
|
||||
# reordering compute might give inf result
|
||||
dtype, x = dtype_x
|
||||
if not is_dtype_supported(dtype): return
|
||||
with Context(TRANSCENDENTAL=2):
|
||||
y = Tensor([x], dtype=dtype).exp().numpy()
|
||||
expected = np.exp(np.array([x], dtype=_to_np_dtype(dtype)))
|
||||
np.testing.assert_allclose(y, expected, rtol=5e-3)
|
||||
|
||||
class TestTranscendentalSchedule(unittest.TestCase):
|
||||
# w/ payne_hanek_reduction (fp32)
|
||||
def test_transcendental_sin_fusion(self):
|
||||
|
||||
@@ -1039,8 +1039,6 @@ symbolic_simple = PatternMatcher([
|
||||
# ** constant folding **
|
||||
(Pat(Ops.ALU, name="root", src=Pat((Ops.VCONST, Ops.CONST))),
|
||||
lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src], truncate_output=False))),
|
||||
# ** COMMUTATIVE flipping **
|
||||
*[(Pat(Ops.ALU, arg=op, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None) for op in COMMUTATIVE],
|
||||
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
(Pat.var('x', dtype=dtypes.bool) * Pat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
|
||||
(Pat.var('x', dtype=dtypes.bool) + Pat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
|
||||
@@ -1051,6 +1049,8 @@ symbolic_simple = PatternMatcher([
|
||||
])
|
||||
|
||||
symbolic = symbolic_simple+PatternMatcher([
|
||||
# ** COMMUTATIVE flipping **
|
||||
*[(Pat(Ops.ALU, arg=op, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None) for op in COMMUTATIVE],
|
||||
# group like
|
||||
((Pat.var("x") + Pat.var("y")) + Pat.var("x") * Pat.cvar("c"), lambda x,y,c: (x+x*c)+y),
|
||||
# ** combine terms **
|
||||
|
||||
Reference in New Issue
Block a user