fix rangeify elu fusion for openpilot (#12341)

* fix rangeify elu fusion for openpilot

* flip the metadata

* copy over permuted contiguous support

* this is correct

* update that
This commit is contained in:
George Hotz
2025-09-30 11:41:52 +08:00
committed by GitHub
parent d95d018bb5
commit f522e83a02
3 changed files with 32 additions and 6 deletions

View File

@@ -97,6 +97,11 @@ class TestRangeify(unittest.TestCase):
w1 = Tensor.empty(8, 4, 3, 3)
x.conv2d(w1).realize()
def test_conv2d_elu(self):
x = Tensor.empty(1, 4, 32, 32)
w1 = Tensor.empty(8, 4, 3, 3)
x.conv2d(w1).elu().realize()
def test_conv2d_t(self):
x = Tensor.empty(1, 4, 32, 32)
w1 = Tensor.empty(8, 4, 3, 3)

View File

@@ -42,7 +42,7 @@ class TestWinograd(unittest.TestCase):
out = Tensor.conv2d(x,w, padding=1)
out.mean().backward()
backward_schedule = Tensor.schedule(x.grad, w.grad)
self.assertEqual(len(backward_schedule), 6 if RANGEIFY else 9)
self.assertEqual(len(backward_schedule), 4 if RANGEIFY else 9)
def test_counters(self):
IC, OC, X, Y = 4,4,9,9

View File

@@ -258,11 +258,16 @@ def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
end_ranges = []
idx_ranges = []
# NOTE: locals aren't working, so we only fully bufferize here (unless RANGEIFY > 1)
all_all_same = all(all_same(r) for r in all_rngs)
for i,valid_rngs in enumerate(all_rngs):
rngs_valids = []
for valid_rngs in all_rngs:
rngs, valids = zip(*[(r.get_idx(), r.get_valid()) for r in valid_rngs])
# if a range has a 1 src, it's the same as UOp.const(dtypes.index, 0)
same_rngs = [x if x.op is not Ops.RANGE or resolve(x.src[0] != 1) else UOp.const(dtypes.index, 0) for x in rngs]
rngs_valids.append((rngs, valids, all_same(same_rngs)))
all_all_same = all(same_rngs for _,_,same_rngs in rngs_valids)
for i,(rngs,valids,same_rngs) in enumerate(rngs_valids):
# we compare the ranges without their valids
if all_same(rngs) and (all_all_same or RANGEIFY > 1):
if same_rngs and (all_all_same or RANGEIFY > 1):
# the new valid is the OR of all the children valids
minimum_valid = functools.reduce(operator.or_, valids, UOp.const(dtypes.bool, False))
out_rngs.append(minimum_valid.where(rngs[0], UOp.invalid()).simplify())
@@ -576,7 +581,7 @@ def split_store(ctx:list[UOp], x:UOp):
# NOTE: the hack for COPY is here
ret = ret.sink() if ret.src[1].op not in {Ops.COPY, Ops.BUFFER_VIEW} else ret.src[1]
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None]))))
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1])
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
return x.as_buf().assign(kernel)
@@ -593,12 +598,28 @@ add_tags = PatternMatcher([
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND}.union(GroupOp.Movement), name="x"), tag_uop),
])
# support for using a contiguous permuted view instead of the parent view if one exists
# modified from kernelize.py to not use ShapeTracker
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
x = src
while x is not src.base:
if x.op is Ops.PERMUTE: contig = contig.permute(argsort(x.arg))
elif x.op is Ops.RESHAPE: contig = contig.reshape(x.src[0].shape)
else: return None
x = x.src[0]
ctx[src.base] = contig
replace_contiguous = PatternMatcher([
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement, name="src"),), name="contig"), found_contiguous),
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
])
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}", True)
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
uop_list: list[UOp] = []
tsink = graph_rewrite(sink, add_tags, ctx=uop_list, bottom_up=True, name="number the uops")
tsink = graph_rewrite(tsink, earliest_rewrites, name="earliest rewrites")
tsink = graph_rewrite(tsink, earliest_rewrites+replace_contiguous, ctx={}, name="earliest rewrites")
realize_map: dict[UOp, UOp] = {}
graph_rewrite(tsink, do_realize, ctx=realize_map, name="Input Graph")
# NOTE: we don't use contiguous here, contiguous is a user op