diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index 3e8f0ae17a..6d215b627a 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -355,9 +355,9 @@ class TestDiskTensor(unittest.TestCase): t[0:8].assign(Tensor([12345], dtype=dtypes.int64, device="CPU").bitcast(dtypes.uint8)) val = int.from_bytes(t[0:8].data(), 'little') self.assertEqual(val, 12345) - # bitcast on target with mismatched dtype raises + # bitcast on target with non-broadcastable dtype raises with self.assertRaises(RuntimeError): - t[0:8].bitcast(dtypes.int64).assign(Tensor([12345], dtype=dtypes.int32)) + t[0:4].bitcast(dtypes.int32).assign(Tensor([12345], dtype=dtypes.int64)) def test_assign_cross_device(self): # disk assign allows cross-device (source on GPU/CPU, target on disk) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c2edc7763f..f7a09b4116 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -286,21 +286,20 @@ class Tensor(OpMixin): return self def assign(self, x:Tensor|PyConst|list|tuple) -> Tensor: - # TODO: this is a hack for writing to DISK. remove with working assign - if isinstance(self.device, str) and self.device.startswith("DISK"): - if not isinstance(x, Tensor): x = Tensor(x, device="CPU", dtype=self.dtype) - if self.dtype != x.dtype: raise RuntimeError(f"DISK assign dtype mismatch {self.dtype} != {x.dtype}") - self._buffer().copyin(x._data()) - return self - if not isinstance(x, Tensor): x = Tensor(x, device=self.device, dtype=self.dtype) + is_disk = isinstance(self.device, str) and self.device.startswith("DISK") + if not isinstance(x, Tensor): x = Tensor(x, device="CPU" if is_disk else self.device, dtype=self.dtype) if self.uop is x.uop: return self # a self assign is a NOOP - # NOTE: we allow cross device assign # broadcast x if least_upper_dtype(self.dtype, x.dtype) == self.dtype: x = x._broadcast_to(self.shape).cast(self.dtype) - assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}" - assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}" - assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}" - assert not isinstance(self.device, tuple) or self.uop.axis == x.uop.axis, f"multi assign axis mismatch {self.uop.axis} != {x.uop.axis}" + if self.shape != x.shape: raise RuntimeError(f"assign shape mismatch {self.shape} != {x.shape}") + if not is_disk and self.device != x.device: raise RuntimeError(f"assign device mismatch {self.device} != {x.device}") + if self.dtype != x.dtype: raise RuntimeError(f"assign dtype mismatch {self.dtype} != {x.dtype}") + if isinstance(self.device, tuple) and self.uop.axis != x.uop.axis: raise RuntimeError(f"multi axis mismatch {self.uop.axis} != {x.uop.axis}") + + # TODO: this is a hack for writing to DISK. remove with working assign + if is_disk: + self._buffer().copyin(x._data()) + return self return self.replace(self._apply_uop(UOp.assign, x)) def detach(self) -> Tensor: