derive device (dname) from UOp [pr] (#7819)

This commit is contained in:
qazal
2024-11-21 03:38:22 -05:00
committed by GitHub
parent 75c082b883
commit 877b440fde
2 changed files with 10 additions and 4 deletions

View File

@@ -218,6 +218,10 @@ def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemCon
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
# convert to AST
sink = graph_rewrite(graph_rewrite(sink, to_si, si_ctx), append_bufs, si_ctx)
# assert buffer count limit
if (limit:=BUF_LIMIT.get(device:=si_ctx.bufs[0].device)) is not None and len(si_ctx.bufs) >= limit:
if DEBUG >= 3: print(sink)
raise RuntimeError(f"Kernel for {si_ctx.metadata} exceeded the {limit} buffer count limit for {device} with {len(si_ctx.bufs)} buffers.")
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in si_ctx.assign_preloads if si_ctx.sinked.get(x.buf_uop) is not None):
@@ -415,9 +419,6 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
while queue:
schedule.append(si:=queue.popleft())
for b in si.outputs: del lazybufs[b].srcs # can only schedule once
if (m:=BUF_LIMIT.get(device:=si.outputs[0].device)) and len(si.bufs) >= m:
if DEBUG >= 3: print(si)
raise RuntimeError(f"Kernel for {si.metadata} exceeded the {m} buffer count limit for {device} with {len(si.bufs)} buffers.")
for x in graph[si]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)

View File

@@ -353,7 +353,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@staticmethod
def new_buffer(device:str, size:int, dtype:DType, num=-1): return UOp(Ops.BUFFER, dtype.ptr(), (), (num, (device, size, dtype)))
@functools.cached_property
def device(self) -> str:
match self.op:
case Ops.COPY: return self.arg
case Ops.BUFFER: return self.arg[1][0]
case _: return self.src[0].device
@property
def buf_uop(self) -> UOp:
assert self.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW} and self.src[0].op is Ops.BUFFER, f"buf_uop called on {self.op}"