assert assign dtype mismatch for disk [pr] (#14473)

the disk hack is generally wrong, now force bitcast on the source before assign
This commit is contained in:
chenyu
2026-01-31 17:08:54 -05:00
committed by GitHub
parent ced886f26c
commit b38fc43b07
4 changed files with 16 additions and 9 deletions

View File

@@ -469,10 +469,10 @@ class TestAssign(unittest.TestCase):
np.testing.assert_allclose(a.numpy(), [1.0, 2.0, 3.0, 4.0]) # TODO: should be [4.0, 3.0, 2.0, 1.0]
def test_assign_bitcast_different_size(self):
# assign to a shape-changing bitcast view (only works on DISK currently)
# different-size bitcast creates a new tensor, not a view, so assign doesn't modify the original
a = Tensor([0]*8, dtype=dtypes.uint8).realize()
a.bitcast(dtypes.int64).assign(Tensor([12345], dtype=dtypes.int64)).realize()
np.testing.assert_equal(a.numpy(), [0]*8) # TODO: should be [57, 48, 0, 0, 0, 0, 0, 0] (little-endian 12345)
np.testing.assert_equal(a.numpy(), [0]*8)
@unittest.skip("don't use output buffer, and mismatch dtype no longer supported")
def test_cast_assignment(self):

View File

@@ -59,13 +59,13 @@ class TestRawDiskBuffer(unittest.TestCase):
_test_bitcasted(t, dtypes.float32, 0.0)
_test_bitcasted(t, dtypes.uint32, 0)
# pi in float16 stored via int16
t.bitcast(dtypes.uint16).assign(Tensor.full((128, 64), 0x4248, dtype=dtypes.uint16)).realize()
t.assign(Tensor.full((128, 64), 0x4248, dtype=dtypes.uint16).bitcast(dtypes.uint8)).realize()
_test_bitcasted(t, dtypes.float16, 3.140625)
_test_bitcasted(t, dtypes.float32, 50.064727)
_test_bitcasted(t, dtypes.uint16, 0x4248)
_test_bitcasted(t, dtypes.uint32, 0x42484248)
# pi in float32 stored via float32
t.bitcast(dtypes.float32).assign(Tensor.full((128, 32), 3.1415927, dtype=dtypes.float32)).realize()
t.assign(Tensor.full((128, 32), 3.1415927, dtype=dtypes.float32).bitcast(dtypes.uint8)).realize()
_test_bitcasted(t, dtypes.float32, 3.1415927)
_test_bitcasted(t, dtypes.uint32, 0x40490FDB)
# doesn't suport normal cast
@@ -348,13 +348,16 @@ class TestDiskTensor(unittest.TestCase):
def test_assign_with_bitcast(self):
# bitcast assign is used in safe_save for writing header length
# this tests the synchronous disk assign hack handles bitcast correctly
# bitcast on source side works, bitcast on target side raises
pathlib.Path(temp(fn:="dt_assign_bitcast")).unlink(missing_ok=True)
t = Tensor.empty(16, device=f"disk:{temp(fn)}", dtype=dtypes.uint8)
t[0:8].bitcast(dtypes.int64).assign([12345])
# verify the data was written correctly
# correct way: bitcast the source to match target dtype
t[0:8].assign(Tensor([12345], dtype=dtypes.int64, device="CPU").bitcast(dtypes.uint8))
val = int.from_bytes(t[0:8].data(), 'little')
self.assertEqual(val, 12345)
# bitcast on target with mismatched dtype raises
with self.assertRaises(RuntimeError):
t[0:8].bitcast(dtypes.int64).assign(Tensor([12345], dtype=dtypes.int32))
def test_assign_cross_device(self):
# disk assign allows cross-device (source on GPU/CPU, target on disk)

View File

@@ -78,7 +78,7 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:dict[str, Any]|None=No
j += "\x20"*(round_up(len(j),8)-len(j))
pathlib.Path(fn).unlink(missing_ok=True)
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
t[0:8].bitcast(dtypes.int64).assign([len(j)])
t[0:8].assign(Tensor([len(j)], dtype=dtypes.int64, device="CPU").bitcast(dtypes.uint8))
t[8:8+len(j)].assign(list(j.encode('utf-8')))
for k,v in safe_load(t).items(): v.assign(tensors[k])

View File

@@ -289,6 +289,7 @@ class Tensor(OpMixin):
# TODO: this is a hack for writing to DISK. remove with working assign
if isinstance(self.device, str) and self.device.startswith("DISK"):
if not isinstance(x, Tensor): x = Tensor(x, device="CPU", dtype=self.dtype)
if self.dtype != x.dtype: raise RuntimeError(f"DISK assign dtype mismatch {self.dtype} != {x.dtype}")
self._buffer().copyin(x._data())
return self
if not isinstance(x, Tensor): x = Tensor(x, device=self.device, dtype=self.dtype)
@@ -3868,7 +3869,10 @@ class Tensor(OpMixin):
def bitcast(self, dtype:DTypeLike) -> Tensor:
"""
Bitcasts `self` to the given `dtype` of the same itemsize.
Bitcasts `self` to the given `dtype`.
When the target dtype has the same itemsize, this is a view of the same memory.
When itemsizes differ, the last dimension is adjusted and a new Tensor is created.
`self` must not require a gradient.