From fde7a40bb093b151fd4afaf3ebe11f094b7af3dc Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 24 Feb 2026 20:49:55 -0500 Subject: [PATCH] allow dtype mismatched assign on disk (#14993) reverted #14473, that was a bad idea. also added a test that safe_save only has copy --- test/unit/test_assign.py | 4 ++-- test/unit/test_disk_tensor.py | 18 ++++++++++-------- tinygrad/nn/state.py | 2 +- tinygrad/tensor.py | 7 ++----- 4 files changed, 15 insertions(+), 16 deletions(-) diff --git a/test/unit/test_assign.py b/test/unit/test_assign.py index b39b52c23f..2e7674ca7c 100644 --- a/test/unit/test_assign.py +++ b/test/unit/test_assign.py @@ -485,10 +485,10 @@ class TestAssign(unittest.TestCase): np.testing.assert_allclose(c.numpy(), [4.0, 3.0, 3.0, 4.0]) def test_assign_bitcast_different_size(self): - # different-size bitcast creates a new tensor, not a view, so assign doesn't modify the original + # assign to a shape-changing bitcast view (only works on DISK currently) a = Tensor([0]*8, dtype=dtypes.uint8).realize() a.bitcast(dtypes.int64).assign(Tensor([12345], dtype=dtypes.int64)).realize() - np.testing.assert_equal(a.numpy(), [0]*8) + np.testing.assert_equal(a.numpy(), [0]*8) # TODO: should be [57, 48, 0, 0, 0, 0, 0, 0] (little-endian 12345) @unittest.skip("don't use output buffer, and mismatch dtype no longer supported") def test_cast_assignment(self): diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index f5b8718715..83ea56a5e7 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -74,13 +74,13 @@ class TestRawDiskBuffer(unittest.TestCase): _test_bitcasted(t, dtypes.float32, 0.0) _test_bitcasted(t, dtypes.uint32, 0) # pi in float16 stored via int16 - t.assign(Tensor.full((128, 64), 0x4248, dtype=dtypes.uint16).bitcast(dtypes.uint8)).realize() + t.bitcast(dtypes.uint16).assign(Tensor.full((128, 64), 0x4248, dtype=dtypes.uint16)).realize() _test_bitcasted(t, dtypes.float16, 3.140625) _test_bitcasted(t, dtypes.float32, 50.064727) _test_bitcasted(t, dtypes.uint16, 0x4248) _test_bitcasted(t, dtypes.uint32, 0x42484248) # pi in float32 stored via float32 - t.assign(Tensor.full((128, 32), 3.1415927, dtype=dtypes.float32).bitcast(dtypes.uint8)).realize() + t.bitcast(dtypes.float32).assign(Tensor.full((128, 32), 3.1415927, dtype=dtypes.float32)).realize() _test_bitcasted(t, dtypes.float32, 3.1415927) _test_bitcasted(t, dtypes.uint32, 0x40490FDB) # doesn't suport normal cast @@ -178,6 +178,13 @@ class TestSafetensors(TempDirTestCase): import json assert json.loads(dat[8:8+sz])['__metadata__']['hello'] == 'world' + def test_safe_save_only_copy(self): + from tinygrad.helpers import GlobalCounters + t = Tensor.rand(10, 10).realize() + GlobalCounters.reset() + safe_save({"t": t}, self.tmp("test_copy.safetensors")) + assert GlobalCounters.global_ops == 0, f"safe_save should have no compute, got {GlobalCounters.global_ops} ops" + def test_save_all_dtypes(self): for dtype in dedup(DTYPES_DICT.values()): if dtype in [dtypes.bfloat16]: continue # not supported in numpy @@ -357,15 +364,10 @@ class TestDiskTensor(TempDirTestCase): def test_assign_with_bitcast(self): # bitcast assign is used in safe_save for writing header length - # bitcast on source side works, bitcast on target side raises t = Tensor.empty(16, device=f"disk:{self.tmp('dt_assign_bitcast')}", dtype=dtypes.uint8) - # correct way: bitcast the source to match target dtype - t[0:8].assign(Tensor([12345], dtype=dtypes.int64, device="CPU").bitcast(dtypes.uint8)) + t[0:8].bitcast(dtypes.int64).assign([12345]) val = int.from_bytes(t[0:8].data(), 'little') self.assertEqual(val, 12345) - # bitcast on target with non-broadcastable dtype raises - with self.assertRaises(RuntimeError): - t[0:4].bitcast(dtypes.int32).assign(Tensor([12345], dtype=dtypes.int64)) def test_assign_to_bitcast_view(self): # assign float values to a float32 view of a uint8 disk buffer (used by safe_save) diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 7df92bf95d..5af4d250b7 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -78,7 +78,7 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:dict[str, Any]|None=No j += "\x20"*(round_up(len(j),8)-len(j)) pathlib.Path(fn).unlink(missing_ok=True) t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}") - t[0:8].assign(Tensor([len(j)], dtype=dtypes.int64, device="CPU").bitcast(dtypes.uint8)) + t[0:8].bitcast(dtypes.int64).assign([len(j)]) t[8:8+len(j)].assign(list(j.encode('utf-8'))) for k,v in safe_load(t).items(): v.assign(tensors[k]) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 924769b655..4247b70776 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -316,7 +316,7 @@ class Tensor(OpMixin): if self.shape != x.shape: x = x._broadcast_to(self.shape) if self.shape != x.shape: raise RuntimeError(f"assign shape mismatch {self.shape} != {x.shape}") if not is_disk and self.device != x.device: raise RuntimeError(f"assign device mismatch {self.device} != {x.device}") - if self.dtype != x.dtype: raise RuntimeError(f"assign dtype mismatch {self.dtype} != {x.dtype}") + if not is_disk and self.dtype != x.dtype: raise RuntimeError(f"assign dtype mismatch {self.dtype} != {x.dtype}") if isinstance(self.device, tuple) and self.uop.axis != x.uop.axis: raise RuntimeError(f"multi axis mismatch {self.uop.axis} != {x.uop.axis}") # TODO: this is a hack for writing to DISK. remove with working assign @@ -3569,10 +3569,7 @@ class Tensor(OpMixin): def bitcast(self, dtype:DTypeLike) -> Tensor: """ - Bitcasts `self` to the given `dtype`. - - When the target dtype has the same itemsize, this is a view of the same memory. - When itemsizes differ, the last dimension is adjusted and a new Tensor is created. + Bitcasts `self` to the given `dtype` of the same itemsize. `self` must not require a gradient.