mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
bottom up earliest rewrites (#14587)
* better * bottom up earliest rewrites * fix
This commit is contained in:
@@ -5,7 +5,7 @@ from tinygrad.helpers import DISABLE_FAST_IDIV, EMULATED_DTYPES, DEVECTORIZE, TR
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, pyrender
|
||||
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
|
||||
from tinygrad.renderer import Renderer, ProgramSpec
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import panic
|
||||
from tinygrad.codegen.opt import Opt
|
||||
|
||||
@@ -18,15 +18,9 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in
|
||||
ReduceContext, correct_load_store, pm_render, pm_add_loads
|
||||
from tinygrad.codegen.opt.postrange import apply_opts, make_images
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar
|
||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||
|
||||
pm_syntactic_sugar = PatternMatcher([
|
||||
# INDEX on ptr INDEX concats them
|
||||
(UPat(Ops.INDEX, name="i1").f(Ops.INDEX, name="i2", allow_any_len=True),
|
||||
lambda i1,i2: i2.replace(src=i1.src+i2.src[1:]) if isinstance(i1.dtype, PtrDType) and not isinstance(i2.dtype, PtrDType) else None),
|
||||
])
|
||||
|
||||
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
|
||||
if ren is None: ren = Renderer()
|
||||
|
||||
|
||||
@@ -14,6 +14,12 @@ from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTI
|
||||
import sys
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
pm_syntactic_sugar = PatternMatcher([
|
||||
# INDEX on ptr INDEX concats them
|
||||
(UPat(Ops.INDEX, name="i1").f(Ops.INDEX, name="i2", allow_any_len=True),
|
||||
lambda i1,i2: i2.replace(src=i1.src+i2.src[1:]) if isinstance(i1.dtype, PtrDType) and not isinstance(i2.dtype, PtrDType) else None),
|
||||
])
|
||||
|
||||
# movement op on INDEX as a PatternMatcher
|
||||
pm_mops = PatternMatcher([
|
||||
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
|
||||
@@ -564,7 +570,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
uop_list: list[UOp] = []
|
||||
tsink = graph_rewrite(sink, add_tags, ctx=(uop_list, set()), bottom_up=True, name="number the uops")
|
||||
|
||||
tsink = graph_rewrite(tsink, pm_mops+earliest_rewrites+replace_contiguous, ctx={}, name="earliest rewrites")
|
||||
tsink = graph_rewrite(tsink, pm_syntactic_sugar+pm_mops+earliest_rewrites+replace_contiguous, ctx={}, bottom_up=True, name="earliest rewrites")
|
||||
|
||||
# convert movement ops to ranges
|
||||
tsink, rctx = run_rangeify(tsink, bool(DEBUG_RANGEIFY))
|
||||
|
||||
@@ -542,6 +542,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
def base(self) -> UOp:
|
||||
if self.op in GroupOp.Movement: return self.src[0].base
|
||||
if self.op is Ops.MULTI: return self.src[0].base # MULTI is really a VIEW
|
||||
if self.op is Ops.DETACH: return self.src[0].base # DETACH can't change base
|
||||
return self
|
||||
|
||||
# like gep, but might return an integer
|
||||
|
||||
Reference in New Issue
Block a user