diff --git a/tinygrad/opt/kernel.py b/tinygrad/opt/kernel.py index e0cda3eb20..c931a89782 100644 --- a/tinygrad/opt/kernel.py +++ b/tinygrad/opt/kernel.py @@ -4,18 +4,15 @@ from dataclasses import dataclass from collections import defaultdict from typing import Optional, cast, Final, Callable, Sequence -from tinygrad.uop.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, track_rewrites, print_uops -from tinygrad.uop.ops import PatternMatcher, smax +from tinygrad.uop.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, smax from tinygrad.uop.spec import type_verify, ast_spec from tinygrad.device import Device from tinygrad.renderer import Renderer, TensorCore, ProgramSpec, Opt, OptOps from tinygrad.dtype import ImageDType -from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, unwrap -from tinygrad.helpers import DEBUG, TC_SELECT, TC_OPT, AMX +from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, all_int, to_function_name, unwrap, DEBUG, TC_SELECT, TC_OPT, AMX from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import strides_for_shape, get_contraction from tinygrad.kernelize.kernelize import view_left -from tinygrad.codegen import full_rewrite class KernelOptError(Exception): pass @@ -530,42 +527,13 @@ class Kernel: del fixup_ast return graph_rewrite(fixed_ast, view_left, name="fixup optimized AST") - # **** this is the lowerer **** + # TODO: update the tests and delete these methods - @track_rewrites() - def linearize(self, name_override:Optional[str]=None, ast_transform:Optional[Callable]=None) -> Kernel: - # display the AST - if getenv("VIZ"): graph_rewrite(self.ast, PatternMatcher([]), name="View Base AST") - - modified_ast = self.get_optimized_ast(name_override) - if ast_transform is not None: modified_ast = ast_transform(self, modified_ast) - - if DEBUG >= 3: - print(self.name) - if DEBUG >= 5: print(self.ast) - for i,(buf,st) in enumerate([(buf,st) for buf,st in zip(self.bufs, self.sts) if buf.op not in {Ops.CONST, Ops.VALID}]): - print(f"{i:2d}: {str(st.shape):25s} {str(buf.src[0].dtype).replace('dtypes.',''):20s} {str(st.real_strides()):30s}", - str(st) if DEBUG >= 4 else "") - print(self.applied_opts) - if DEBUG >= 5: print(modified_ast) - # verify AST matches the spec after applying opts - if __debug__: type_verify(list(modified_ast.toposort())) - # TODO: sadly modified_ast doesn't pass the shape spec because of how group_for_reduces constructs UOps, there's probably a way to fix this - #if __debug__: type_verify(list(modified_ast.toposort()), ast_spec) - - try: - self.uops:list[UOp] = full_rewrite(modified_ast, self.opts) - except RuntimeError: - print("***** LINEARIZE FAILURE *****") - print(f"ast = {self.ast}") - print(f"opts = {self.applied_opts}") - raise - if DEBUG >= 6: print_uops(self.uops) + def linearize(self): + self.to_program() return self - - def to_program(self, name_override:Optional[str]=None, ast_transform:Optional[Callable]=None) -> ProgramSpec: - self.linearize(name_override, ast_transform) - assert self.uops[-1].op is Ops.SINK, "last uop must be sink" - src = self.opts.render(self.uops) - return ProgramSpec(self.name if not name_override else name_override, src, self.opts.device, self.ast, self.uops, - global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None) + def to_program(self, name_override:Optional[str]=None) -> ProgramSpec: + from tinygrad.engine.realize import get_program + ret = get_program(self.get_optimized_ast(name_override), self.opts) + self.uops = ret.uops + return ret