From 6533250246ad5df9318e1ca60ddebc8f7c250048 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 21 Feb 2026 12:51:53 +0800 Subject: [PATCH] remove more tags stuff (#14927) * remove more tags stuff * remove more * unique consts aren't needed post tensor --- tinygrad/engine/allocations.py | 9 +++++- tinygrad/engine/schedule.py | 16 ++-------- tinygrad/schedule/indexing.py | 4 +-- tinygrad/schedule/rangeify.py | 54 +++++----------------------------- 4 files changed, 19 insertions(+), 64 deletions(-) diff --git a/tinygrad/engine/allocations.py b/tinygrad/engine/allocations.py index e5e1031bf2..c79bf0e70d 100644 --- a/tinygrad/engine/allocations.py +++ b/tinygrad/engine/allocations.py @@ -81,6 +81,13 @@ pm_early_transform_tensor_graph = PatternMatcher([ lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), # handle size 0 (UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x._shape is not None and x.size == 0 else None), + # early fixup const copy (TODO: is this wrong if there's a pad?) + (UPat(Ops.COPY, src=(UPat.var("s"), UPat()), name="c"), lambda c,s: c.const_like(ss.arg) if (ss:=s.base).op is Ops.CONST else None), +]) + +pm_remove_unique_consts = PatternMatcher([ + # replace UNIQUE with LUNIQUE for CONST cache key normalization + (UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE, name="d")), name="b"), lambda b,d: b.replace(src=(d,))), ]) def allocate_global_buffers(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]: @@ -108,5 +115,5 @@ def allocate_global_buffers(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]: replace_uop = s while replace_uop.op is Ops.ASSIGN: replace_uop = replace_uop.src[0] buffer_map[original_uop] = replace_uop.shrink_to(original_uop.shape) - big_sink = graph_rewrite(big_sink, _remove_all_tags, name="remove tags") + big_sink = graph_rewrite(big_sink, _remove_all_tags+pm_remove_unique_consts, name="remove tags") return big_sink, buffer_map diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index d3fe1ab683..026de189f5 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -73,13 +73,6 @@ def replace_input_buffer(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], li ctx[2][0] += 1 return ret -def replace_input_const(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]], b:UOp): - if (ret:=ctx[0].get(b, None)) is None: - # replace UNIQUE with LUNIQUE for CONST cache key normalization - ctx[0][b] = ret = b.replace(src=(UOp(Ops.LUNIQUE, arg=ctx[3][0]), b.src[1])) - ctx[3][0] += 1 - return ret - def strip_bind(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]], b:UOp): var, val = b.src[0], b.src[1].arg assert var.expr not in ctx[1] or ctx[1][var.expr] == val, f"bind mismatch on {var}, {ctx[1][var.expr]} != {val}" @@ -89,8 +82,6 @@ def strip_bind(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]], pm_pre_sched_cache = PatternMatcher([ # replace BUFFER with PARAM for cache key normalization (UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer), - # replace UNIQUE with LUNIQUE for CONST cache key normalization - (UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_const), # strip value from BIND for cache key normalization, so different values hit same cache (UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), strip_bind), ]) @@ -102,8 +93,6 @@ def create_new_buffer(ctx:dict[UOp, UOp], b:UOp): pm_post_sched_cache = PatternMatcher([ # create new BUFFERs for LUNIQUE BUFFERs from rangeify (UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer), - # restore CONST back to original CONST - (UPat(Ops.CONST, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), lambda ctx,b: ctx.get(b)), # restore PARAM back to original BUFFER (UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE)), name="b"), lambda ctx,b: ctx.get(b)), # restore BIND value stripped in pm_pre_sched_cache @@ -128,13 +117,12 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li # verify Tensors match the spec (on big_sink, we only need to do this if cache misses) if SPEC: type_verify(big_sink, tensor_spec) big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm") - big_sink = get_rangeify(big_sink_cache) - pre_schedule, buf_uops_sink = create_schedule(big_sink) + pre_schedule, buf_uops_sink = create_schedule(get_rangeify(big_sink_cache)) if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, buf_uops_sink) else: # schedule cache hit pre_schedule, buf_uops_sink = sc_ret - del big_sink_cache + del big_sink, big_sink_cache # replace all the PARAMs/LUNIQUEs back (single graph_rewrite for everything) input_buffers_inverse = {v:k for k,v in input_buffers.items()} diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 88191258cd..15919caa5d 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -72,7 +72,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp): # None in the device assigns it a number later opts = BufferizeOpts(device=s.device, removable=removable) if len(ctx.range_map[s][1]) == len(realized_ranges) else \ BufferizeOpts(device=s.device, addrspace=AddrSpace.LOCAL, removable=removable) - new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None) + new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts) if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges]) new_srcs.append(new_src) # NOTE: do we need this? @@ -88,7 +88,7 @@ def convert_pad_to_where_to_keep_behavior_local(ctx:IndexingContext, x:UOp): def convert_reduce_axis_to_reduce_with_ranges(ctx:IndexingContext, x:UOp): # input ranges new_ranges = [r for i,r in enumerate(ctx.range_map[x][0]) if i in x.arg[1]] - ret = UOp(Ops.REDUCE, x.dtype, src=(x.src[0],)+tuple(new_ranges), arg=x.arg[0], tag=x.tag) + ret = UOp(Ops.REDUCE, x.dtype, src=(x.src[0],)+tuple(new_ranges), arg=x.arg[0]) ctx.range_map[ret] = ctx.range_map[x] return ret diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 03fd1eba62..68a5d48878 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -33,20 +33,6 @@ pm_mops = PatternMatcher([ # ***************** # 0. do some cleanup rewrites, mostly copied from the old stuff -def assign_to_contiguous(assign:UOp, target:UOp, src:UOp): - if (t := target.base).op is Ops.PARAM or (t.op is Ops.MSTACK and all(s.op is Ops.PARAM for s in t.src)): return None - # partial view of unrealized graph: insert CONTIGUOUS at base to realize it - if target is not t and target.op_in_backward_slice_with_self(Ops.SHRINK): - if t.op is Ops.CONTIGUOUS: return None - mops: list[UOp] = [] - while target.op in GroupOp.Movement: - mops.append(target) - target = target.src[0] - new_target = t.f(Ops.CONTIGUOUS) - for m in reversed(mops): new_target = m.replace(src=(new_target,)+m.src[1:]) - return assign.replace(src=(new_target, src)) - return src.f(Ops.CONTIGUOUS) - def fix_assign_hazard(assign:UOp, target:UOp, src:UOp): # PERMUTE and FLIP reorder indices, SHRINK can have overlapping regions when dest is also shrunk unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if target.op_in_backward_slice_with_self(Ops.SHRINK) else set()) @@ -86,7 +72,7 @@ def split_reduceop(reduce:UOp, x:UOp): return splitted.r(*reduce.arg).contiguous().r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape) mop_cleanup = PatternMatcher([ - # merge adjacent RESHAPES, safe because they are not tagged + # merge adjacent RESHAPES (UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE, name="x2"), UPat()), name="x"), lambda x,x2: x.replace(src=(x2.src[0], x.src[1]))), ]) @@ -108,9 +94,6 @@ earliest_rewrites = mop_cleanup+PatternMatcher([ # resolve calls (UPat(Ops.CALL, name="c"), resolve_call), - # remove CONTIGUOUS if the source is already contiguous - (UPat(Ops.RESHAPE, src=(UPat((Ops.PARAM, Ops.CONTIGUOUS)), UPat()), name="r").f(Ops.CONTIGUOUS), lambda r: r), - # split_reduceop (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop), @@ -124,11 +107,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([ # ** copy rules ** - # early fixup const copy - (UPat(Ops.COPY, src=(UPat.var("s"), UPat()), name="c"), lambda c,s: c.const_like(ss.arg) if (ss:=s.base).op is Ops.CONST else None), - # COPY and source size need to match - # TODO: expand after copy creates issues with tagging (UPat(Ops.COPY, src=(UPat(GroupOp.Movement, name="r"), UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.contiguous(), d)) if r.size != r.base.size else None), @@ -141,17 +120,14 @@ earliest_rewrites = mop_cleanup+PatternMatcher([ (UPat(Ops.ASSIGN, src=(UPat(name="target"), UPat(Ops.ASSIGN, src=(UPat(name="target"), UPat()), name="src"))), lambda target, src: src), # move bitcast from assign target to source: a.bitcast(X).assign(src) -> a.assign(src.bitcast(a.dtype)) - (UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src")), name="assign"), - lambda assign, target, src: target.assign(src.bitcast(target.dtype))), + (UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src"))), + lambda target, src: target.assign(src.bitcast(target.dtype))), # if assign target is itself an ASSIGN chain, canonicalize to the original buffer target (UPat(Ops.ASSIGN, src=(UPat(Ops.ASSIGN, name="target"), UPat(name="src")), allow_any_len=True, name="assign"), normalize_assign_target_chain), - # assign only to buffer, otherwise make it a CONTIGUOUS - (UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.PARAM}, name="target"), UPat(name="src")), name="assign"), assign_to_contiguous), - - # make source contiguous if it has hazardous movement ops on the dest buffer - (UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard), + # make source contiguous if it has hazardous movement ops on the dest buffer + (UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard), ]) # ***************** @@ -396,7 +372,6 @@ class LocalAddBufferContext: map:dict = field(default_factory=dict) vars:dict = field(default_factory=dict) range:int = 0 - parent_tags:list = field(default_factory=list) opts:tuple|None = None def debuf(ctx:LocalAddBufferContext, buf:UOp): @@ -458,12 +433,6 @@ rangeify_codegen = PatternMatcher([ # TODO: this can be moved into codegen? (UPat(Ops.NOOP, name="x"), lambda x: x.src[0]), - # add loads to non ptr indexes - # TODO: this can be moved into codegen? - #(UPat.any(UPat(Ops.DEFINE_GLOBAL, name="dg"), UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True, name="dg")) - # .f(Ops.INDEX, name="idx", allow_any_len=True), - # lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()), - # fix broadcast dtype (UPat(Ops.AFTER, name="a").broadcast(name="b"), lambda a,b: a.broadcast(len(b.src))), (UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True).broadcast(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True), @@ -475,15 +444,6 @@ rangeify_codegen = PatternMatcher([ idx.replace(dtype=dg.dtype, arg=None).load(dtype=dg.dtype.base.scalar().vec(dg.dtype.vcount))), ]) -def remove_metadata_tags(ctx:LocalAddBufferContext, x:UOp): - if x.tag is None or x.tag == (): return None - if isinstance(x.tag, tuple): ctx.parent_tags += list(x.tag) - return x.replace(tag=None) - -pm_remove_tags = PatternMatcher([ - (UPat(GroupOp.All, name="x"), remove_metadata_tags), -]) - pm_add_range_tags = PatternMatcher([ (UPat(Ops.RANGE, name="x"), lambda x: x.rtag(())), ]) @@ -494,7 +454,7 @@ def split_store(x:UOp) -> UOp|None: # local kernel rewrite lctx = LocalAddBufferContext() - ret = graph_rewrite(x, to_define_global+pm_flatten_range+rangeify_codegen+pm_remove_tags, ctx=lctx, name="kernel split", bottom_up=True) + ret = graph_rewrite(x, to_define_global+pm_flatten_range+rangeify_codegen, ctx=lctx, name="kernel split", bottom_up=True) # SINK requires all buffers on the same device, but COPY/BUFFER_VIEW/ENCDEC are cross-device or special hardware ops if ret.op is Ops.STORE: stored = ret.src[1] @@ -522,7 +482,7 @@ def get_rangeify(sink:UOp) -> UOp: tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_remove_bufferize, name="symbolic+reduce_collapse+debuf") tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers") - if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify") + if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Rangeify") # bufferize -> store lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1