mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
remove Ops.THREEFRY in remove_bufferize [pr] (#15220)
This commit is contained in:
@@ -210,56 +210,54 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
|
||||
# if it's user contiguous, we never remove it
|
||||
if src.op in ALWAYS_RUN_OPS or not buf.arg.removable: return None
|
||||
|
||||
# we don't want to bufferize threefry, also causes problems because not all platforms support long
|
||||
if src.op is not Ops.THREEFRY:
|
||||
# *** here is where we compute the cost ***
|
||||
# if we return None, the bufferize is kept
|
||||
# *** here is where we compute the cost ***
|
||||
# if we return None, the bufferize is kept
|
||||
|
||||
accessed_buffers: list[UOp] = []
|
||||
indexes: list[UOp] = []
|
||||
reduces: list[UOp] = []
|
||||
def red_gate(x:UOp):
|
||||
if (x.op is Ops.BUFFERIZE and x.arg.addrspace == AddrSpace.GLOBAL) or x.op is Ops.MSTACK:
|
||||
accessed_buffers.append(x)
|
||||
return False
|
||||
if x.op is Ops.PARAM:
|
||||
accessed_buffers.append(x)
|
||||
if x.op is Ops.INDEX:
|
||||
indexes.append(x)
|
||||
if x.op is Ops.REDUCE: reduces.append(x)
|
||||
return True
|
||||
src.toposort(gate=red_gate)
|
||||
del red_gate
|
||||
accessed_buffers = dedup(accessed_buffers)
|
||||
accessed_buffers: list[UOp] = []
|
||||
indexes: list[UOp] = []
|
||||
reduces: list[UOp] = []
|
||||
def red_gate(x:UOp):
|
||||
if (x.op is Ops.BUFFERIZE and x.arg.addrspace == AddrSpace.GLOBAL) or x.op is Ops.MSTACK:
|
||||
accessed_buffers.append(x)
|
||||
return False
|
||||
if x.op is Ops.PARAM:
|
||||
accessed_buffers.append(x)
|
||||
if x.op is Ops.INDEX:
|
||||
indexes.append(x)
|
||||
if x.op is Ops.REDUCE: reduces.append(x)
|
||||
return True
|
||||
src.toposort(gate=red_gate)
|
||||
del red_gate
|
||||
accessed_buffers = dedup(accessed_buffers)
|
||||
|
||||
# if this is generated from multiple buffers, don't remove this buffer
|
||||
if len(accessed_buffers) > 3 and not (PCONTIG > 2): return None
|
||||
# if this is generated from multiple buffers, don't remove this buffer
|
||||
if len(accessed_buffers) > 3 and not (PCONTIG > 2): return None
|
||||
|
||||
# if any reduces access a buffer, don't remove this buffer
|
||||
buffer_in_reduce = False
|
||||
def buf_gate(x:UOp):
|
||||
nonlocal buffer_in_reduce
|
||||
if x.op in {Ops.PARAM, Ops.BUFFERIZE}: buffer_in_reduce = True
|
||||
return not buffer_in_reduce
|
||||
UOp.sink(*[x.src[0] for x in reduces]).toposort(gate=buf_gate)
|
||||
del buf_gate
|
||||
if buffer_in_reduce:
|
||||
if PCONTIG > 2:
|
||||
out_in_ratio = (prod(buf.shape)+1) / (sum([x.size for x in accessed_buffers])+1)
|
||||
if out_in_ratio < 10: return None
|
||||
# here we have to check the indexes, we might do a partial contig here
|
||||
local_indexes = [x for x in indexes if x.src[0].op is Ops.BUFFERIZE and x.src[0].arg.addrspace == AddrSpace.LOCAL]
|
||||
exclude_ranges = UOp.group(*[UOp.group(*x.src[1:]) for x in local_indexes]).ranges
|
||||
subs = [(k,v) for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST]
|
||||
# if it's bufferized or a reduce, it's pcontig
|
||||
is_pcontig, is_subs = partition(subs, lambda x: x[0] in exclude_ranges or any([r.arg[-1] == AxisType.REDUCE for r in x[1].ranges]))
|
||||
if not len(is_subs):
|
||||
return None
|
||||
if len(is_pcontig):
|
||||
ret = src.substitute(dict(is_subs), extra_pm=pm_gate_substitute)
|
||||
return ret.bufferize(*[x[0] for x in is_pcontig], arg=BufferizeOpts(None, AddrSpace.LOCAL)).index(*[x[1] for x in is_pcontig])
|
||||
else:
|
||||
# if any reduces access a buffer, don't remove this buffer
|
||||
buffer_in_reduce = False
|
||||
def buf_gate(x:UOp):
|
||||
nonlocal buffer_in_reduce
|
||||
if x.op in {Ops.PARAM, Ops.BUFFERIZE}: buffer_in_reduce = True
|
||||
return not buffer_in_reduce
|
||||
UOp.sink(*[x.src[0] for x in reduces]).toposort(gate=buf_gate)
|
||||
del buf_gate
|
||||
if buffer_in_reduce:
|
||||
if PCONTIG > 2:
|
||||
out_in_ratio = (prod(buf.shape)+1) / (sum([x.size for x in accessed_buffers])+1)
|
||||
if out_in_ratio < 10: return None
|
||||
# here we have to check the indexes, we might do a partial contig here
|
||||
local_indexes = [x for x in indexes if x.src[0].op is Ops.BUFFERIZE and x.src[0].arg.addrspace == AddrSpace.LOCAL]
|
||||
exclude_ranges = UOp.group(*[UOp.group(*x.src[1:]) for x in local_indexes]).ranges
|
||||
subs = [(k,v) for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST]
|
||||
# if it's bufferized or a reduce, it's pcontig
|
||||
is_pcontig, is_subs = partition(subs, lambda x: x[0] in exclude_ranges or any([r.arg[-1] == AxisType.REDUCE for r in x[1].ranges]))
|
||||
if not len(is_subs):
|
||||
return None
|
||||
if len(is_pcontig):
|
||||
ret = src.substitute(dict(is_subs), extra_pm=pm_gate_substitute)
|
||||
return ret.bufferize(*[x[0] for x in is_pcontig], arg=BufferizeOpts(None, AddrSpace.LOCAL)).index(*[x[1] for x in is_pcontig])
|
||||
else:
|
||||
return None
|
||||
|
||||
# if it makes it here, the bufferize is removed
|
||||
# this is the ranges replaced
|
||||
|
||||
Reference in New Issue
Block a user