fast rangeify works for mnist

This commit is contained in:
George Hotz
2025-10-08 12:58:40 +08:00
parent e35152a004
commit 53eb2af4ce
2 changed files with 44 additions and 22 deletions

View File

@@ -753,22 +753,23 @@ def apply_rangeify(ctx, x:UOp):
for s in x.src:
new_src = s
if s in realize_map:
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(s,)+tuple(range_map[s]), arg=BufferizeOpts(device=s.device), tag=s.tag)
if x in range_map: new_src = new_src.index(*range_map[x])
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(s,)+tuple(range_map[s][1]), arg=BufferizeOpts(device=s.device), tag=s.tag)
if x in range_map: new_src = new_src.index(*range_map[x][0])
elif s.op is Ops.BUFFER:
new_src = new_src.index(*range_map[x])
new_src = new_src.index(*range_map[x][0])
new_srcs.append(new_src)
# NOTE: do we need this?
return x.replace(src=tns) if x.src != (tns:=tuple(new_srcs)) else None
def apply_pad(ctx, x:UOp):
realize_map, range_map = ctx
bigwhere: UOp = functools.reduce(operator.and_, [u.src[0] for u in range_map[x] if u.op is Ops.WHERE], UOp.const(dtypes.bool, True))
bigwhere: UOp = functools.reduce(operator.and_, [u.src[0] for u in range_map[x][0] if u.op is Ops.WHERE], UOp.const(dtypes.bool, True))
return bigwhere.simplify().where(x.src[0], UOp.const(x.dtype, 0))
def fix_reduce_axis(ctx, x:UOp):
realize_map, range_map = ctx
new_ranges = [r for i,r in enumerate(range_map[x]) if i in x.arg[1]]
# input ranges
new_ranges = [r for i,r in enumerate(range_map[x][0]) if i in x.arg[1]]
ret = UOp(Ops.REDUCE, x.dtype, src=(x.src[0],)+tuple(new_ranges), arg=x.arg[0], tag=x.tag)
range_map[ret] = range_map[x]
return ret
@@ -807,27 +808,44 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
realize_map[x] = None
ending_ranges[x] = False
# *** these are the ranges on the output ***
if x in realize_map:
# if this is in the realize_map, we create new ranges (at the output)
assert x.op not in GroupOp.Movement
rngs = [ctx.new_range(s) for s in x.shape]
#assert x.op not in GroupOp.Movement
out_rngs = [ctx.new_range(s) for s in x.shape]
elif len(consumer_map[x]) == 0:
continue
elif len(consumer_map[x]) > 1:
# if this has two consumers, we have to merge the ranges and might create new ones
all_rngs = list(zip(*[range_map[c] for c in consumer_map[x]]))
# for RANGEIFY=1, if any differ we create all new ones
all_all_same = all(all_same(x) for x in all_rngs)
if all_all_same:
# all are the same
rngs = range_map[list(consumer_map[x])[0]]
else:
# create new ranges and add to realize_map
rngs = [ctx.new_range(s) for s in x.shape]
realize_map[x] = None
all_rngs = list(zip(*[range_map[c][0] for c in consumer_map[x]]))
rngs_valids = []
for valid_rngs in all_rngs:
local_rngs, valids = zip(*[(r.get_idx(), r.get_valid()) for r in valid_rngs])
# if a range has a 1 src, it's the same as UOp.const(dtypes.index, 0)
same_rngs = [x if x.op is not Ops.RANGE or resolve(x.src[0] != 1) else UOp.const(dtypes.index, 0) for x in local_rngs]
rngs_valids.append((local_rngs, valids, all_same(same_rngs)))
all_all_same = all(same_rngs for _,_,same_rngs in rngs_valids)
out_rngs = []
for i,(local_rngs,valids,same_rngs) in enumerate(rngs_valids):
# we compare the ranges without their valids
if all_all_same:
# the new valid is the OR of all the children valids
minimum_valid = functools.reduce(operator.or_, valids, UOp.const(dtypes.bool, False))
out_rngs.append(minimum_valid.where(local_rngs[0], UOp.invalid()).simplify())
else:
out_rngs.append(ctx.new_range(x.shape[i]))
# we have to realize here if there's new ranges
if not all_all_same: realize_map[x] = None
else:
# if this has one consumer, we just pass it through from the consumer
rngs = range_map[list(consumer_map[x])[0]]
out_rngs = range_map[list(consumer_map[x])[0]][0]
# rngs is the input ranges
rngs = out_rngs[:]
# handle REDUCE
if x.op is Ops.REDUCE_AXIS:
@@ -842,6 +860,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
rngs = [a if resolve(x==y, False) else a.const_like(0) for a,x,y in zip(rngs, x.src[0].shape, x.shape)]
ending_ranges[x] = True
if x.op is Ops.PAD:
rngs = rngs[:]
for i,(sh,(s,e)) in enumerate(zip(x.shape, x.arg)):
if s == 0 and e == 0: continue
where = UOp.const(dtypes.bool, True)
@@ -862,8 +881,10 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
ret.append(mish % s) # NOTE: simplify will turn this to CONST
mish //= s
rngs = list(UOp.sink(*ret[::-1]).simplify().src)
range_map[x] = rngs
#for k,v in range_map.items(): print("***" if k in realize_map else " ", k.op, UOp.sink().index(*v).render())
range_map[x] = (rngs, out_rngs)
for k,(in_rng,out_rng) in range_map.items():
print("***" if k in realize_map else " ", len(consumer_map[k]), k.op,
UOp.sink().index(*in_rng).render(), " -> ", UOp.sink().index(*out_rng).render())
tsink = graph_rewrite(tsink, pm_apply_rangeify, ctx=(realize_map,range_map), bottom_up=True, name="apply rangeify")
else:
# NOTE: we don't use contiguous here, contiguous is a user op

View File

@@ -163,9 +163,10 @@ def mem_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int,
for st,_,_,e in dev_events:
if not isinstance(e, ProfilePointEvent): continue
if e.name == "alloc":
events.append(struct.pack("<BIIIQ", 1, int(e.ts)-start_ts, e.key, enum_str(e.arg["dtype"].name, scache), e.arg["sz"]))
safe_sz = min(1_000_000_000_000, e.arg["sz"])
events.append(struct.pack("<BIIIQ", 1, int(e.ts)-start_ts, e.key, enum_str(e.arg["dtype"].name, scache), safe_sz))
dtype_size.setdefault(e.arg["dtype"].name, e.arg["dtype"].itemsize)
temp[e.key] = nbytes = e.arg["sz"]*e.arg["dtype"].itemsize
temp[e.key] = nbytes = safe_sz*e.arg["dtype"].itemsize
mem += nbytes
if mem > peak: peak = mem
if e.name == "free":