mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
minor cleanups in linearize.py [pr] (#10735)
This commit is contained in:
@@ -72,7 +72,7 @@ class BlockContext:
|
||||
child_count: dict[UOp, int]
|
||||
block_ctxs: dict[UOp, tuple[UOp, ...]]
|
||||
child_ctxs: dict[UOp, tuple[UOp, ...]]
|
||||
def last_ctx(self, u): return ret if (ret:=self.child_ctxs.get(u)) is not None else self.block_ctxs[u]
|
||||
def last_ctx(self, u): return self.child_ctxs.get(u, self.block_ctxs[u])
|
||||
@staticmethod
|
||||
def from_sink(sink:UOp) -> BlockContext:
|
||||
# get children and all block contexts
|
||||
@@ -212,9 +212,8 @@ def remove_blockend(x:UOp):
|
||||
# if there's any remaining blocks that need to go in this BLOCKEND, we don't remove it
|
||||
if any(x.arg.end in y.arg.ctx for y in x.src if y.op in {Ops.BLOCK, Ops.BLOCKEND}): return None
|
||||
|
||||
parent_blocks = [y for y in x.src if y.op is Ops.BLOCK and y.arg.child_ctx is not None and x.arg.end in y.arg.child_ctx]
|
||||
assert all_same(parent_blocks), f"should never have two parent blocks (has {len(parent_blocks)})"
|
||||
if len(parent_blocks) > 0:
|
||||
if (parent_blocks := [y for y in x.src if y.op is Ops.BLOCK and y.arg.child_ctx is not None and x.arg.end in y.arg.child_ctx]):
|
||||
assert all_same(parent_blocks), f"should never have two parent blocks (has {len(parent_blocks)})"
|
||||
parent_block = parent_blocks[0]
|
||||
assert len(parent_blocks) == parent_block.arg.cnt
|
||||
# range needs DEFINE_ACC to be before the range (never in DEFINE_ACC for if)
|
||||
|
||||
Reference in New Issue
Block a user