fix indexing bug with convs

* minimal difference for ONE_POOL=1

* fix indexing bug

* improve indexing debugger

* more debugger improvements

* always for reshape
This commit is contained in:
George Hotz
2025-11-07 16:45:19 -08:00
committed by GitHub
parent 6a509da7f3
commit ffb9e8396f
3 changed files with 28 additions and 7 deletions

View File

@@ -300,6 +300,12 @@ class TestRangeify(unittest.TestCase):
w2 = Tensor.empty(12, 8, 3, 3)
x.conv2d(w1).conv2d(w2).realize()
def test_resnet_conv2d(self):
x = Tensor.empty(1, 8, 32, 32)
w1 = Tensor.empty(8, 8, 3, 3)
w2 = Tensor.empty(8, 8, 1, 1)
x.conv2d(w1).conv2d(w2).realize()
def test_xception_conv2d(self):
# NOTE: this fusion is bad, it's recomputing the inner many times
x = Tensor.empty(1, 4, 32, 32)

View File

@@ -1573,6 +1573,13 @@ class TestSchedule(unittest.TestCase):
def test_conv2d(self): _test_conv2d(5 if SPLIT_REDUCEOP else 4)
def test_conv2d_fused(self): _test_conv2d(5 if SPLIT_REDUCEOP else 4)
def test_resnet_conv2d(self):
x = Tensor.empty(1, 8, 32, 32)
w1 = Tensor.empty(8, 8, 3, 3)
w2 = Tensor.empty(8, 8, 1, 1)
out = x.conv2d(w1).conv2d(w2)
check_schedule(out, 2)
@unittest.skipUnless(is_dtype_supported(dtypes.half) and is_dtype_supported(dtypes.ulong), "need half and ulong")
def test_conv2d_half(self): _test_conv2d(5 if SPLIT_REDUCEOP else 4, dtype=dtypes.half)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")

View File

@@ -239,7 +239,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
# if the EXPAND is used to inject a range, we don't mark it as ending_ranges. otherwise we do.
# NOTE: this doesn't actually always end a range, but this is why convs are realized, so for now we need it
if x.op is Ops.EXPAND and all(isinstance(y, int) or y.op is not Ops.RANGE for y in x.shape):
ending_ranges[x] = list(UOp.sink(*[ro for ri, ro in zip(rngs, out_rngs) if ri is not ro]).ranges.keys())
ending_ranges[x] += list(UOp.sink(*[ro for ri, ro in zip(rngs, out_rngs) if ri is not ro]).ranges.keys())
# REDUCE_AXIS creates ranges for the axes it is reducing
if x.op is Ops.REDUCE_AXIS:
@@ -247,15 +247,23 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
if debug:
realized_ranges = rctx.realize_map.get(x, None)
disp = []
for i, (ri, ro) in enumerate(zip([r.render() for r in rngs], [r.render() for r in out_rngs])):
rng = f"{ri}" if ri == ro else f"{ri} -> {ro}"
if realized_ranges is not None and i in realized_ranges: rng = colored(rng, "yellow")
disp.append("["+rng+"]")
print("***" if x in rctx.realize_map else " ", len(consumer_map[x]), f"{str(x.op):20s}", ''.join(disp))
if x.op is Ops.RESHAPE or len(rngs) != len(out_rngs):
disp = render_ranges(rngs, realized=realized_ranges) + " -> " + render_ranges(out_rngs, realized=realized_ranges)
else:
disp = render_ranges(rngs, out_rngs, realized=realized_ranges)
print("***" if x in rctx.realize_map else " ",
f"{len(consumer_map[x]):2d} {str(x.op):20s} {str(x.shape):35s} {len(ending_ranges[x]):2d}", disp)
# assign to the range map. rngs are the input ranges, out_rngs are the output ranges, from the x op.
rctx.range_map[x] = (rngs, out_rngs)
tsink = graph_rewrite(tsink, pm_apply_rangeify, ctx=rctx, bottom_up=True, name="apply rangeify")
return tsink, rctx
def render_ranges(*rngs_list, realized) -> str:
disp = []
for i, rs in enumerate(zip(*[[r.render() for r in rngs] for rngs in rngs_list])):
rng = rs[0] if all_same(rs) else " -> ".join(rs)
if realized is not None and i in realized: rng = colored(rng, "yellow")
disp.append("["+rng+"]")
return ''.join(disp)