diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index bb972e767e..a802eb705e 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -753,22 +753,23 @@ def apply_rangeify(ctx, x:UOp): for s in x.src: new_src = s if s in realize_map: - new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(s,)+tuple(range_map[s]), arg=BufferizeOpts(device=s.device), tag=s.tag) - if x in range_map: new_src = new_src.index(*range_map[x]) + new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(s,)+tuple(range_map[s][1]), arg=BufferizeOpts(device=s.device), tag=s.tag) + if x in range_map: new_src = new_src.index(*range_map[x][0]) elif s.op is Ops.BUFFER: - new_src = new_src.index(*range_map[x]) + new_src = new_src.index(*range_map[x][0]) new_srcs.append(new_src) # NOTE: do we need this? return x.replace(src=tns) if x.src != (tns:=tuple(new_srcs)) else None def apply_pad(ctx, x:UOp): realize_map, range_map = ctx - bigwhere: UOp = functools.reduce(operator.and_, [u.src[0] for u in range_map[x] if u.op is Ops.WHERE], UOp.const(dtypes.bool, True)) + bigwhere: UOp = functools.reduce(operator.and_, [u.src[0] for u in range_map[x][0] if u.op is Ops.WHERE], UOp.const(dtypes.bool, True)) return bigwhere.simplify().where(x.src[0], UOp.const(x.dtype, 0)) def fix_reduce_axis(ctx, x:UOp): realize_map, range_map = ctx - new_ranges = [r for i,r in enumerate(range_map[x]) if i in x.arg[1]] + # input ranges + new_ranges = [r for i,r in enumerate(range_map[x][0]) if i in x.arg[1]] ret = UOp(Ops.REDUCE, x.dtype, src=(x.src[0],)+tuple(new_ranges), arg=x.arg[0], tag=x.tag) range_map[ret] = range_map[x] return ret @@ -807,27 +808,44 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: realize_map[x] = None ending_ranges[x] = False + # *** these are the ranges on the output *** + if x in realize_map: # if this is in the realize_map, we create new ranges (at the output) - assert x.op not in GroupOp.Movement - rngs = [ctx.new_range(s) for s in x.shape] + #assert x.op not in GroupOp.Movement + out_rngs = [ctx.new_range(s) for s in x.shape] elif len(consumer_map[x]) == 0: continue elif len(consumer_map[x]) > 1: # if this has two consumers, we have to merge the ranges and might create new ones - all_rngs = list(zip(*[range_map[c] for c in consumer_map[x]])) - # for RANGEIFY=1, if any differ we create all new ones - all_all_same = all(all_same(x) for x in all_rngs) - if all_all_same: - # all are the same - rngs = range_map[list(consumer_map[x])[0]] - else: - # create new ranges and add to realize_map - rngs = [ctx.new_range(s) for s in x.shape] - realize_map[x] = None + all_rngs = list(zip(*[range_map[c][0] for c in consumer_map[x]])) + + rngs_valids = [] + for valid_rngs in all_rngs: + local_rngs, valids = zip(*[(r.get_idx(), r.get_valid()) for r in valid_rngs]) + # if a range has a 1 src, it's the same as UOp.const(dtypes.index, 0) + same_rngs = [x if x.op is not Ops.RANGE or resolve(x.src[0] != 1) else UOp.const(dtypes.index, 0) for x in local_rngs] + rngs_valids.append((local_rngs, valids, all_same(same_rngs))) + + all_all_same = all(same_rngs for _,_,same_rngs in rngs_valids) + out_rngs = [] + for i,(local_rngs,valids,same_rngs) in enumerate(rngs_valids): + # we compare the ranges without their valids + if all_all_same: + # the new valid is the OR of all the children valids + minimum_valid = functools.reduce(operator.or_, valids, UOp.const(dtypes.bool, False)) + out_rngs.append(minimum_valid.where(local_rngs[0], UOp.invalid()).simplify()) + else: + out_rngs.append(ctx.new_range(x.shape[i])) + + # we have to realize here if there's new ranges + if not all_all_same: realize_map[x] = None else: # if this has one consumer, we just pass it through from the consumer - rngs = range_map[list(consumer_map[x])[0]] + out_rngs = range_map[list(consumer_map[x])[0]][0] + + # rngs is the input ranges + rngs = out_rngs[:] # handle REDUCE if x.op is Ops.REDUCE_AXIS: @@ -842,6 +860,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: rngs = [a if resolve(x==y, False) else a.const_like(0) for a,x,y in zip(rngs, x.src[0].shape, x.shape)] ending_ranges[x] = True if x.op is Ops.PAD: + rngs = rngs[:] for i,(sh,(s,e)) in enumerate(zip(x.shape, x.arg)): if s == 0 and e == 0: continue where = UOp.const(dtypes.bool, True) @@ -862,8 +881,10 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: ret.append(mish % s) # NOTE: simplify will turn this to CONST mish //= s rngs = list(UOp.sink(*ret[::-1]).simplify().src) - range_map[x] = rngs - #for k,v in range_map.items(): print("***" if k in realize_map else " ", k.op, UOp.sink().index(*v).render()) + range_map[x] = (rngs, out_rngs) + for k,(in_rng,out_rng) in range_map.items(): + print("***" if k in realize_map else " ", len(consumer_map[k]), k.op, + UOp.sink().index(*in_rng).render(), " -> ", UOp.sink().index(*out_rng).render()) tsink = graph_rewrite(tsink, pm_apply_rangeify, ctx=(realize_map,range_map), bottom_up=True, name="apply rangeify") else: # NOTE: we don't use contiguous here, contiguous is a user op diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 22941527c4..de4b9994f0 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -163,9 +163,10 @@ def mem_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int, for st,_,_,e in dev_events: if not isinstance(e, ProfilePointEvent): continue if e.name == "alloc": - events.append(struct.pack(" peak: peak = mem if e.name == "free":