From eeb222f98b32db1fbba93d7ec4d3b43298d20bf1 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 17 Nov 2024 10:44:52 +0200 Subject: [PATCH] add UOp.new_buffer [pr] (#7742) --- tinygrad/engine/schedule.py | 4 ++-- tinygrad/ops.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 87c2e11941..9b4d6c5d93 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -46,7 +46,7 @@ class ScheduleContext: buf_uops: Dict[Buffer, UOp] = field(default_factory=dict) # this maps Buffers to BUFFER uops ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps BUFFER uops to Metadata var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value - assigns: Set[UOp] = field(default_factory=set) # this holds all the UOps.BUFFERs we ASSIGN to in this schedule + assigns: Set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule lazybufs: Dict[Buffer, LazyBuffer] = field(default_factory=dict) # this is a lookup for the LazyBuffers we need to mark as realized def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp], @@ -69,7 +69,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children:DefaultDict[UOp, Dict[U if isinstance(val:=buf.arg, UOp): ctx.var_vals.update([val.unbind()]) return UOp(Ops.VALID, dtypes.bool, (buf.st.to_uop(),)).where(UOp.const(dtype, val), 0) # everything else has BUFFER - ubuf = ctx.buf_uops.setdefault(b:=buf.buffer, UOp(Ops.BUFFER, b.dtype.ptr(), (), (len(ctx.buf_uops), (b.device, b.size, b.dtype)))) + ubuf = ctx.buf_uops.setdefault(b:=buf.buffer, UOp.new_buffer(b.device, b.size, b.dtype, num=len(ctx.buf_uops))) # if the buffer is already realized we just load it if buf.is_realized(): return UOp(Ops.VIEW, dtype, (ubuf,), buf.st) # everything else needs sources diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 41d2911a8e..e6a8deb5e6 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -350,6 +350,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # *** uop Buffer stuff *** + @staticmethod + def new_buffer(device:str, size:int, dtype:DType, num=-1): return UOp(Ops.BUFFER, dtype.ptr(), (), (num, (device, size, dtype))) + @property def buf_uop(self) -> UOp: assert self.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW} and self.src[0].op is Ops.BUFFER, f"buf_uop called on {self.op}"