diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index f30fe52d3c..7d48afcda1 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -218,6 +218,10 @@ def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemCon sink = graph_rewrite(graph_rewrite(sink, view_left), view_right) # convert to AST sink = graph_rewrite(graph_rewrite(sink, to_si, si_ctx), append_bufs, si_ctx) + # assert buffer count limit + if (limit:=BUF_LIMIT.get(device:=si_ctx.bufs[0].device)) is not None and len(si_ctx.bufs) >= limit: + if DEBUG >= 3: print(sink) + raise RuntimeError(f"Kernel for {si_ctx.metadata} exceeded the {limit} buffer count limit for {device} with {len(si_ctx.bufs)} buffers.") # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \ and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in si_ctx.assign_preloads if si_ctx.sinked.get(x.buf_uop) is not None): @@ -415,9 +419,6 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] while queue: schedule.append(si:=queue.popleft()) for b in si.outputs: del lazybufs[b].srcs # can only schedule once - if (m:=BUF_LIMIT.get(device:=si.outputs[0].device)) and len(si.bufs) >= m: - if DEBUG >= 3: print(si) - raise RuntimeError(f"Kernel for {si.metadata} exceeded the {m} buffer count limit for {device} with {len(si.bufs)} buffers.") for x in graph[si]: in_degree[x] -= 1 if in_degree[x] == 0: queue.append(x) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index ff04578389..50ee312c0e 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -353,7 +353,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @staticmethod def new_buffer(device:str, size:int, dtype:DType, num=-1): return UOp(Ops.BUFFER, dtype.ptr(), (), (num, (device, size, dtype))) - + @functools.cached_property + def device(self) -> str: + match self.op: + case Ops.COPY: return self.arg + case Ops.BUFFER: return self.arg[1][0] + case _: return self.src[0].device @property def buf_uop(self) -> UOp: assert self.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW} and self.src[0].op is Ops.BUFFER, f"buf_uop called on {self.op}"