mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
make device on uop optional [pr] (#8034)
This commit is contained in:
@@ -396,6 +396,15 @@ class TestUOpMethod(unittest.TestCase):
|
||||
self.assertIs(x.replace(arg=None).arg, None)
|
||||
with self.assertRaises(AssertionError): x.replace(field="a")
|
||||
|
||||
def test_device(self):
|
||||
x = UOp(Ops.VIEW, dtypes.int, (UOp.new_buffer(Device.DEFAULT, 1, dtypes.int), UOp.const(dtypes.int, 1)), ShapeTracker.from_shape(()))
|
||||
self.assertEqual(x.device, Device.DEFAULT)
|
||||
# NOTE: CONST doesn't have device
|
||||
buffer, const = x.src
|
||||
self.assertEqual(buffer.device, Device.DEFAULT)
|
||||
self.assertEqual(const._device, None)
|
||||
with self.assertRaises(AssertionError): const.device
|
||||
|
||||
class TestUOpStr(unittest.TestCase):
|
||||
def test_uop_str(self):
|
||||
a = UOp(Ops.CONST, dtypes.float, (), 2.0) + UOp(Ops.CONST, dtypes.float, (), 3.0)
|
||||
|
||||
@@ -363,8 +363,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
buffer_num = itertools.count(0)
|
||||
@staticmethod
|
||||
def new_buffer(device:str, size:int, dtype:DType) -> UOp: return UOp(Ops.BUFFER, dtype.ptr(), (), (next(UOp.buffer_num), (device, size, dtype)))
|
||||
@property
|
||||
def device(self) -> str: return unwrap(self._device)
|
||||
@functools.cached_property
|
||||
def device(self) -> str: return self.arg[1][0] if self.op is Ops.BUFFER else self.src[0].device
|
||||
def _device(self) -> Optional[str]:
|
||||
return self.arg[1][0] if self.op is Ops.BUFFER else dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
|
||||
@property
|
||||
def buf_uop(self) -> UOp:
|
||||
if self.op is Ops.BUFFER: return self
|
||||
|
||||
Reference in New Issue
Block a user