make device on uop optional [pr] (#8034)

This commit is contained in:
qazal
2024-12-04 07:18:00 -05:00
committed by GitHub
parent 13eedd373b
commit b116e1511d
2 changed files with 13 additions and 1 deletions

View File

@@ -396,6 +396,15 @@ class TestUOpMethod(unittest.TestCase):
self.assertIs(x.replace(arg=None).arg, None)
with self.assertRaises(AssertionError): x.replace(field="a")
def test_device(self):
x = UOp(Ops.VIEW, dtypes.int, (UOp.new_buffer(Device.DEFAULT, 1, dtypes.int), UOp.const(dtypes.int, 1)), ShapeTracker.from_shape(()))
self.assertEqual(x.device, Device.DEFAULT)
# NOTE: CONST doesn't have device
buffer, const = x.src
self.assertEqual(buffer.device, Device.DEFAULT)
self.assertEqual(const._device, None)
with self.assertRaises(AssertionError): const.device
class TestUOpStr(unittest.TestCase):
def test_uop_str(self):
a = UOp(Ops.CONST, dtypes.float, (), 2.0) + UOp(Ops.CONST, dtypes.float, (), 3.0)

View File

@@ -363,8 +363,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
buffer_num = itertools.count(0)
@staticmethod
def new_buffer(device:str, size:int, dtype:DType) -> UOp: return UOp(Ops.BUFFER, dtype.ptr(), (), (next(UOp.buffer_num), (device, size, dtype)))
@property
def device(self) -> str: return unwrap(self._device)
@functools.cached_property
def device(self) -> str: return self.arg[1][0] if self.op is Ops.BUFFER else self.src[0].device
def _device(self) -> Optional[str]:
return self.arg[1][0] if self.op is Ops.BUFFER else dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
@property
def buf_uop(self) -> UOp:
if self.op is Ops.BUFFER: return self