Revert "optimize in rewrite (#11516)" (#11517)

This reverts commit 3b777a9e05.
This commit is contained in:
George Hotz
2025-08-05 15:39:07 -07:00
committed by GitHub
parent 3b777a9e05
commit 4dabdf7c6d
5 changed files with 12 additions and 20 deletions

View File

@@ -892,13 +892,13 @@ class TestIdxUpcast(unittest.TestCase):
@unittest.skipUnless(is_dtype_supported(dtypes.long), "int64 is supported")
def test_overflow_sym(self):
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32))
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 0, 2048).bind(32))
def test_regular(self):
self.do_op_then_assert(dtypes.int, 64, 64, 64)
def test_regular_sym(self):
self.do_op_then_assert(dtypes.int, 2048, 2048, UOp.variable("dim3", 1, 64).bind(32))
self.do_op_then_assert(dtypes.int, 2048, 2048, UOp.variable("dim3", 0, 64).bind(32))
@unittest.skipIf(PTX, "PTX always convert Ops.INDEX to int64")
def test_symfold(self):
@@ -910,7 +910,7 @@ class TestIdxUpcast(unittest.TestCase):
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
def test_int64_unsupported_overflow_sym(self):
with self.assertRaises(KeyError):
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32))
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 0, 2048).bind(32))
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
def test_int64_unsupported_overflow(self):

View File

@@ -16,7 +16,6 @@ from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexin
ReduceContext, correct_load_store, pm_render
from tinygrad.codegen.optional import get_late_rewrite_patterns
from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
from tinygrad.opt import pm_optimize
@dataclass
class RewriteStep:
@@ -43,10 +42,6 @@ def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[Rewri
def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
# ** lowerer (rewrite_shapetracker_with_index) **
ret: list[RewriteStep] = []
# this is kernel.py
ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast"))
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))

View File

@@ -7,7 +7,9 @@ from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
from tinygrad.engine.schedule import ScheduleItem
from tinygrad.opt import get_optimized_ast
from tinygrad.codegen import full_rewrite
from tinygrad.uop.spec import type_verify
# **************** Program Creation ****************
@@ -25,13 +27,16 @@ def get_program(ast:UOp, renderer:Renderer) -> ProgramSpec:
"""
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
modified_ast = get_optimized_ast(ast, renderer) if ast.arg is None or ast.arg.opts_to_apply is not None else ast
if __debug__: type_verify(list(modified_ast.toposort()))
# linearize
try:
uops = full_rewrite(ast, renderer)
uops = full_rewrite(modified_ast, renderer)
except RuntimeError:
print("***** LINEARIZE FAILURE *****")
print(f"ast = {ast}")
print(f"opts = {modified_ast.arg.applied_opts}")
raise
assert uops[-1].op is Ops.SINK, "last uop must be sink"

View File

@@ -2,10 +2,9 @@
from tinygrad.opt.kernel import Kernel
from tinygrad.opt.heuristic import hand_coded_optimizations
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops
from tinygrad.uop.ops import UOp
from tinygrad.helpers import NOOPT, BEAM, USE_TC, getenv
from tinygrad.renderer import Renderer
from tinygrad.uop.spec import type_verify
def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp:
"""
@@ -28,11 +27,4 @@ def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp:
kb = Kernel(ast, opts=renderer)
rawbufs = bufs_from_lin(kb, allocate=False)
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
ret = k.get_optimized_ast()
if __debug__: type_verify(list(ret.toposort()))
return ret
pm_optimize = PatternMatcher([
(UPat(Ops.SINK, name="ast"), lambda ctx,ast:
get_optimized_ast(ast, ctx) if (ast.arg is None or ast.arg.opts_to_apply is not None) and ast.src[0].st is not None else None),
])
return k.get_optimized_ast()

View File

@@ -440,7 +440,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
all_vars = set([x for x in self.toposort() if x.op is Ops.DEFINE_VAR])
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
def variables(self) -> list[Variable]:
st_vars: list[set[Variable]] = [x.arg.vars() for x in self.toposort() if x.op is Ops.VIEW]
st_vars: list[set[Variable]] = [x.st_arg.vars() for x in self.toposort() if x.op in GroupOp.Buffer]
return sorted(set.union(*st_vars, set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()])), key=lambda v: v.arg)
# *** uop symbolic stuff ***