start multioutput support

This commit is contained in:
George Hotz
2025-10-20 11:17:00 +08:00
parent ba593f7b98
commit aecd51f54a
4 changed files with 34 additions and 1 deletions

View File

@@ -2414,5 +2414,13 @@ class TestUOpBecome(unittest.TestCase):
b.shrink(((0,4),)).assign(a_view).realize()
self.assertListEqual(b.tolist(), [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
class TestScheduleMultioutput(unittest.TestCase):
def test_simplest_multioutput(self):
a = Tensor.ones(256, 256).contiguous().realize()
r = a.sum(axis=1)
b = r+1
c = r+2
check_schedule([b, c], 1)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -171,6 +171,7 @@ SPEC = ContextVar("SPEC", 0)
IGNORE_OOB = ContextVar("IGNORE_OOB", 1)
PCONTIG = ContextVar("PCONTIG", 0) # partial contiguous in rangeify
REAL_SUBSTITUTE = ContextVar("REAL_SUBSTITUTE", 0)
DEBUG_RANGEIFY = ContextVar("DEBUG_RANGEIFY", 0)
@dataclass(frozen=True)
class Metadata:

View File

@@ -141,6 +141,7 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO
@profile_matches
def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
if debug: print("**************************")
rctx = IndexingContext()
# get ops to realize
@@ -252,5 +253,27 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
# assign to the range map. rngs are the input ranges, out_rngs are the output ranges, from the x op.
rctx.range_map[x] = (rngs, out_rngs)
# second forward pass to fuse children
replaced_ranges = {}
for x in tsink.toposort():
if x not in rctx.realize_map: continue
out_rngs = rctx.range_map[x][1]
_realize_axis = rctx.realize_map[x]
consumers = [rctx.range_map[u][0] for u in consumer_map[x] if u in rctx.range_map]
if len(consumers) < 2: continue
assert all(len(out_rngs) == len(rr) for rr in consumers)
for i,c in enumerate(zip(*consumers)):
out_rng = out_rngs[i]
# check if they are all simple ranges
if not all(y.op is Ops.RANGE and y.vmax == out_rng.vmax for y in c): continue
for r in c: replaced_ranges[r] = out_rngs[i]
_realize_axis.remove(i)
if len(_realize_axis) == 0: del rctx.realize_map[x]
else: rctx.realize_map[x] = _realize_axis
# do all the replaces
for k,(v0,v1) in rctx.range_map.items():
rctx.range_map[k] = (tuple(x.substitute(replaced_ranges) for x in v0), tuple(x.substitute(replaced_ranges) for x in v1))
tsink = graph_rewrite(tsink, pm_apply_rangeify, ctx=rctx, bottom_up=True, name="apply rangeify")
return tsink, rctx

View File

@@ -5,6 +5,7 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
from tinygrad.uop.symbolic import symbolic_flat
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata, REAL_SUBSTITUTE
from tinygrad.helpers import DEBUG_RANGEIFY
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented
from tinygrad.codegen.opt import Opt
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
@@ -498,7 +499,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
tsink = graph_rewrite(tsink, earliest_rewrites+replace_contiguous, ctx={}, name="earliest rewrites")
# convert movement ops to ranges
tsink, rctx = run_rangeify(tsink, getenv("DEBUG_RANGEIFY", 0))
tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY)
# NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right
tsink = graph_rewrite(tsink, symbolic_flat+pm_reduce_unparented, name="symbolic") # this supports const folding