BUFFER size is their arg (#15750)

This commit is contained in:
chenyu
2026-04-15 18:08:29 -04:00
committed by GitHub
parent be8005c5dc
commit 41421c3b48

View File

@@ -732,10 +732,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
buf = self.src[0].buffer
if isinstance(buf, MultiBuffer):
mbuf = MultiBuffer.__new__(MultiBuffer)
mbuf.bufs = [b.view(self.size, self.dtype, self.arg[1] * self.dtype.itemsize) for b in buf.bufs]
mbuf.bufs = [b.view(self.arg[0], self.dtype, self.arg[1] * self.dtype.itemsize) for b in buf.bufs]
return mbuf
assert isinstance(buf, Buffer), "must be a Buffer for BUFFER_VIEW"
return buf.view(self.size, self.dtype, self.arg[1] * self.dtype.itemsize)
return buf.view(self.arg[0], self.dtype, self.arg[1] * self.dtype.itemsize)
if self.op is Ops.MSELECT:
ret = self.src[0].buffer
assert isinstance(ret, MultiBuffer)
@@ -749,8 +749,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
assert self.src[0].op is Ops.UNIQUE, f"buffer src[0] must be UNIQUE, not {self.src[0].op}"
if (cret:=buffers.get(self)) is not None: return cret
rdtype = self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base
if isinstance(self.device, tuple): ret = MultiBuffer(self.device, self.size, rdtype).ref(1)
else: ret = Buffer(self.device, self.size, rdtype).ref(1)
if isinstance(self.device, tuple): ret = MultiBuffer(self.device, self.arg, rdtype).ref(1)
else: ret = Buffer(self.device, self.arg, rdtype).ref(1)
buffers[self] = ret
return ret
@property
@@ -1564,7 +1564,7 @@ pm_pyrender_extra = PatternMatcher([
(UPat((Ops.CAST, Ops.BITCAST), name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.dtype})"),
(UPat(Ops.SPECIAL, src=(UPat(Ops.CONST),), name="x"), lambda x: f"UOp.special({x.src[0].arg}, {repr(x.arg)}, dtype={x.dtype})"),
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"), lambda x,u,d:
f"UOp.new_buffer({repr(d.arg)}, {x.size}, {x.dtype}, {u.arg})"),
f"UOp.new_buffer({repr(d.arg)}, {x.arg}, {x.dtype}, {u.arg})"),
(UPat(Ops.COPY, src=(UPat(name="x"), UPat(Ops.DEVICE, name="d"))), lambda ctx,x,d: f"{ctx[x]}.copy_to_device({repr(d.arg)})"),
(UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda ctx,x: f"UOp(Ops.CUSTOM_FUNCTION, {x.dtype}, src={srcs(ctx, x.src)}, arg={x.arg!r})"),
(UPat(Ops.REDUCE_AXIS, name="r"), lambda ctx,r: f"{ctx[r.src[0]]}._rop({r.arg[0]}, {r.arg[1]})"),