mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
BasicBlock2 -> BasicBlock [pr] (#10691)
This commit is contained in:
@@ -50,13 +50,13 @@ def disp(y:UOp) -> str:
|
||||
return "<NONE>"
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class BasicBlock2:
|
||||
class BasicBlock:
|
||||
lst: tuple[UOp, ...]
|
||||
ctx: tuple[UOp, ...] = ()
|
||||
end: UOp|None = None
|
||||
cnt: int = 0
|
||||
child_ctx: tuple[UOp, ...]|None = None
|
||||
def __lt__(self, _:BasicBlock2): raise RuntimeError("no comparing basic blocks")
|
||||
def __lt__(self, _:BasicBlock): raise RuntimeError("no comparing basic blocks")
|
||||
def __repr__(self):
|
||||
return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+f'f{self.cnt} '+\
|
||||
f"{[disp(y) for y in self.ctx]} {[disp(y) for y in self.child_ctx] if self.child_ctx is not None else '-'} "+\
|
||||
@@ -114,7 +114,7 @@ def add_blockends(base_block:UOp, new_ctx:tuple[UOp, ...], current_ctx:tuple[UOp
|
||||
r:UOp = ends_to_add.pop(-1)
|
||||
new_ctx = tuple([z for z in new_ctx if z is not r])
|
||||
end_uop = UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,))
|
||||
base_block = UOp(Ops.BLOCKEND, src=(base_block,)*cnt, arg=BasicBlock2((end_uop,), tuple(new_ctx), end=r, cnt=cnt))
|
||||
base_block = UOp(Ops.BLOCKEND, src=(base_block,)*cnt, arg=BasicBlock((end_uop,), tuple(new_ctx), end=r, cnt=cnt))
|
||||
return base_block
|
||||
|
||||
def make_block_bottom_up(ctx:BlockContext, x:UOp):
|
||||
@@ -159,7 +159,7 @@ def make_block_bottom_up(ctx:BlockContext, x:UOp):
|
||||
srcs.append(add_blockends(base_block, new_ctx, current_ctx))
|
||||
|
||||
lst = block_reorder(lst[::-1])
|
||||
bb = BasicBlock2(tuple(lst), ctx=current_ctx, cnt=child_count, child_ctx=child_ctx)
|
||||
bb = BasicBlock(tuple(lst), ctx=current_ctx, cnt=child_count, child_ctx=child_ctx)
|
||||
return UOp(Ops.BLOCK, src=tuple(srcs), arg=bb)
|
||||
|
||||
block_create = PatternMatcher([
|
||||
@@ -179,7 +179,7 @@ def merge_blockends(sink:UOp) -> UOp|None:
|
||||
for k,v in blockends_to_arg.items():
|
||||
# NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
|
||||
if len(v) > 1:
|
||||
bb = BasicBlock2(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=sum(y.arg.cnt for y in v))
|
||||
bb = BasicBlock(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=sum(y.arg.cnt for y in v))
|
||||
out = UOp(Ops.BLOCKEND, src=tuple(flatten([x.src for x in v])), arg=bb)
|
||||
# NOTE: bb.ctx != u.arg.ctx can cause problems here
|
||||
for u in v: new_forks[u] = out
|
||||
@@ -221,7 +221,7 @@ def remove_blockend(x:UOp):
|
||||
# NOTE: we have to add a barrier at the start if barrier is used in the range
|
||||
if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.ENDRANGE:
|
||||
late_ops = [UOp(Ops.BARRIER)] + late_ops
|
||||
arg = BasicBlock2(tuple(early_ops)+parent_block.arg.lst+tuple(late_ops), tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt)
|
||||
arg = BasicBlock(tuple(early_ops)+parent_block.arg.lst+tuple(late_ops), tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt)
|
||||
return UOp(Ops.BLOCK, src=tuple(y for y in x.src if y is not parent_block)+parent_block.src, arg=arg)
|
||||
|
||||
block_merge = PatternMatcher([
|
||||
@@ -240,6 +240,6 @@ def finalize(sink:UOp) -> UOp:
|
||||
|
||||
if __debug__: type_verify(lst)
|
||||
|
||||
return UOp(Ops.BLOCKFINAL, arg=BasicBlock2(tuple(lst)))
|
||||
return UOp(Ops.BLOCKFINAL, arg=BasicBlock(tuple(lst)))
|
||||
|
||||
pm_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), finalize)])
|
||||
|
||||
Reference in New Issue
Block a user