st is a cached_property on UOp [run_process_replay] (#6433)

This commit is contained in:
qazal
2024-09-10 08:30:35 +08:00
committed by GitHub
parent cf64f8bb40
commit 29e63097a0
3 changed files with 21 additions and 23 deletions

View File

@@ -15,7 +15,7 @@ from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps
from tinygrad.ops import graph_rewrite
from tinygrad.helpers import AST_REWRITE, CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import create_schedule, get_output_st, reduceop_fusor, st_fixup, ScheduleItem
from tinygrad.engine.schedule import create_schedule, reduceop_fusor, st_fixup, ScheduleItem
from tinygrad.engine.realize import CompiledRunner, run_schedule
from test.helpers import assert_equiv_uops, ast_const, is_dtype_supported, Context, timeit
from tinygrad.lazy import LazyBuffer, view_supported_devices
@@ -1650,8 +1650,8 @@ class TestScheduleRewrite(unittest.TestCase):
a = Tensor([1,2,3,4]).realize()
for _ in range(24): a = a + a
ast = a.schedule()[0].ast
new_uop, et = timeit(st_fixup, ast.src[0].src[2], lambda st:st.reshape((4, 1)), {}, {})
self.assertEqual(get_output_st(new_uop, {}), ShapeTracker.from_shape((4,)).reshape((4, 1)))
new_uop, et = timeit(st_fixup, ast.src[0].src[2], lambda st:st.reshape((4, 1)), {})
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
self.assertLess(et, 1e3)
def test_no_rewrite_elementwise(self):