mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
start multioutput support
This commit is contained in:
@@ -2414,5 +2414,13 @@ class TestUOpBecome(unittest.TestCase):
|
||||
b.shrink(((0,4),)).assign(a_view).realize()
|
||||
self.assertListEqual(b.tolist(), [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
|
||||
|
||||
class TestScheduleMultioutput(unittest.TestCase):
|
||||
def test_simplest_multioutput(self):
|
||||
a = Tensor.ones(256, 256).contiguous().realize()
|
||||
r = a.sum(axis=1)
|
||||
b = r+1
|
||||
c = r+2
|
||||
check_schedule([b, c], 1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -171,6 +171,7 @@ SPEC = ContextVar("SPEC", 0)
|
||||
IGNORE_OOB = ContextVar("IGNORE_OOB", 1)
|
||||
PCONTIG = ContextVar("PCONTIG", 0) # partial contiguous in rangeify
|
||||
REAL_SUBSTITUTE = ContextVar("REAL_SUBSTITUTE", 0)
|
||||
DEBUG_RANGEIFY = ContextVar("DEBUG_RANGEIFY", 0)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
|
||||
@@ -141,6 +141,7 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO
|
||||
|
||||
@profile_matches
|
||||
def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
if debug: print("**************************")
|
||||
rctx = IndexingContext()
|
||||
|
||||
# get ops to realize
|
||||
@@ -252,5 +253,27 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
# assign to the range map. rngs are the input ranges, out_rngs are the output ranges, from the x op.
|
||||
rctx.range_map[x] = (rngs, out_rngs)
|
||||
|
||||
# second forward pass to fuse children
|
||||
replaced_ranges = {}
|
||||
for x in tsink.toposort():
|
||||
if x not in rctx.realize_map: continue
|
||||
out_rngs = rctx.range_map[x][1]
|
||||
_realize_axis = rctx.realize_map[x]
|
||||
consumers = [rctx.range_map[u][0] for u in consumer_map[x] if u in rctx.range_map]
|
||||
if len(consumers) < 2: continue
|
||||
assert all(len(out_rngs) == len(rr) for rr in consumers)
|
||||
for i,c in enumerate(zip(*consumers)):
|
||||
out_rng = out_rngs[i]
|
||||
# check if they are all simple ranges
|
||||
if not all(y.op is Ops.RANGE and y.vmax == out_rng.vmax for y in c): continue
|
||||
for r in c: replaced_ranges[r] = out_rngs[i]
|
||||
_realize_axis.remove(i)
|
||||
if len(_realize_axis) == 0: del rctx.realize_map[x]
|
||||
else: rctx.realize_map[x] = _realize_axis
|
||||
|
||||
# do all the replaces
|
||||
for k,(v0,v1) in rctx.range_map.items():
|
||||
rctx.range_map[k] = (tuple(x.substitute(replaced_ranges) for x in v0), tuple(x.substitute(replaced_ranges) for x in v1))
|
||||
|
||||
tsink = graph_rewrite(tsink, pm_apply_rangeify, ctx=rctx, bottom_up=True, name="apply rangeify")
|
||||
return tsink, rctx
|
||||
|
||||
@@ -5,6 +5,7 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
|
||||
from tinygrad.uop.symbolic import symbolic_flat
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata, REAL_SUBSTITUTE
|
||||
from tinygrad.helpers import DEBUG_RANGEIFY
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented
|
||||
from tinygrad.codegen.opt import Opt
|
||||
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
|
||||
@@ -498,7 +499,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
tsink = graph_rewrite(tsink, earliest_rewrites+replace_contiguous, ctx={}, name="earliest rewrites")
|
||||
|
||||
# convert movement ops to ranges
|
||||
tsink, rctx = run_rangeify(tsink, getenv("DEBUG_RANGEIFY", 0))
|
||||
tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY)
|
||||
|
||||
# NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right
|
||||
tsink = graph_rewrite(tsink, symbolic_flat+pm_reduce_unparented, name="symbolic") # this supports const folding
|
||||
|
||||
Reference in New Issue
Block a user