don't implicitly change dtype in assign (#14481)

broadcast shape is fine, but implicitly cast dtype is hard to find
This commit is contained in:
chenyu
2026-02-01 11:48:54 -05:00
committed by GitHub
parent 2111762a48
commit 3ff390159b
3 changed files with 41 additions and 7 deletions

View File

@@ -195,8 +195,10 @@ class TestAssignIssues(unittest.TestCase):
t.shrink(((1, 3), (1, 3))).assign(Tensor.ones(2, 2))
np.testing.assert_allclose(t.numpy(), torch_tensor.numpy())
@unittest.expectedFailure
def test_assign_broadcast(self):
# broadcasting during assign should behave like PyTorch
# NOTE: we don't want implicit dtype casting (int64 -> float32 loses precision), so this fails
torch_tensor = torch.zeros(3, 5)
torch_tensor[:] = torch.arange(5)
t = Tensor.zeros(3, 5)

View File

@@ -493,6 +493,38 @@ class TestAssign(unittest.TestCase):
assert oba1 is None and oba2 is None
np.testing.assert_allclose(a.numpy(), np.arange(N*N,dtype=np.int32).reshape((N,N)))
def test_assign_dtype_mismatch(self):
# assign should not implicitly cast dtypes - this can lose precision
a = Tensor.zeros(4, dtype=dtypes.float32).contiguous().realize()
b = Tensor([1, 2, 3, 4], dtype=dtypes.int32)
with self.assertRaisesRegex(RuntimeError, "assign dtype mismatch"):
a.assign(b)
def test_assign_dtype_mismatch_int64_to_float32(self):
# int64 -> float32 loses precision for large values, should not be implicit
a = Tensor.zeros(1, dtype=dtypes.float32).contiguous().realize()
b = Tensor([16777217], dtype=dtypes.int64) # 2^24 + 1, not exactly representable in float32
with self.assertRaisesRegex(RuntimeError, "assign dtype mismatch"):
a.assign(b)
def test_assign_shape_broadcast(self):
# shape broadcasting should work when dtypes match
a = Tensor.zeros(3, 5, dtype=dtypes.float32).contiguous().realize()
b = Tensor([1., 2., 3., 4., 5.], dtype=dtypes.float32)
a.assign(b)
a.realize()
expected = np.array([[1., 2., 3., 4., 5.]] * 3)
np.testing.assert_allclose(a.numpy(), expected)
def test_assign_shape_broadcast_2d(self):
# broadcast (1, 5) to (3, 5)
a = Tensor.zeros(3, 5, dtype=dtypes.float32).contiguous().realize()
b = Tensor([[1., 2., 3., 4., 5.]], dtype=dtypes.float32)
a.assign(b)
a.realize()
expected = np.array([[1., 2., 3., 4., 5.]] * 3)
np.testing.assert_allclose(a.numpy(), expected)
def test_disk_assignment(self):
a = Tensor.empty(5, device=f"disk:{temp('disk_assignment')}").assign(Tensor.ones(5)).numpy()
np.testing.assert_equal(a, np.ones(5))
@@ -587,12 +619,12 @@ class TestAssignOrdering(unittest.TestCase):
def test_slice_write_then_full_read(self):
"""Write to slice, then read full buffer."""
# without .realize(): orphan slice assign not triggered by .numpy()
buf = Tensor.zeros(4).contiguous().realize()
buf = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
buf[1:3].assign(Tensor([5, 6]))
np.testing.assert_equal(buf.numpy(), [0, 0, 0, 0]) # TODO: wrong! should be [0, 5, 6, 0]
# with .realize(): assign executes
buf = Tensor.zeros(4).contiguous().realize()
buf = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
buf[1:3].assign(Tensor([5, 6])).realize()
np.testing.assert_equal(buf.numpy(), [0, 5, 6, 0])
@@ -674,7 +706,7 @@ class TestAssignOrdering(unittest.TestCase):
def test_three_buffer_chain(self):
"""Chain: A depends on B, B depends on C - ordering matters."""
a = Tensor.zeros(4).contiguous().realize()
a = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
b = Tensor([1, 2, 3, 4]).contiguous().realize()
c = Tensor([10, 10, 10, 10]).contiguous().realize()
# b reads from c, a reads from b
@@ -686,8 +718,8 @@ class TestAssignOrdering(unittest.TestCase):
def test_interleaved_assign_read_patterns(self):
"""Complex interleaved pattern: write A, read A into B, write B, read B."""
a = Tensor.zeros(4).contiguous().realize()
b = Tensor.zeros(4).contiguous().realize()
a = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
b = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
a.assign(Tensor([1, 2, 3, 4]))
b.assign(a.contiguous()) # b should get [1,2,3,4]

View File

@@ -289,8 +289,8 @@ class Tensor(OpMixin):
is_disk = isinstance(self.device, str) and self.device.startswith("DISK")
if not isinstance(x, Tensor): x = Tensor(x, device="CPU" if is_disk else self.device, dtype=self.dtype)
if self.uop is x.uop: return self # a self assign is a NOOP
# broadcast x
if least_upper_dtype(self.dtype, x.dtype) == self.dtype: x = x._broadcast_to(self.shape).cast(self.dtype)
# broadcast x (shape only, dtype must match)
if self.shape != x.shape: x = x._broadcast_to(self.shape)
if self.shape != x.shape: raise RuntimeError(f"assign shape mismatch {self.shape} != {x.shape}")
if not is_disk and self.device != x.device: raise RuntimeError(f"assign device mismatch {self.device} != {x.device}")
if self.dtype != x.dtype: raise RuntimeError(f"assign dtype mismatch {self.dtype} != {x.dtype}")