use backward_slice in reduce_mul_chain [pr] (#15186)

This commit is contained in:
chenyu
2026-03-08 21:44:53 -04:00
committed by GitHub
parent 25e82a9aca
commit 82f7734501

View File

@@ -342,8 +342,8 @@ def reduce_mul_chain(r:UOp) -> UOp|None:
if r.dtype != r.src[0].dtype: return None
inside, outside = [], []
for m in r.src[0].split_uop(Ops.MUL):
m_parents = m.toposort()
if all(r not in m_parents for r in r.src[1:]) and (r.arg != Ops.MAX or m.vmin >= 0): outside.append(m)
m_parents = m.backward_slice
if m not in r.src[1:] and all(r not in m_parents for r in r.src[1:]) and (r.arg != Ops.MAX or m.vmin >= 0): outside.append(m)
else: inside.append(m)
if len(outside) == 0: return None
return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside)