diff --git a/tinygrad/codegen/linearize.py b/tinygrad/codegen/linearize.py index bd9231e157..5a731d2454 100644 --- a/tinygrad/codegen/linearize.py +++ b/tinygrad/codegen/linearize.py @@ -50,13 +50,13 @@ def disp(y:UOp) -> str: return "" @dataclass(frozen=True, eq=False) -class BasicBlock2: +class BasicBlock: lst: tuple[UOp, ...] ctx: tuple[UOp, ...] = () end: UOp|None = None cnt: int = 0 child_ctx: tuple[UOp, ...]|None = None - def __lt__(self, _:BasicBlock2): raise RuntimeError("no comparing basic blocks") + def __lt__(self, _:BasicBlock): raise RuntimeError("no comparing basic blocks") def __repr__(self): return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+f'f{self.cnt} '+\ f"{[disp(y) for y in self.ctx]} {[disp(y) for y in self.child_ctx] if self.child_ctx is not None else '-'} "+\ @@ -114,7 +114,7 @@ def add_blockends(base_block:UOp, new_ctx:tuple[UOp, ...], current_ctx:tuple[UOp r:UOp = ends_to_add.pop(-1) new_ctx = tuple([z for z in new_ctx if z is not r]) end_uop = UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)) - base_block = UOp(Ops.BLOCKEND, src=(base_block,)*cnt, arg=BasicBlock2((end_uop,), tuple(new_ctx), end=r, cnt=cnt)) + base_block = UOp(Ops.BLOCKEND, src=(base_block,)*cnt, arg=BasicBlock((end_uop,), tuple(new_ctx), end=r, cnt=cnt)) return base_block def make_block_bottom_up(ctx:BlockContext, x:UOp): @@ -159,7 +159,7 @@ def make_block_bottom_up(ctx:BlockContext, x:UOp): srcs.append(add_blockends(base_block, new_ctx, current_ctx)) lst = block_reorder(lst[::-1]) - bb = BasicBlock2(tuple(lst), ctx=current_ctx, cnt=child_count, child_ctx=child_ctx) + bb = BasicBlock(tuple(lst), ctx=current_ctx, cnt=child_count, child_ctx=child_ctx) return UOp(Ops.BLOCK, src=tuple(srcs), arg=bb) block_create = PatternMatcher([ @@ -179,7 +179,7 @@ def merge_blockends(sink:UOp) -> UOp|None: for k,v in blockends_to_arg.items(): # NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails if len(v) > 1: - bb = BasicBlock2(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=sum(y.arg.cnt for y in v)) + bb = BasicBlock(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=sum(y.arg.cnt for y in v)) out = UOp(Ops.BLOCKEND, src=tuple(flatten([x.src for x in v])), arg=bb) # NOTE: bb.ctx != u.arg.ctx can cause problems here for u in v: new_forks[u] = out @@ -221,7 +221,7 @@ def remove_blockend(x:UOp): # NOTE: we have to add a barrier at the start if barrier is used in the range if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.ENDRANGE: late_ops = [UOp(Ops.BARRIER)] + late_ops - arg = BasicBlock2(tuple(early_ops)+parent_block.arg.lst+tuple(late_ops), tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt) + arg = BasicBlock(tuple(early_ops)+parent_block.arg.lst+tuple(late_ops), tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt) return UOp(Ops.BLOCK, src=tuple(y for y in x.src if y is not parent_block)+parent_block.src, arg=arg) block_merge = PatternMatcher([ @@ -240,6 +240,6 @@ def finalize(sink:UOp) -> UOp: if __debug__: type_verify(lst) - return UOp(Ops.BLOCKFINAL, arg=BasicBlock2(tuple(lst))) + return UOp(Ops.BLOCKFINAL, arg=BasicBlock(tuple(lst))) pm_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), finalize)])