From 3b777a9e05925916bd2f0e71fa76d2d2fb646400 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 5 Aug 2025 15:33:26 -0700 Subject: [PATCH] optimize in rewrite (#11516) * changes * fix test uops * dim shouldn't be 0 * huh, why did that one not save --- test/test_tensor.py | 6 +++--- tinygrad/codegen/__init__.py | 5 +++++ tinygrad/engine/realize.py | 7 +------ tinygrad/opt/__init__.py | 12 ++++++++++-- tinygrad/uop/ops.py | 2 +- 5 files changed, 20 insertions(+), 12 deletions(-) diff --git a/test/test_tensor.py b/test/test_tensor.py index f71844b233..21a868838f 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -892,13 +892,13 @@ class TestIdxUpcast(unittest.TestCase): @unittest.skipUnless(is_dtype_supported(dtypes.long), "int64 is supported") def test_overflow_sym(self): - self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 0, 2048).bind(32)) + self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32)) def test_regular(self): self.do_op_then_assert(dtypes.int, 64, 64, 64) def test_regular_sym(self): - self.do_op_then_assert(dtypes.int, 2048, 2048, UOp.variable("dim3", 0, 64).bind(32)) + self.do_op_then_assert(dtypes.int, 2048, 2048, UOp.variable("dim3", 1, 64).bind(32)) @unittest.skipIf(PTX, "PTX always convert Ops.INDEX to int64") def test_symfold(self): @@ -910,7 +910,7 @@ class TestIdxUpcast(unittest.TestCase): @unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported") def test_int64_unsupported_overflow_sym(self): with self.assertRaises(KeyError): - self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 0, 2048).bind(32)) + self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32)) @unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported") def test_int64_unsupported_overflow(self): diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 24456a2a4d..9ee02e682b 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -16,6 +16,7 @@ from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexin ReduceContext, correct_load_store, pm_render from tinygrad.codegen.optional import get_late_rewrite_patterns from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext +from tinygrad.opt import pm_optimize @dataclass class RewriteStep: @@ -42,6 +43,10 @@ def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[Rewri def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]: # ** lowerer (rewrite_shapetracker_with_index) ** ret: list[RewriteStep] = [] + + # this is kernel.py + ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast")) + if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize")) ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True)) diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index f4cd06c5bf..388f7d33e8 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -7,9 +7,7 @@ from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer from tinygrad.device import Device, Buffer from tinygrad.renderer import Renderer, ProgramSpec, Estimates from tinygrad.engine.schedule import ScheduleItem -from tinygrad.opt import get_optimized_ast from tinygrad.codegen import full_rewrite -from tinygrad.uop.spec import type_verify # **************** Program Creation **************** @@ -27,16 +25,13 @@ def get_program(ast:UOp, renderer:Renderer) -> ProgramSpec: """ if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST") - modified_ast = get_optimized_ast(ast, renderer) if ast.arg is None or ast.arg.opts_to_apply is not None else ast - if __debug__: type_verify(list(modified_ast.toposort())) # linearize try: - uops = full_rewrite(modified_ast, renderer) + uops = full_rewrite(ast, renderer) except RuntimeError: print("***** LINEARIZE FAILURE *****") print(f"ast = {ast}") - print(f"opts = {modified_ast.arg.applied_opts}") raise assert uops[-1].op is Ops.SINK, "last uop must be sink" diff --git a/tinygrad/opt/__init__.py b/tinygrad/opt/__init__.py index 934b9f0749..6128b21a5b 100644 --- a/tinygrad/opt/__init__.py +++ b/tinygrad/opt/__init__.py @@ -2,9 +2,10 @@ from tinygrad.opt.kernel import Kernel from tinygrad.opt.heuristic import hand_coded_optimizations -from tinygrad.uop.ops import UOp +from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops from tinygrad.helpers import NOOPT, BEAM, USE_TC, getenv from tinygrad.renderer import Renderer +from tinygrad.uop.spec import type_verify def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp: """ @@ -27,4 +28,11 @@ def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp: kb = Kernel(ast, opts=renderer) rawbufs = bufs_from_lin(kb, allocate=False) k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))) - return k.get_optimized_ast() + ret = k.get_optimized_ast() + if __debug__: type_verify(list(ret.toposort())) + return ret + +pm_optimize = PatternMatcher([ + (UPat(Ops.SINK, name="ast"), lambda ctx,ast: + get_optimized_ast(ast, ctx) if (ast.arg is None or ast.arg.opts_to_apply is not None) and ast.src[0].st is not None else None), +]) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 1dfdaeb857..add5ca012f 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -440,7 +440,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): all_vars = set([x for x in self.toposort() if x.op is Ops.DEFINE_VAR]) return bound_vars.union(set([x for x in all_vars if x not in bound_var_base])) def variables(self) -> list[Variable]: - st_vars: list[set[Variable]] = [x.st_arg.vars() for x in self.toposort() if x.op in GroupOp.Buffer] + st_vars: list[set[Variable]] = [x.arg.vars() for x in self.toposort() if x.op is Ops.VIEW] return sorted(set.union(*st_vars, set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()])), key=lambda v: v.arg) # *** uop symbolic stuff ***