mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
derive device (dname) from UOp [pr] (#7819)
This commit is contained in:
@@ -218,6 +218,10 @@ def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemCon
|
||||
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
|
||||
# convert to AST
|
||||
sink = graph_rewrite(graph_rewrite(sink, to_si, si_ctx), append_bufs, si_ctx)
|
||||
# assert buffer count limit
|
||||
if (limit:=BUF_LIMIT.get(device:=si_ctx.bufs[0].device)) is not None and len(si_ctx.bufs) >= limit:
|
||||
if DEBUG >= 3: print(sink)
|
||||
raise RuntimeError(f"Kernel for {si_ctx.metadata} exceeded the {limit} buffer count limit for {device} with {len(si_ctx.bufs)} buffers.")
|
||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
|
||||
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in si_ctx.assign_preloads if si_ctx.sinked.get(x.buf_uop) is not None):
|
||||
@@ -415,9 +419,6 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
while queue:
|
||||
schedule.append(si:=queue.popleft())
|
||||
for b in si.outputs: del lazybufs[b].srcs # can only schedule once
|
||||
if (m:=BUF_LIMIT.get(device:=si.outputs[0].device)) and len(si.bufs) >= m:
|
||||
if DEBUG >= 3: print(si)
|
||||
raise RuntimeError(f"Kernel for {si.metadata} exceeded the {m} buffer count limit for {device} with {len(si.bufs)} buffers.")
|
||||
for x in graph[si]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
@@ -353,7 +353,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
@staticmethod
|
||||
def new_buffer(device:str, size:int, dtype:DType, num=-1): return UOp(Ops.BUFFER, dtype.ptr(), (), (num, (device, size, dtype)))
|
||||
|
||||
@functools.cached_property
|
||||
def device(self) -> str:
|
||||
match self.op:
|
||||
case Ops.COPY: return self.arg
|
||||
case Ops.BUFFER: return self.arg[1][0]
|
||||
case _: return self.src[0].device
|
||||
@property
|
||||
def buf_uop(self) -> UOp:
|
||||
assert self.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW} and self.src[0].op is Ops.BUFFER, f"buf_uop called on {self.op}"
|
||||
|
||||
Reference in New Issue
Block a user