minor cleanups in linearize.py [pr] (#10735)

This commit is contained in:
chenyu
2025-06-09 16:49:19 -07:00
committed by GitHub
parent 81ef879da3
commit 364b903850

View File

@@ -72,7 +72,7 @@ class BlockContext:
child_count: dict[UOp, int]
block_ctxs: dict[UOp, tuple[UOp, ...]]
child_ctxs: dict[UOp, tuple[UOp, ...]]
def last_ctx(self, u): return ret if (ret:=self.child_ctxs.get(u)) is not None else self.block_ctxs[u]
def last_ctx(self, u): return self.child_ctxs.get(u, self.block_ctxs[u])
@staticmethod
def from_sink(sink:UOp) -> BlockContext:
# get children and all block contexts
@@ -212,9 +212,8 @@ def remove_blockend(x:UOp):
# if there's any remaining blocks that need to go in this BLOCKEND, we don't remove it
if any(x.arg.end in y.arg.ctx for y in x.src if y.op in {Ops.BLOCK, Ops.BLOCKEND}): return None
parent_blocks = [y for y in x.src if y.op is Ops.BLOCK and y.arg.child_ctx is not None and x.arg.end in y.arg.child_ctx]
assert all_same(parent_blocks), f"should never have two parent blocks (has {len(parent_blocks)})"
if len(parent_blocks) > 0:
if (parent_blocks := [y for y in x.src if y.op is Ops.BLOCK and y.arg.child_ctx is not None and x.arg.end in y.arg.child_ctx]):
assert all_same(parent_blocks), f"should never have two parent blocks (has {len(parent_blocks)})"
parent_block = parent_blocks[0]
assert len(parent_blocks) == parent_block.arg.cnt
# range needs DEFINE_ACC to be before the range (never in DEFINE_ACC for if)