fix: variable defined in assert breaks -O (#4866)

This commit is contained in:
Roelof van Dijk
2024-06-07 20:36:24 +02:00
committed by GitHub
parent 3a20cff7c2
commit 15e5a4fb26

View File

@@ -418,7 +418,8 @@ class Linearizer(Kernel):
def render_block(self, outputs:Tuple[LazyOp, ...], global_idxs, local_idxs, upcast_idxs, full_upcast_idxs,
alias_buf_idxs, loaded_buffers, accs) -> List[List[UOp]]:
assert len(reduceops:=dedup(x for x in outputs if x.op in ReduceOps)) <= 1, "max one reduceop per block"
reduceops = dedup(x for x in outputs if x.op in ReduceOps)
assert len(reduceops) <= 1, "max one reduceop per block"
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
fake_reduce_idxs = [x*0 for x in reduce_idxs]