mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
fix indexing bug with convs
* minimal difference for ONE_POOL=1 * fix indexing bug * improve indexing debugger * more debugger improvements * always for reshape
This commit is contained in:
@@ -300,6 +300,12 @@ class TestRangeify(unittest.TestCase):
|
||||
w2 = Tensor.empty(12, 8, 3, 3)
|
||||
x.conv2d(w1).conv2d(w2).realize()
|
||||
|
||||
def test_resnet_conv2d(self):
|
||||
x = Tensor.empty(1, 8, 32, 32)
|
||||
w1 = Tensor.empty(8, 8, 3, 3)
|
||||
w2 = Tensor.empty(8, 8, 1, 1)
|
||||
x.conv2d(w1).conv2d(w2).realize()
|
||||
|
||||
def test_xception_conv2d(self):
|
||||
# NOTE: this fusion is bad, it's recomputing the inner many times
|
||||
x = Tensor.empty(1, 4, 32, 32)
|
||||
|
||||
@@ -1573,6 +1573,13 @@ class TestSchedule(unittest.TestCase):
|
||||
def test_conv2d(self): _test_conv2d(5 if SPLIT_REDUCEOP else 4)
|
||||
def test_conv2d_fused(self): _test_conv2d(5 if SPLIT_REDUCEOP else 4)
|
||||
|
||||
def test_resnet_conv2d(self):
|
||||
x = Tensor.empty(1, 8, 32, 32)
|
||||
w1 = Tensor.empty(8, 8, 3, 3)
|
||||
w2 = Tensor.empty(8, 8, 1, 1)
|
||||
out = x.conv2d(w1).conv2d(w2)
|
||||
check_schedule(out, 2)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half) and is_dtype_supported(dtypes.ulong), "need half and ulong")
|
||||
def test_conv2d_half(self): _test_conv2d(5 if SPLIT_REDUCEOP else 4, dtype=dtypes.half)
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
|
||||
@@ -239,7 +239,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
# if the EXPAND is used to inject a range, we don't mark it as ending_ranges. otherwise we do.
|
||||
# NOTE: this doesn't actually always end a range, but this is why convs are realized, so for now we need it
|
||||
if x.op is Ops.EXPAND and all(isinstance(y, int) or y.op is not Ops.RANGE for y in x.shape):
|
||||
ending_ranges[x] = list(UOp.sink(*[ro for ri, ro in zip(rngs, out_rngs) if ri is not ro]).ranges.keys())
|
||||
ending_ranges[x] += list(UOp.sink(*[ro for ri, ro in zip(rngs, out_rngs) if ri is not ro]).ranges.keys())
|
||||
|
||||
# REDUCE_AXIS creates ranges for the axes it is reducing
|
||||
if x.op is Ops.REDUCE_AXIS:
|
||||
@@ -247,15 +247,23 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
|
||||
if debug:
|
||||
realized_ranges = rctx.realize_map.get(x, None)
|
||||
disp = []
|
||||
for i, (ri, ro) in enumerate(zip([r.render() for r in rngs], [r.render() for r in out_rngs])):
|
||||
rng = f"{ri}" if ri == ro else f"{ri} -> {ro}"
|
||||
if realized_ranges is not None and i in realized_ranges: rng = colored(rng, "yellow")
|
||||
disp.append("["+rng+"]")
|
||||
print("***" if x in rctx.realize_map else " ", len(consumer_map[x]), f"{str(x.op):20s}", ''.join(disp))
|
||||
if x.op is Ops.RESHAPE or len(rngs) != len(out_rngs):
|
||||
disp = render_ranges(rngs, realized=realized_ranges) + " -> " + render_ranges(out_rngs, realized=realized_ranges)
|
||||
else:
|
||||
disp = render_ranges(rngs, out_rngs, realized=realized_ranges)
|
||||
print("***" if x in rctx.realize_map else " ",
|
||||
f"{len(consumer_map[x]):2d} {str(x.op):20s} {str(x.shape):35s} {len(ending_ranges[x]):2d}", disp)
|
||||
|
||||
# assign to the range map. rngs are the input ranges, out_rngs are the output ranges, from the x op.
|
||||
rctx.range_map[x] = (rngs, out_rngs)
|
||||
|
||||
tsink = graph_rewrite(tsink, pm_apply_rangeify, ctx=rctx, bottom_up=True, name="apply rangeify")
|
||||
return tsink, rctx
|
||||
|
||||
def render_ranges(*rngs_list, realized) -> str:
|
||||
disp = []
|
||||
for i, rs in enumerate(zip(*[[r.render() for r in rngs] for rngs in rngs_list])):
|
||||
rng = rs[0] if all_same(rs) else " -> ".join(rs)
|
||||
if realized is not None and i in realized: rng = colored(rng, "yellow")
|
||||
disp.append("["+rng+"]")
|
||||
return ''.join(disp)
|
||||
|
||||
Reference in New Issue
Block a user