diff --git a/test/test_tensor.py b/test/test_tensor.py index 21a868838f..f71844b233 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -892,13 +892,13 @@ class TestIdxUpcast(unittest.TestCase): @unittest.skipUnless(is_dtype_supported(dtypes.long), "int64 is supported") def test_overflow_sym(self): - self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32)) + self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 0, 2048).bind(32)) def test_regular(self): self.do_op_then_assert(dtypes.int, 64, 64, 64) def test_regular_sym(self): - self.do_op_then_assert(dtypes.int, 2048, 2048, UOp.variable("dim3", 1, 64).bind(32)) + self.do_op_then_assert(dtypes.int, 2048, 2048, UOp.variable("dim3", 0, 64).bind(32)) @unittest.skipIf(PTX, "PTX always convert Ops.INDEX to int64") def test_symfold(self): @@ -910,7 +910,7 @@ class TestIdxUpcast(unittest.TestCase): @unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported") def test_int64_unsupported_overflow_sym(self): with self.assertRaises(KeyError): - self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32)) + self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 0, 2048).bind(32)) @unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported") def test_int64_unsupported_overflow(self): diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 9ee02e682b..24456a2a4d 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -16,7 +16,6 @@ from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexin ReduceContext, correct_load_store, pm_render from tinygrad.codegen.optional import get_late_rewrite_patterns from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext -from tinygrad.opt import pm_optimize @dataclass class RewriteStep: @@ -43,10 +42,6 @@ def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[Rewri def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]: # ** lowerer (rewrite_shapetracker_with_index) ** ret: list[RewriteStep] = [] - - # this is kernel.py - ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast")) - if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize")) ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True)) diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 388f7d33e8..f4cd06c5bf 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -7,7 +7,9 @@ from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer from tinygrad.device import Device, Buffer from tinygrad.renderer import Renderer, ProgramSpec, Estimates from tinygrad.engine.schedule import ScheduleItem +from tinygrad.opt import get_optimized_ast from tinygrad.codegen import full_rewrite +from tinygrad.uop.spec import type_verify # **************** Program Creation **************** @@ -25,13 +27,16 @@ def get_program(ast:UOp, renderer:Renderer) -> ProgramSpec: """ if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST") + modified_ast = get_optimized_ast(ast, renderer) if ast.arg is None or ast.arg.opts_to_apply is not None else ast + if __debug__: type_verify(list(modified_ast.toposort())) # linearize try: - uops = full_rewrite(ast, renderer) + uops = full_rewrite(modified_ast, renderer) except RuntimeError: print("***** LINEARIZE FAILURE *****") print(f"ast = {ast}") + print(f"opts = {modified_ast.arg.applied_opts}") raise assert uops[-1].op is Ops.SINK, "last uop must be sink" diff --git a/tinygrad/opt/__init__.py b/tinygrad/opt/__init__.py index 6128b21a5b..934b9f0749 100644 --- a/tinygrad/opt/__init__.py +++ b/tinygrad/opt/__init__.py @@ -2,10 +2,9 @@ from tinygrad.opt.kernel import Kernel from tinygrad.opt.heuristic import hand_coded_optimizations -from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops +from tinygrad.uop.ops import UOp from tinygrad.helpers import NOOPT, BEAM, USE_TC, getenv from tinygrad.renderer import Renderer -from tinygrad.uop.spec import type_verify def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp: """ @@ -28,11 +27,4 @@ def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp: kb = Kernel(ast, opts=renderer) rawbufs = bufs_from_lin(kb, allocate=False) k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))) - ret = k.get_optimized_ast() - if __debug__: type_verify(list(ret.toposort())) - return ret - -pm_optimize = PatternMatcher([ - (UPat(Ops.SINK, name="ast"), lambda ctx,ast: - get_optimized_ast(ast, ctx) if (ast.arg is None or ast.arg.opts_to_apply is not None) and ast.src[0].st is not None else None), -]) + return k.get_optimized_ast() diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index add5ca012f..1dfdaeb857 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -440,7 +440,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): all_vars = set([x for x in self.toposort() if x.op is Ops.DEFINE_VAR]) return bound_vars.union(set([x for x in all_vars if x not in bound_var_base])) def variables(self) -> list[Variable]: - st_vars: list[set[Variable]] = [x.arg.vars() for x in self.toposort() if x.op is Ops.VIEW] + st_vars: list[set[Variable]] = [x.st_arg.vars() for x in self.toposort() if x.op in GroupOp.Buffer] return sorted(set.union(*st_vars, set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()])), key=lambda v: v.arg) # *** uop symbolic stuff ***