From 3ff390159b17b0e60355e22fedbc102a4e807d54 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 1 Feb 2026 11:48:54 -0500 Subject: [PATCH] don't implicitly change dtype in assign (#14481) broadcast shape is fine, but implicitly cast dtype is hard to find --- test/test_edgecases.py | 2 ++ test/unit/test_assign.py | 42 +++++++++++++++++++++++++++++++++++----- tinygrad/tensor.py | 4 ++-- 3 files changed, 41 insertions(+), 7 deletions(-) diff --git a/test/test_edgecases.py b/test/test_edgecases.py index 9f24524ae2..59358a5a0e 100644 --- a/test/test_edgecases.py +++ b/test/test_edgecases.py @@ -195,8 +195,10 @@ class TestAssignIssues(unittest.TestCase): t.shrink(((1, 3), (1, 3))).assign(Tensor.ones(2, 2)) np.testing.assert_allclose(t.numpy(), torch_tensor.numpy()) + @unittest.expectedFailure def test_assign_broadcast(self): # broadcasting during assign should behave like PyTorch + # NOTE: we don't want implicit dtype casting (int64 -> float32 loses precision), so this fails torch_tensor = torch.zeros(3, 5) torch_tensor[:] = torch.arange(5) t = Tensor.zeros(3, 5) diff --git a/test/unit/test_assign.py b/test/unit/test_assign.py index 93bd9b9149..71d59a8607 100644 --- a/test/unit/test_assign.py +++ b/test/unit/test_assign.py @@ -493,6 +493,38 @@ class TestAssign(unittest.TestCase): assert oba1 is None and oba2 is None np.testing.assert_allclose(a.numpy(), np.arange(N*N,dtype=np.int32).reshape((N,N))) + def test_assign_dtype_mismatch(self): + # assign should not implicitly cast dtypes - this can lose precision + a = Tensor.zeros(4, dtype=dtypes.float32).contiguous().realize() + b = Tensor([1, 2, 3, 4], dtype=dtypes.int32) + with self.assertRaisesRegex(RuntimeError, "assign dtype mismatch"): + a.assign(b) + + def test_assign_dtype_mismatch_int64_to_float32(self): + # int64 -> float32 loses precision for large values, should not be implicit + a = Tensor.zeros(1, dtype=dtypes.float32).contiguous().realize() + b = Tensor([16777217], dtype=dtypes.int64) # 2^24 + 1, not exactly representable in float32 + with self.assertRaisesRegex(RuntimeError, "assign dtype mismatch"): + a.assign(b) + + def test_assign_shape_broadcast(self): + # shape broadcasting should work when dtypes match + a = Tensor.zeros(3, 5, dtype=dtypes.float32).contiguous().realize() + b = Tensor([1., 2., 3., 4., 5.], dtype=dtypes.float32) + a.assign(b) + a.realize() + expected = np.array([[1., 2., 3., 4., 5.]] * 3) + np.testing.assert_allclose(a.numpy(), expected) + + def test_assign_shape_broadcast_2d(self): + # broadcast (1, 5) to (3, 5) + a = Tensor.zeros(3, 5, dtype=dtypes.float32).contiguous().realize() + b = Tensor([[1., 2., 3., 4., 5.]], dtype=dtypes.float32) + a.assign(b) + a.realize() + expected = np.array([[1., 2., 3., 4., 5.]] * 3) + np.testing.assert_allclose(a.numpy(), expected) + def test_disk_assignment(self): a = Tensor.empty(5, device=f"disk:{temp('disk_assignment')}").assign(Tensor.ones(5)).numpy() np.testing.assert_equal(a, np.ones(5)) @@ -587,12 +619,12 @@ class TestAssignOrdering(unittest.TestCase): def test_slice_write_then_full_read(self): """Write to slice, then read full buffer.""" # without .realize(): orphan slice assign not triggered by .numpy() - buf = Tensor.zeros(4).contiguous().realize() + buf = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize() buf[1:3].assign(Tensor([5, 6])) np.testing.assert_equal(buf.numpy(), [0, 0, 0, 0]) # TODO: wrong! should be [0, 5, 6, 0] # with .realize(): assign executes - buf = Tensor.zeros(4).contiguous().realize() + buf = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize() buf[1:3].assign(Tensor([5, 6])).realize() np.testing.assert_equal(buf.numpy(), [0, 5, 6, 0]) @@ -674,7 +706,7 @@ class TestAssignOrdering(unittest.TestCase): def test_three_buffer_chain(self): """Chain: A depends on B, B depends on C - ordering matters.""" - a = Tensor.zeros(4).contiguous().realize() + a = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize() b = Tensor([1, 2, 3, 4]).contiguous().realize() c = Tensor([10, 10, 10, 10]).contiguous().realize() # b reads from c, a reads from b @@ -686,8 +718,8 @@ class TestAssignOrdering(unittest.TestCase): def test_interleaved_assign_read_patterns(self): """Complex interleaved pattern: write A, read A into B, write B, read B.""" - a = Tensor.zeros(4).contiguous().realize() - b = Tensor.zeros(4).contiguous().realize() + a = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize() + b = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize() a.assign(Tensor([1, 2, 3, 4])) b.assign(a.contiguous()) # b should get [1,2,3,4] diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index f7a09b4116..aa45bdfa6e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -289,8 +289,8 @@ class Tensor(OpMixin): is_disk = isinstance(self.device, str) and self.device.startswith("DISK") if not isinstance(x, Tensor): x = Tensor(x, device="CPU" if is_disk else self.device, dtype=self.dtype) if self.uop is x.uop: return self # a self assign is a NOOP - # broadcast x - if least_upper_dtype(self.dtype, x.dtype) == self.dtype: x = x._broadcast_to(self.shape).cast(self.dtype) + # broadcast x (shape only, dtype must match) + if self.shape != x.shape: x = x._broadcast_to(self.shape) if self.shape != x.shape: raise RuntimeError(f"assign shape mismatch {self.shape} != {x.shape}") if not is_disk and self.device != x.device: raise RuntimeError(f"assign device mismatch {self.device} != {x.device}") if self.dtype != x.dtype: raise RuntimeError(f"assign dtype mismatch {self.dtype} != {x.dtype}")