mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-12 15:45:27 -05:00
move pow folding tests to test_schedule [pr] (#8955)
not really belongs to test_const_folding
This commit is contained in:
@@ -13,7 +13,7 @@ from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views, GroupOp
|
||||
from tinygrad.spec import type_verify, shape_spec
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
|
||||
@@ -559,11 +559,30 @@ class TestSchedule(unittest.TestCase):
|
||||
out = x.to('python')
|
||||
check_schedule(out, 0, filter_sink=False)
|
||||
|
||||
def test_pow_const_tensor_simplified(self):
|
||||
x = Tensor([1,2,3,4])
|
||||
# NOTE: this does not test ** Tensor(2) is simpler in ast than ** Tensor(2.5)
|
||||
out = x ** Tensor(2.0)
|
||||
check_schedule(out, 1)
|
||||
def _alu_from_tensor(self, t:Tensor):
|
||||
s = [s for s in t.schedule() if s.ast.op is Ops.SINK]
|
||||
self.assertEqual(len(s), 1)
|
||||
return [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU]
|
||||
|
||||
def test_2_pow_is_exp2(self):
|
||||
t = 2.0 ** Tensor([1.0, 2.0, 3.0])
|
||||
self.assertEqual(self._alu_from_tensor(t), [Ops.EXP2])
|
||||
|
||||
def test_pow_05_is_sqrt(self):
|
||||
t = Tensor([1.0, 2.0, 3.0]) ** 0.5
|
||||
self.assertEqual(self._alu_from_tensor(t), [Ops.SQRT])
|
||||
|
||||
def test_pow_neg_05_is_rsqrt(self):
|
||||
t = Tensor([1.0, 2.0, 3.0]) ** -0.5
|
||||
self.assertEqual(self._alu_from_tensor(t), [Ops.RECIP, Ops.SQRT])
|
||||
|
||||
def test_pow_2_has_1_mul(self):
|
||||
t = Tensor([1.0, 2.0, 3.0]) ** Tensor(2.0)
|
||||
self.assertEqual(self._alu_from_tensor(t), [Ops.MUL])
|
||||
|
||||
def test_pow_8_has_3_muls(self):
|
||||
t = Tensor([1.0, 2.0, 3.0]) ** 8
|
||||
self.assertEqual(self._alu_from_tensor(t), [Ops.MUL, Ops.MUL, Ops.MUL])
|
||||
|
||||
def test_pow_const_tensor_to_zero(self):
|
||||
x = Tensor([1,2,3,4])
|
||||
|
||||
Reference in New Issue
Block a user