mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-15 01:48:23 -05:00
fast rangeify works for mnist
This commit is contained in:
@@ -753,22 +753,23 @@ def apply_rangeify(ctx, x:UOp):
|
||||
for s in x.src:
|
||||
new_src = s
|
||||
if s in realize_map:
|
||||
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(s,)+tuple(range_map[s]), arg=BufferizeOpts(device=s.device), tag=s.tag)
|
||||
if x in range_map: new_src = new_src.index(*range_map[x])
|
||||
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(s,)+tuple(range_map[s][1]), arg=BufferizeOpts(device=s.device), tag=s.tag)
|
||||
if x in range_map: new_src = new_src.index(*range_map[x][0])
|
||||
elif s.op is Ops.BUFFER:
|
||||
new_src = new_src.index(*range_map[x])
|
||||
new_src = new_src.index(*range_map[x][0])
|
||||
new_srcs.append(new_src)
|
||||
# NOTE: do we need this?
|
||||
return x.replace(src=tns) if x.src != (tns:=tuple(new_srcs)) else None
|
||||
|
||||
def apply_pad(ctx, x:UOp):
|
||||
realize_map, range_map = ctx
|
||||
bigwhere: UOp = functools.reduce(operator.and_, [u.src[0] for u in range_map[x] if u.op is Ops.WHERE], UOp.const(dtypes.bool, True))
|
||||
bigwhere: UOp = functools.reduce(operator.and_, [u.src[0] for u in range_map[x][0] if u.op is Ops.WHERE], UOp.const(dtypes.bool, True))
|
||||
return bigwhere.simplify().where(x.src[0], UOp.const(x.dtype, 0))
|
||||
|
||||
def fix_reduce_axis(ctx, x:UOp):
|
||||
realize_map, range_map = ctx
|
||||
new_ranges = [r for i,r in enumerate(range_map[x]) if i in x.arg[1]]
|
||||
# input ranges
|
||||
new_ranges = [r for i,r in enumerate(range_map[x][0]) if i in x.arg[1]]
|
||||
ret = UOp(Ops.REDUCE, x.dtype, src=(x.src[0],)+tuple(new_ranges), arg=x.arg[0], tag=x.tag)
|
||||
range_map[ret] = range_map[x]
|
||||
return ret
|
||||
@@ -807,27 +808,44 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
realize_map[x] = None
|
||||
ending_ranges[x] = False
|
||||
|
||||
# *** these are the ranges on the output ***
|
||||
|
||||
if x in realize_map:
|
||||
# if this is in the realize_map, we create new ranges (at the output)
|
||||
assert x.op not in GroupOp.Movement
|
||||
rngs = [ctx.new_range(s) for s in x.shape]
|
||||
#assert x.op not in GroupOp.Movement
|
||||
out_rngs = [ctx.new_range(s) for s in x.shape]
|
||||
elif len(consumer_map[x]) == 0:
|
||||
continue
|
||||
elif len(consumer_map[x]) > 1:
|
||||
# if this has two consumers, we have to merge the ranges and might create new ones
|
||||
all_rngs = list(zip(*[range_map[c] for c in consumer_map[x]]))
|
||||
# for RANGEIFY=1, if any differ we create all new ones
|
||||
all_all_same = all(all_same(x) for x in all_rngs)
|
||||
if all_all_same:
|
||||
# all are the same
|
||||
rngs = range_map[list(consumer_map[x])[0]]
|
||||
else:
|
||||
# create new ranges and add to realize_map
|
||||
rngs = [ctx.new_range(s) for s in x.shape]
|
||||
realize_map[x] = None
|
||||
all_rngs = list(zip(*[range_map[c][0] for c in consumer_map[x]]))
|
||||
|
||||
rngs_valids = []
|
||||
for valid_rngs in all_rngs:
|
||||
local_rngs, valids = zip(*[(r.get_idx(), r.get_valid()) for r in valid_rngs])
|
||||
# if a range has a 1 src, it's the same as UOp.const(dtypes.index, 0)
|
||||
same_rngs = [x if x.op is not Ops.RANGE or resolve(x.src[0] != 1) else UOp.const(dtypes.index, 0) for x in local_rngs]
|
||||
rngs_valids.append((local_rngs, valids, all_same(same_rngs)))
|
||||
|
||||
all_all_same = all(same_rngs for _,_,same_rngs in rngs_valids)
|
||||
out_rngs = []
|
||||
for i,(local_rngs,valids,same_rngs) in enumerate(rngs_valids):
|
||||
# we compare the ranges without their valids
|
||||
if all_all_same:
|
||||
# the new valid is the OR of all the children valids
|
||||
minimum_valid = functools.reduce(operator.or_, valids, UOp.const(dtypes.bool, False))
|
||||
out_rngs.append(minimum_valid.where(local_rngs[0], UOp.invalid()).simplify())
|
||||
else:
|
||||
out_rngs.append(ctx.new_range(x.shape[i]))
|
||||
|
||||
# we have to realize here if there's new ranges
|
||||
if not all_all_same: realize_map[x] = None
|
||||
else:
|
||||
# if this has one consumer, we just pass it through from the consumer
|
||||
rngs = range_map[list(consumer_map[x])[0]]
|
||||
out_rngs = range_map[list(consumer_map[x])[0]][0]
|
||||
|
||||
# rngs is the input ranges
|
||||
rngs = out_rngs[:]
|
||||
|
||||
# handle REDUCE
|
||||
if x.op is Ops.REDUCE_AXIS:
|
||||
@@ -842,6 +860,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
rngs = [a if resolve(x==y, False) else a.const_like(0) for a,x,y in zip(rngs, x.src[0].shape, x.shape)]
|
||||
ending_ranges[x] = True
|
||||
if x.op is Ops.PAD:
|
||||
rngs = rngs[:]
|
||||
for i,(sh,(s,e)) in enumerate(zip(x.shape, x.arg)):
|
||||
if s == 0 and e == 0: continue
|
||||
where = UOp.const(dtypes.bool, True)
|
||||
@@ -862,8 +881,10 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
ret.append(mish % s) # NOTE: simplify will turn this to CONST
|
||||
mish //= s
|
||||
rngs = list(UOp.sink(*ret[::-1]).simplify().src)
|
||||
range_map[x] = rngs
|
||||
#for k,v in range_map.items(): print("***" if k in realize_map else " ", k.op, UOp.sink().index(*v).render())
|
||||
range_map[x] = (rngs, out_rngs)
|
||||
for k,(in_rng,out_rng) in range_map.items():
|
||||
print("***" if k in realize_map else " ", len(consumer_map[k]), k.op,
|
||||
UOp.sink().index(*in_rng).render(), " -> ", UOp.sink().index(*out_rng).render())
|
||||
tsink = graph_rewrite(tsink, pm_apply_rangeify, ctx=(realize_map,range_map), bottom_up=True, name="apply rangeify")
|
||||
else:
|
||||
# NOTE: we don't use contiguous here, contiguous is a user op
|
||||
|
||||
@@ -163,9 +163,10 @@ def mem_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int,
|
||||
for st,_,_,e in dev_events:
|
||||
if not isinstance(e, ProfilePointEvent): continue
|
||||
if e.name == "alloc":
|
||||
events.append(struct.pack("<BIIIQ", 1, int(e.ts)-start_ts, e.key, enum_str(e.arg["dtype"].name, scache), e.arg["sz"]))
|
||||
safe_sz = min(1_000_000_000_000, e.arg["sz"])
|
||||
events.append(struct.pack("<BIIIQ", 1, int(e.ts)-start_ts, e.key, enum_str(e.arg["dtype"].name, scache), safe_sz))
|
||||
dtype_size.setdefault(e.arg["dtype"].name, e.arg["dtype"].itemsize)
|
||||
temp[e.key] = nbytes = e.arg["sz"]*e.arg["dtype"].itemsize
|
||||
temp[e.key] = nbytes = safe_sz*e.arg["dtype"].itemsize
|
||||
mem += nbytes
|
||||
if mem > peak: peak = mem
|
||||
if e.name == "free":
|
||||
|
||||
Reference in New Issue
Block a user