mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
assign cleanup [pr] (#14479)
share more code path between disk and non-disk. also raise RuntimeError instead of Assert for mismatches
This commit is contained in:
@@ -355,9 +355,9 @@ class TestDiskTensor(unittest.TestCase):
|
||||
t[0:8].assign(Tensor([12345], dtype=dtypes.int64, device="CPU").bitcast(dtypes.uint8))
|
||||
val = int.from_bytes(t[0:8].data(), 'little')
|
||||
self.assertEqual(val, 12345)
|
||||
# bitcast on target with mismatched dtype raises
|
||||
# bitcast on target with non-broadcastable dtype raises
|
||||
with self.assertRaises(RuntimeError):
|
||||
t[0:8].bitcast(dtypes.int64).assign(Tensor([12345], dtype=dtypes.int32))
|
||||
t[0:4].bitcast(dtypes.int32).assign(Tensor([12345], dtype=dtypes.int64))
|
||||
|
||||
def test_assign_cross_device(self):
|
||||
# disk assign allows cross-device (source on GPU/CPU, target on disk)
|
||||
|
||||
@@ -286,21 +286,20 @@ class Tensor(OpMixin):
|
||||
return self
|
||||
|
||||
def assign(self, x:Tensor|PyConst|list|tuple) -> Tensor:
|
||||
# TODO: this is a hack for writing to DISK. remove with working assign
|
||||
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
||||
if not isinstance(x, Tensor): x = Tensor(x, device="CPU", dtype=self.dtype)
|
||||
if self.dtype != x.dtype: raise RuntimeError(f"DISK assign dtype mismatch {self.dtype} != {x.dtype}")
|
||||
self._buffer().copyin(x._data())
|
||||
return self
|
||||
if not isinstance(x, Tensor): x = Tensor(x, device=self.device, dtype=self.dtype)
|
||||
is_disk = isinstance(self.device, str) and self.device.startswith("DISK")
|
||||
if not isinstance(x, Tensor): x = Tensor(x, device="CPU" if is_disk else self.device, dtype=self.dtype)
|
||||
if self.uop is x.uop: return self # a self assign is a NOOP
|
||||
# NOTE: we allow cross device assign
|
||||
# broadcast x
|
||||
if least_upper_dtype(self.dtype, x.dtype) == self.dtype: x = x._broadcast_to(self.shape).cast(self.dtype)
|
||||
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
||||
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
|
||||
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
||||
assert not isinstance(self.device, tuple) or self.uop.axis == x.uop.axis, f"multi assign axis mismatch {self.uop.axis} != {x.uop.axis}"
|
||||
if self.shape != x.shape: raise RuntimeError(f"assign shape mismatch {self.shape} != {x.shape}")
|
||||
if not is_disk and self.device != x.device: raise RuntimeError(f"assign device mismatch {self.device} != {x.device}")
|
||||
if self.dtype != x.dtype: raise RuntimeError(f"assign dtype mismatch {self.dtype} != {x.dtype}")
|
||||
if isinstance(self.device, tuple) and self.uop.axis != x.uop.axis: raise RuntimeError(f"multi axis mismatch {self.uop.axis} != {x.uop.axis}")
|
||||
|
||||
# TODO: this is a hack for writing to DISK. remove with working assign
|
||||
if is_disk:
|
||||
self._buffer().copyin(x._data())
|
||||
return self
|
||||
return self.replace(self._apply_uop(UOp.assign, x))
|
||||
|
||||
def detach(self) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user