replace lazy srcs with the new uop api [pr] (#8255)

* buf_uop_view function

* srcs shouldn't exist

* fix TestTensorMetadata

---------

Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
qazal
2024-12-15 11:09:54 +02:00
committed by GitHub
parent e0aeb2e9f4
commit d05e21cb69
5 changed files with 9 additions and 14 deletions

View File

@@ -84,8 +84,8 @@ a = UOp.metaop(Ops.EMPTY, (1,), dtypes.int32, DEVICE)
b = UOp.metaop(Ops.EMPTY, (1,), dtypes.int32, DEVICE)
a.buffer.allocate().copyin(memoryview(bytearray(struct.pack("I", 2))))
b.buffer.allocate().copyin(memoryview(bytearray(struct.pack("I", 3))))
del a.srcs
del b.srcs
a = a.buf_uop_view()
b = b.buf_uop_view()
# describe the computation
out = a.alu(Ops.ADD, b)

View File

@@ -742,8 +742,8 @@ class TestTensorMetadata(unittest.TestCase):
y = Tensor.rand(3, requires_grad=True)
out = x.relu() * y.sigmoid()
self.assertEqual(out.lazydata.metadata.name, "__mul__")
self.assertEqual(out.lazydata.srcs[0].metadata.name, "relu")
self.assertEqual(out.lazydata.srcs[1].metadata.name, "sigmoid")
self.assertEqual(out.lazydata.src[0].metadata.name, "relu")
self.assertEqual(out.lazydata.src[1].metadata.name, "sigmoid")
si = create_schedule([out.lazydata])[-1]
self.assertEqual(len(si.metadata), 3)
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})

View File

@@ -69,7 +69,7 @@ def to_uop(buf:UOp, ctx:ScheduleContext, cache:Dict[UOp, UOp]) -> UOp:
# ASSIGN uses the target buffer, otherwise we create a new buffer
else:
buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
op = buf.replace(dtype=dtype.base, src=tuple(to_uop(x, ctx, cache) for x in buf.srcs))
op = buf.replace(dtype=dtype.base, src=tuple(to_uop(x, ctx, cache) for x in buf.src))
# track the underlying tensor uop for this op
ctx.tensor_uops[buf_uop] = [buf]
# (early) bufferize

View File

@@ -453,10 +453,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return (self.st is not None and self._device is not None and self.st.consecutive and not self.is_unrealized_const() and
not isinstance(self.dtype, ImageDType) and self.device.split(":")[0] in view_supported_devices)
@property
def srcs(self): return self.src
@srcs.deleter
def srcs(self): self.become(self.buf_uop.view(unwrap(self.st)))
@property
def lbs(self): return [self]
@property
def metadata(self): return all_metadata.get(self, None)
@@ -511,6 +507,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.BUFFER: return self
assert self.base.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW}, f"buf_uop called on {self.op}"
return self.src[0].buf_uop
def buf_uop_view(self) -> UOp: return self.buf_uop.view(unwrap(self.st))
@property
def buffer(self) -> Buffer:
if self.base.realized is not None: return self.base.realized

View File

@@ -52,8 +52,7 @@ def _fromnp(x: 'np.ndarray') -> UOp: # type: ignore [name-defined] # noqa: F821
ret = UOp.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
# fake realize
ret.buffer.allocate(x)
del ret.srcs
return ret
return ret.buf_uop_view()
def get_shape(x) -> Tuple[int, ...]:
# NOTE: str is special because __getitem__ on a str is still a str
@@ -70,8 +69,7 @@ def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> UOp:
data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
# fake realize
ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data)))
del ret.srcs
return ret
return ret.buf_uop_view()
def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str, Tuple[str, ...]], dtype:DType) -> List[List[Tensor]]:
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype) for m in mat], dim=dim)
@@ -430,7 +428,7 @@ class Tensor(SimpleMathTrait):
r = Tensor._metaop(Ops.EMPTY, shape, **kwargs)
r.lazydata.buffer.allocate(external_ptr=ptr)
del r.lazydata.srcs # fake realize
r.lazydata.buf_uop_view()
return r
@staticmethod