From 39b0f4bcc161d82bbb33bc086ac9fba954029793 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 11 Mar 2026 05:30:33 -0400 Subject: [PATCH] remove Ops.THREEFRY in remove_bufferize [pr] (#15220) --- tinygrad/schedule/rangeify.py | 90 +++++++++++++++++------------------ 1 file changed, 44 insertions(+), 46 deletions(-) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 9be14b4573..d12bf1823a 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -210,56 +210,54 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp): # if it's user contiguous, we never remove it if src.op in ALWAYS_RUN_OPS or not buf.arg.removable: return None - # we don't want to bufferize threefry, also causes problems because not all platforms support long - if src.op is not Ops.THREEFRY: - # *** here is where we compute the cost *** - # if we return None, the bufferize is kept + # *** here is where we compute the cost *** + # if we return None, the bufferize is kept - accessed_buffers: list[UOp] = [] - indexes: list[UOp] = [] - reduces: list[UOp] = [] - def red_gate(x:UOp): - if (x.op is Ops.BUFFERIZE and x.arg.addrspace == AddrSpace.GLOBAL) or x.op is Ops.MSTACK: - accessed_buffers.append(x) - return False - if x.op is Ops.PARAM: - accessed_buffers.append(x) - if x.op is Ops.INDEX: - indexes.append(x) - if x.op is Ops.REDUCE: reduces.append(x) - return True - src.toposort(gate=red_gate) - del red_gate - accessed_buffers = dedup(accessed_buffers) + accessed_buffers: list[UOp] = [] + indexes: list[UOp] = [] + reduces: list[UOp] = [] + def red_gate(x:UOp): + if (x.op is Ops.BUFFERIZE and x.arg.addrspace == AddrSpace.GLOBAL) or x.op is Ops.MSTACK: + accessed_buffers.append(x) + return False + if x.op is Ops.PARAM: + accessed_buffers.append(x) + if x.op is Ops.INDEX: + indexes.append(x) + if x.op is Ops.REDUCE: reduces.append(x) + return True + src.toposort(gate=red_gate) + del red_gate + accessed_buffers = dedup(accessed_buffers) - # if this is generated from multiple buffers, don't remove this buffer - if len(accessed_buffers) > 3 and not (PCONTIG > 2): return None + # if this is generated from multiple buffers, don't remove this buffer + if len(accessed_buffers) > 3 and not (PCONTIG > 2): return None - # if any reduces access a buffer, don't remove this buffer - buffer_in_reduce = False - def buf_gate(x:UOp): - nonlocal buffer_in_reduce - if x.op in {Ops.PARAM, Ops.BUFFERIZE}: buffer_in_reduce = True - return not buffer_in_reduce - UOp.sink(*[x.src[0] for x in reduces]).toposort(gate=buf_gate) - del buf_gate - if buffer_in_reduce: - if PCONTIG > 2: - out_in_ratio = (prod(buf.shape)+1) / (sum([x.size for x in accessed_buffers])+1) - if out_in_ratio < 10: return None - # here we have to check the indexes, we might do a partial contig here - local_indexes = [x for x in indexes if x.src[0].op is Ops.BUFFERIZE and x.src[0].arg.addrspace == AddrSpace.LOCAL] - exclude_ranges = UOp.group(*[UOp.group(*x.src[1:]) for x in local_indexes]).ranges - subs = [(k,v) for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST] - # if it's bufferized or a reduce, it's pcontig - is_pcontig, is_subs = partition(subs, lambda x: x[0] in exclude_ranges or any([r.arg[-1] == AxisType.REDUCE for r in x[1].ranges])) - if not len(is_subs): - return None - if len(is_pcontig): - ret = src.substitute(dict(is_subs), extra_pm=pm_gate_substitute) - return ret.bufferize(*[x[0] for x in is_pcontig], arg=BufferizeOpts(None, AddrSpace.LOCAL)).index(*[x[1] for x in is_pcontig]) - else: + # if any reduces access a buffer, don't remove this buffer + buffer_in_reduce = False + def buf_gate(x:UOp): + nonlocal buffer_in_reduce + if x.op in {Ops.PARAM, Ops.BUFFERIZE}: buffer_in_reduce = True + return not buffer_in_reduce + UOp.sink(*[x.src[0] for x in reduces]).toposort(gate=buf_gate) + del buf_gate + if buffer_in_reduce: + if PCONTIG > 2: + out_in_ratio = (prod(buf.shape)+1) / (sum([x.size for x in accessed_buffers])+1) + if out_in_ratio < 10: return None + # here we have to check the indexes, we might do a partial contig here + local_indexes = [x for x in indexes if x.src[0].op is Ops.BUFFERIZE and x.src[0].arg.addrspace == AddrSpace.LOCAL] + exclude_ranges = UOp.group(*[UOp.group(*x.src[1:]) for x in local_indexes]).ranges + subs = [(k,v) for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST] + # if it's bufferized or a reduce, it's pcontig + is_pcontig, is_subs = partition(subs, lambda x: x[0] in exclude_ranges or any([r.arg[-1] == AxisType.REDUCE for r in x[1].ranges])) + if not len(is_subs): return None + if len(is_pcontig): + ret = src.substitute(dict(is_subs), extra_pm=pm_gate_substitute) + return ret.bufferize(*[x[0] for x in is_pcontig], arg=BufferizeOpts(None, AddrSpace.LOCAL)).index(*[x[1] for x in is_pcontig]) + else: + return None # if it makes it here, the bufferize is removed # this is the ranges replaced