diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 1aeeb7fa58..1a0b446b91 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -342,8 +342,8 @@ def reduce_mul_chain(r:UOp) -> UOp|None: if r.dtype != r.src[0].dtype: return None inside, outside = [], [] for m in r.src[0].split_uop(Ops.MUL): - m_parents = m.toposort() - if all(r not in m_parents for r in r.src[1:]) and (r.arg != Ops.MAX or m.vmin >= 0): outside.append(m) + m_parents = m.backward_slice + if m not in r.src[1:] and all(r not in m_parents for r in r.src[1:]) and (r.arg != Ops.MAX or m.vmin >= 0): outside.append(m) else: inside.append(m) if len(outside) == 0: return None return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside)