remove more tags stuff (#14927)

* remove more tags stuff

* remove more

* unique consts aren't needed post tensor
This commit is contained in:
George Hotz
2026-02-21 12:51:53 +08:00
committed by GitHub
parent 0c0d07d330
commit 6533250246
4 changed files with 19 additions and 64 deletions

View File

@@ -81,6 +81,13 @@ pm_early_transform_tensor_graph = PatternMatcher([
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# handle size 0
(UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x._shape is not None and x.size == 0 else None),
# early fixup const copy (TODO: is this wrong if there's a pad?)
(UPat(Ops.COPY, src=(UPat.var("s"), UPat()), name="c"), lambda c,s: c.const_like(ss.arg) if (ss:=s.base).op is Ops.CONST else None),
])
pm_remove_unique_consts = PatternMatcher([
# replace UNIQUE with LUNIQUE for CONST cache key normalization
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE, name="d")), name="b"), lambda b,d: b.replace(src=(d,))),
])
def allocate_global_buffers(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
@@ -108,5 +115,5 @@ def allocate_global_buffers(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
replace_uop = s
while replace_uop.op is Ops.ASSIGN: replace_uop = replace_uop.src[0]
buffer_map[original_uop] = replace_uop.shrink_to(original_uop.shape)
big_sink = graph_rewrite(big_sink, _remove_all_tags, name="remove tags")
big_sink = graph_rewrite(big_sink, _remove_all_tags+pm_remove_unique_consts, name="remove tags")
return big_sink, buffer_map

View File

@@ -73,13 +73,6 @@ def replace_input_buffer(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], li
ctx[2][0] += 1
return ret
def replace_input_const(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]], b:UOp):
if (ret:=ctx[0].get(b, None)) is None:
# replace UNIQUE with LUNIQUE for CONST cache key normalization
ctx[0][b] = ret = b.replace(src=(UOp(Ops.LUNIQUE, arg=ctx[3][0]), b.src[1]))
ctx[3][0] += 1
return ret
def strip_bind(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]], b:UOp):
var, val = b.src[0], b.src[1].arg
assert var.expr not in ctx[1] or ctx[1][var.expr] == val, f"bind mismatch on {var}, {ctx[1][var.expr]} != {val}"
@@ -89,8 +82,6 @@ def strip_bind(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]],
pm_pre_sched_cache = PatternMatcher([
# replace BUFFER with PARAM for cache key normalization
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
# replace UNIQUE with LUNIQUE for CONST cache key normalization
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_const),
# strip value from BIND for cache key normalization, so different values hit same cache
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), strip_bind),
])
@@ -102,8 +93,6 @@ def create_new_buffer(ctx:dict[UOp, UOp], b:UOp):
pm_post_sched_cache = PatternMatcher([
# create new BUFFERs for LUNIQUE BUFFERs from rangeify
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
# restore CONST back to original CONST
(UPat(Ops.CONST, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), lambda ctx,b: ctx.get(b)),
# restore PARAM back to original BUFFER
(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE)), name="b"), lambda ctx,b: ctx.get(b)),
# restore BIND value stripped in pm_pre_sched_cache
@@ -128,13 +117,12 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
# verify Tensors match the spec (on big_sink, we only need to do this if cache misses)
if SPEC: type_verify(big_sink, tensor_spec)
big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm")
big_sink = get_rangeify(big_sink_cache)
pre_schedule, buf_uops_sink = create_schedule(big_sink)
pre_schedule, buf_uops_sink = create_schedule(get_rangeify(big_sink_cache))
if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, buf_uops_sink)
else:
# schedule cache hit
pre_schedule, buf_uops_sink = sc_ret
del big_sink_cache
del big_sink, big_sink_cache
# replace all the PARAMs/LUNIQUEs back (single graph_rewrite for everything)
input_buffers_inverse = {v:k for k,v in input_buffers.items()}

View File

@@ -72,7 +72,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
# None in the device assigns it a number later
opts = BufferizeOpts(device=s.device, removable=removable) if len(ctx.range_map[s][1]) == len(realized_ranges) else \
BufferizeOpts(device=s.device, addrspace=AddrSpace.LOCAL, removable=removable)
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts)
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
new_srcs.append(new_src)
# NOTE: do we need this?
@@ -88,7 +88,7 @@ def convert_pad_to_where_to_keep_behavior_local(ctx:IndexingContext, x:UOp):
def convert_reduce_axis_to_reduce_with_ranges(ctx:IndexingContext, x:UOp):
# input ranges
new_ranges = [r for i,r in enumerate(ctx.range_map[x][0]) if i in x.arg[1]]
ret = UOp(Ops.REDUCE, x.dtype, src=(x.src[0],)+tuple(new_ranges), arg=x.arg[0], tag=x.tag)
ret = UOp(Ops.REDUCE, x.dtype, src=(x.src[0],)+tuple(new_ranges), arg=x.arg[0])
ctx.range_map[ret] = ctx.range_map[x]
return ret

View File

