add UOp.new_buffer [pr] (#7742)

This commit is contained in:
qazal
2024-11-17 10:44:52 +02:00
committed by GitHub
parent a15a900415
commit eeb222f98b
2 changed files with 5 additions and 2 deletions

View File

@@ -46,7 +46,7 @@ class ScheduleContext:
buf_uops: Dict[Buffer, UOp] = field(default_factory=dict) # this maps Buffers to BUFFER uops
ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps BUFFER uops to Metadata
var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value
assigns: Set[UOp] = field(default_factory=set) # this holds all the UOps.BUFFERs we ASSIGN to in this schedule
assigns: Set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule
lazybufs: Dict[Buffer, LazyBuffer] = field(default_factory=dict) # this is a lookup for the LazyBuffers we need to mark as realized
def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp],
@@ -69,7 +69,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children:DefaultDict[UOp, Dict[U
if isinstance(val:=buf.arg, UOp): ctx.var_vals.update([val.unbind()])
return UOp(Ops.VALID, dtypes.bool, (buf.st.to_uop(),)).where(UOp.const(dtype, val), 0)
# everything else has BUFFER
ubuf = ctx.buf_uops.setdefault(b:=buf.buffer, UOp(Ops.BUFFER, b.dtype.ptr(), (), (len(ctx.buf_uops), (b.device, b.size, b.dtype))))
ubuf = ctx.buf_uops.setdefault(b:=buf.buffer, UOp.new_buffer(b.device, b.size, b.dtype, num=len(ctx.buf_uops)))
# if the buffer is already realized we just load it
if buf.is_realized(): return UOp(Ops.VIEW, dtype, (ubuf,), buf.st)
# everything else needs sources

View File

@@ -350,6 +350,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# *** uop Buffer stuff ***
@staticmethod
def new_buffer(device:str, size:int, dtype:DType, num=-1): return UOp(Ops.BUFFER, dtype.ptr(), (), (num, (device, size, dtype)))
@property
def buf_uop(self) -> UOp:
assert self.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW} and self.src[0].op is Ops.BUFFER, f"buf_uop called on {self.op}"