remove Ops.THREEFRY in remove_bufferize [pr] (#15220)

This commit is contained in:
chenyu
2026-03-11 05:30:33 -04:00
committed by GitHub
parent 6489a6f212
commit 39b0f4bcc1

View File

@@ -210,56 +210,54 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
# if it's user contiguous, we never remove it
if src.op in ALWAYS_RUN_OPS or not buf.arg.removable: return None
# we don't want to bufferize threefry, also causes problems because not all platforms support long
if src.op is not Ops.THREEFRY:
# *** here is where we compute the cost ***
# if we return None, the bufferize is kept
# *** here is where we compute the cost ***
# if we return None, the bufferize is kept
accessed_buffers: list[UOp] = []
indexes: list[UOp] = []
reduces: list[UOp] = []
def red_gate(x:UOp):
if (x.op is Ops.BUFFERIZE and x.arg.addrspace == AddrSpace.GLOBAL) or x.op is Ops.MSTACK:
accessed_buffers.append(x)
return False
if x.op is Ops.PARAM:
accessed_buffers.append(x)
if x.op is Ops.INDEX:
indexes.append(x)
if x.op is Ops.REDUCE: reduces.append(x)
return True
src.toposort(gate=red_gate)
del red_gate
accessed_buffers = dedup(accessed_buffers)
accessed_buffers: list[UOp] = []
indexes: list[UOp] = []
reduces: list[UOp] = []
def red_gate(x:UOp):
if (x.op is Ops.BUFFERIZE and x.arg.addrspace == AddrSpace.GLOBAL) or x.op is Ops.MSTACK:
accessed_buffers.append(x)
return False
if x.op is Ops.PARAM:
accessed_buffers.append(x)
if x.op is Ops.INDEX:
indexes.append(x)
if x.op is Ops.REDUCE: reduces.append(x)
return True
src.toposort(gate=red_gate)
del red_gate
accessed_buffers = dedup(accessed_buffers)
# if this is generated from multiple buffers, don't remove this buffer
if len(accessed_buffers) > 3 and not (PCONTIG > 2): return None
# if this is generated from multiple buffers, don't remove this buffer
if len(accessed_buffers) > 3 and not (PCONTIG > 2): return None
# if any reduces access a buffer, don't remove this buffer
buffer_in_reduce = False
def buf_gate(x:UOp):
nonlocal buffer_in_reduce
if x.op in {Ops.PARAM, Ops.BUFFERIZE}: buffer_in_reduce = True
return not buffer_in_reduce
UOp.sink(*[x.src[0] for x in reduces]).toposort(gate=buf_gate)
del buf_gate
if buffer_in_reduce:
if PCONTIG > 2:
out_in_ratio = (prod(buf.shape)+1) / (sum([x.size for x in accessed_buffers])+1)
if out_in_ratio < 10: return None
# here we have to check the indexes, we might do a partial contig here
local_indexes = [x for x in indexes if x.src[0].op is Ops.BUFFERIZE and x.src[0].arg.addrspace == AddrSpace.LOCAL]
exclude_ranges = UOp.group(*[UOp.group(*x.src[1:]) for x in local_indexes]).ranges
subs = [(k,v) for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST]
# if it's bufferized or a reduce, it's pcontig
is_pcontig, is_subs = partition(subs, lambda x: x[0] in exclude_ranges or any([r.arg[-1] == AxisType.REDUCE for r in x[1].ranges]))
if not len(is_subs):
return None
if len(is_pcontig):
ret = src.substitute(dict(is_subs), extra_pm=pm_gate_substitute)
return ret.bufferize(*[x[0] for x in is_pcontig], arg=BufferizeOpts(None, AddrSpace.LOCAL)).index(*[x[1] for x in is_pcontig])
else:
# if any reduces access a buffer, don't remove this buffer
buffer_in_reduce = False
def buf_gate(x:UOp):
nonlocal buffer_in_reduce
if x.op in {Ops.PARAM, Ops.BUFFERIZE}: buffer_in_reduce = True
return not buffer_in_reduce
UOp.sink(*[x.src[0] for x in reduces]).toposort(gate=buf_gate)
del buf_gate
if buffer_in_reduce:
if PCONTIG > 2:
out_in_ratio = (prod(buf.shape)+1) / (sum([x.size for x in accessed_buffers])+1)
if out_in_ratio < 10: return None
# here we have to check the indexes, we might do a partial contig here
local_indexes = [x for x in indexes if x.src[0].op is Ops.BUFFERIZE and x.src[0].arg.addrspace == AddrSpace.LOCAL]
exclude_ranges = UOp.group(*[UOp.group(*x.src[1:]) for x in local_indexes]).ranges
subs = [(k,v) for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST]
# if it's bufferized or a reduce, it's pcontig
is_pcontig, is_subs = partition(subs, lambda x: x[0] in exclude_ranges or any([r.arg[-1] == AxisType.REDUCE for r in x[1].ranges]))
if not len(is_subs):
return None
if len(is_pcontig):
ret = src.substitute(dict(is_subs), extra_pm=pm_gate_substitute)
return ret.bufferize(*[x[0] for x in is_pcontig], arg=BufferizeOpts(None, AddrSpace.LOCAL)).index(*[x[1] for x in is_pcontig])
else:
return None
# if it makes it here, the bufferize is removed
# this is the ranges replaced