@@ -33,20 +33,6 @@ pm_mops = PatternMatcher([
# *****************
# 0. do some cleanup rewrites, mostly copied from the old stuff
def assign_to_contiguous(assign:UOp, target:UOp, src:UOp):
if (t := target.base).op is Ops.PARAM or (t.op is Ops.MSTACK and all(s.op is Ops.PARAM for s in t.src)): return None
# partial view of unrealized graph: insert CONTIGUOUS at base to realize it
if target is not t and target.op_in_backward_slice_with_self(Ops.SHRINK):
if t.op is Ops.CONTIGUOUS: return None
mops: list[UOp] = []
while target.op in GroupOp.Movement:
mops.append(target)
target = target.src[0]
new_target = t.f(Ops.CONTIGUOUS)
for m in reversed(mops): new_target = m.replace(src=(new_target,)+m.src[1:])
return assign.replace(src=(new_target, src))
return src.f(Ops.CONTIGUOUS)
def fix_assign_hazard(assign:UOp, target:UOp, src:UOp):
# PERMUTE and FLIP reorder indices, SHRINK can have overlapping regions when dest is also shrunk
unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if target.op_in_backward_slice_with_self(Ops.SHRINK) else set())
@@ -86,7 +72,7 @@ def split_reduceop(reduce:UOp, x:UOp):
return splitted.r(*reduce.arg).contiguous().r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape)
mop_cleanup = PatternMatcher([
# merge adjacent RESHAPES, safe because they are not tagged
# merge adjacent RESHAPES
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE, name="x2"), UPat()), name="x"), lambda x,x2: x.replace(src=(x2.src[0], x.src[1]))),
])
@@ -108,9 +94,6 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
# resolve calls
(UPat(Ops.CALL, name="c"), resolve_call),
# remove CONTIGUOUS if the source is already contiguous
(UPat(Ops.RESHAPE, src=(UPat((Ops.PARAM, Ops.CONTIGUOUS)), UPat()), name="r").f(Ops.CONTIGUOUS), lambda r: r),
# split_reduceop
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
@@ -124,11 +107,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
# ** copy rules **
# early fixup const copy
(UPat(Ops.COPY, src=(UPat.var("s"), UPat()), name="c"), lambda c,s: c.const_like(ss.arg) if (ss:=s.base).op is Ops.CONST else None),
# COPY and source size need to match
# TODO: expand after copy creates issues with tagging
(UPat(Ops.COPY, src=(UPat(GroupOp.Movement, name="r"), UPat(name="d")), name="c"),
lambda c,r,d: c.replace(src=(r.contiguous(), d)) if r.size != r.base.size else None),
@@ -141,17 +120,14 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
(UPat(Ops.ASSIGN, src=(UPat(name="target"), UPat(Ops.ASSIGN, src=(UPat(name="target"), UPat()), name="src"))), lambda target, src: src),
# move bitcast from assign target to source: a.bitcast(X).assign(src) -> a.assign(src.bitcast(a.dtype))
(UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src")), name="assign"),
lambda assign, target, src: target.assign(src.bitcast(target.dtype))),
(UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src"))),
lambda target, src: target.assign(src.bitcast(target.dtype))),
# if assign target is itself an ASSIGN chain, canonicalize to the original buffer target
(UPat(Ops.ASSIGN, src=(UPat(Ops.ASSIGN, name="target"), UPat(name="src")), allow_any_len=True, name="assign"), normalize_assign_target_chain),
# assign only to buffer, otherwise make it a CONTIGUOUS
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.PARAM}, name="target"), UPat(name="src")), name="assign"), assign_to_contiguous),
# make source contiguous if it has hazardous movement ops on the dest buffer
(UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),
# make source contiguous if it has hazardous movement ops on the dest buffer
(UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),
])
# *****************
@@ -396,7 +372,6 @@ class LocalAddBufferContext:
map:dict = field(default_factory=dict)
vars:dict = field(default_factory=dict)
range:int = 0
parent_tags:list = field(default_factory=list)
opts:tuple|None = None
def debuf(ctx:LocalAddBufferContext, buf:UOp):
@@ -458,12 +433,6 @@ rangeify_codegen = PatternMatcher([
# TODO: this can be moved into codegen?
(UPat(Ops.NOOP, name="x"), lambda x: x.src[0]),
# add loads to non ptr indexes
# TODO: this can be moved into codegen?
#(UPat.any(UPat(Ops.DEFINE_GLOBAL, name="dg"), UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True, name="dg"))
# .f(Ops.INDEX, name="idx", allow_any_len=True),
# lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()),
# fix broadcast dtype
(UPat(Ops.AFTER, name="a").broadcast(name="b"), lambda a,b: a.broadcast(len(b.src))),
(UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True).broadcast(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
@@ -475,15 +444,6 @@ rangeify_codegen = PatternMatcher([
idx.replace(dtype=dg.dtype, arg=None).load(dtype=dg.dtype.base.scalar().vec(dg.dtype.vcount))),
])
def remove_metadata_tags(ctx:LocalAddBufferContext, x:UOp):
if x.tag is None or x.tag == (): return None
if isinstance(x.tag, tuple): ctx.parent_tags += list(x.tag)
return x.replace(tag=None)
pm_remove_tags = PatternMatcher([
(UPat(GroupOp.All, name="x"), remove_metadata_tags),
])
pm_add_range_tags = PatternMatcher([
(UPat(Ops.RANGE, name="x"), lambda x: x.rtag(())),
])
@@ -494,7 +454,7 @@ def split_store(x:UOp) -> UOp|None:
# local kernel rewrite
lctx = LocalAddBufferContext()
ret = graph_rewrite(x, to_define_global+pm_flatten_range+rangeify_codegen+pm_remove_tags, ctx=lctx, name="kernel split", bottom_up=True)
ret = graph_rewrite(x, to_define_global+pm_flatten_range+rangeify_codegen, ctx=lctx, name="kernel split", bottom_up=True)
# SINK requires all buffers on the same device, but COPY/BUFFER_VIEW/ENCDEC are cross-device or special hardware ops
if ret.op is Ops.STORE: stored = ret.src[1]
@@ -522,7 +482,7 @@ def get_rangeify(sink:UOp) -> UOp:
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_remove_bufferize, name="symbolic+reduce_collapse+debuf")
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers")
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Rangeify")
# bufferize -> store
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1