assign cleanup [pr] (#14479)

share more code path between disk and non-disk. also raise RuntimeError instead of Assert for mismatches
This commit is contained in:
chenyu
2026-02-01 09:10:22 -05:00
committed by GitHub
parent da500dbe06
commit 5705398a1f
2 changed files with 13 additions and 14 deletions

View File

@@ -355,9 +355,9 @@ class TestDiskTensor(unittest.TestCase):
t[0:8].assign(Tensor([12345], dtype=dtypes.int64, device="CPU").bitcast(dtypes.uint8))
val = int.from_bytes(t[0:8].data(), 'little')
self.assertEqual(val, 12345)
# bitcast on target with mismatched dtype raises
# bitcast on target with non-broadcastable dtype raises
with self.assertRaises(RuntimeError):
t[0:8].bitcast(dtypes.int64).assign(Tensor([12345], dtype=dtypes.int32))
t[0:4].bitcast(dtypes.int32).assign(Tensor([12345], dtype=dtypes.int64))
def test_assign_cross_device(self):
# disk assign allows cross-device (source on GPU/CPU, target on disk)

View File

@@ -286,21 +286,20 @@ class Tensor(OpMixin):
return self
def assign(self, x:Tensor|PyConst|list|tuple) -> Tensor:
# TODO: this is a hack for writing to DISK. remove with working assign
if isinstance(self.device, str) and self.device.startswith("DISK"):
if not isinstance(x, Tensor): x = Tensor(x, device="CPU", dtype=self.dtype)
if self.dtype != x.dtype: raise RuntimeError(f"DISK assign dtype mismatch {self.dtype} != {x.dtype}")
self._buffer().copyin(x._data())
return self
if not isinstance(x, Tensor): x = Tensor(x, device=self.device, dtype=self.dtype)
is_disk = isinstance(self.device, str) and self.device.startswith("DISK")
if not isinstance(x, Tensor): x = Tensor(x, device="CPU" if is_disk else self.device, dtype=self.dtype)
if self.uop is x.uop: return self # a self assign is a NOOP
# NOTE: we allow cross device assign
# broadcast x
if least_upper_dtype(self.dtype, x.dtype) == self.dtype: x = x._broadcast_to(self.shape).cast(self.dtype)
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
assert not isinstance(self.device, tuple) or self.uop.axis == x.uop.axis, f"multi assign axis mismatch {self.uop.axis} != {x.uop.axis}"
if self.shape != x.shape: raise RuntimeError(f"assign shape mismatch {self.shape} != {x.shape}")
if not is_disk and self.device != x.device: raise RuntimeError(f"assign device mismatch {self.device} != {x.device}")
if self.dtype != x.dtype: raise RuntimeError(f"assign dtype mismatch {self.dtype} != {x.dtype}")
if isinstance(self.device, tuple) and self.uop.axis != x.uop.axis: raise RuntimeError(f"multi axis mismatch {self.uop.axis} != {x.uop.axis}")
# TODO: this is a hack for writing to DISK. remove with working assign
if is_disk:
self._buffer().copyin(x._data())
return self
return self.replace(self._apply_uop(UOp.assign, x))
def detach(self) -> Tensor: