move pow folding tests to test_schedule [pr] (#8955)

not really belongs to test_const_folding
This commit is contained in:
chenyu
2025-02-07 12:51:43 -05:00
committed by GitHub
parent c2b4c43edb
commit cfd28517df
2 changed files with 26 additions and 35 deletions

View File

@@ -13,7 +13,7 @@ from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType, ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views, GroupOp
from tinygrad.spec import type_verify, shape_spec
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
@@ -559,11 +559,30 @@ class TestSchedule(unittest.TestCase):
out = x.to('python')
check_schedule(out, 0, filter_sink=False)
def test_pow_const_tensor_simplified(self):
x = Tensor([1,2,3,4])
# NOTE: this does not test ** Tensor(2) is simpler in ast than ** Tensor(2.5)
out = x ** Tensor(2.0)
check_schedule(out, 1)
def _alu_from_tensor(self, t:Tensor):
s = [s for s in t.schedule() if s.ast.op is Ops.SINK]
self.assertEqual(len(s), 1)
return [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU]
def test_2_pow_is_exp2(self):
t = 2.0 ** Tensor([1.0, 2.0, 3.0])
self.assertEqual(self._alu_from_tensor(t), [Ops.EXP2])
def test_pow_05_is_sqrt(self):
t = Tensor([1.0, 2.0, 3.0]) ** 0.5
self.assertEqual(self._alu_from_tensor(t), [Ops.SQRT])
def test_pow_neg_05_is_rsqrt(self):
t = Tensor([1.0, 2.0, 3.0]) ** -0.5
self.assertEqual(self._alu_from_tensor(t), [Ops.RECIP, Ops.SQRT])
def test_pow_2_has_1_mul(self):
t = Tensor([1.0, 2.0, 3.0]) ** Tensor(2.0)
self.assertEqual(self._alu_from_tensor(t), [Ops.MUL])
def test_pow_8_has_3_muls(self):
t = Tensor([1.0, 2.0, 3.0]) ** 8
self.assertEqual(self._alu_from_tensor(t), [Ops.MUL, Ops.MUL, Ops.MUL])
def test_pow_const_tensor_to_zero(self):
x = Tensor([1,2,3,4])