minor cleanup to UOp mod folding [run_process_replay] (#5895)

some walrus
This commit is contained in:
chenyu
2024-08-03 21:38:44 -04:00
committed by GitHub
parent dad8e72ee9
commit 59315ffc78

View File

@@ -85,22 +85,22 @@ def _get_add_chain(x:UOp):
def mod_folding(x:UOp, c:int) -> Optional[UOp]:
# simplify x in x % c
# None means no change
ret, something_changed = [], False
remainder, something_changed = [], False
for u in _get_add_chain(x):
if u.op is UOps.CONST and u.arg%c!=u.arg:
if u.arg%c != 0: ret.append(u.const(u.arg%c))
if u.op is UOps.CONST and (r:=u.arg%c) != u.arg:
if r: remainder.append(u.const(r))
something_changed = True
elif u.op is UOps.ALU and u.arg is BinaryOps.MUL:
if (u0:=u.src[0]).op is UOps.CONST and u0.arg%c!=u0.arg:
if u0.arg%c != 0: ret.append(u.src[1] if (r:=u0.arg%c)==1 else u.const(r)*u.src[1])
if (u0:=u.src[0]).op is UOps.CONST and (r:=u0.arg%c) != u0.arg:
if r: remainder.append(u.src[1] if r==1 else u.const(r)*u.src[1])
something_changed = True
elif (u1:=u.src[1]).op is UOps.CONST and u1.arg%c!=u1.arg:
if u1.arg%c != 0: ret.append(u.src[0] if (r:=u1.arg%c)==1 else u.src[0]*u.const(r))
elif (u1:=u.src[1]).op is UOps.CONST and (r:=u1.arg%c) != u1.arg:
if r: remainder.append(u.src[0] if r==1 else u.src[0]*u.const(r))
something_changed = True
else: ret.append(u)
else: ret.append(u)
else: remainder.append(u)
else: remainder.append(u)
if not something_changed: return None
return functools.reduce(operator.add, ret) if ret else x.const(0)
return functools.reduce(operator.add, remainder) if remainder else x.const(0)
# ***** transcendental *****