mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
don't implicitly change dtype in assign (#14481)
broadcast shape is fine, but implicitly cast dtype is hard to find
This commit is contained in:
@@ -195,8 +195,10 @@ class TestAssignIssues(unittest.TestCase):
|
||||
t.shrink(((1, 3), (1, 3))).assign(Tensor.ones(2, 2))
|
||||
np.testing.assert_allclose(t.numpy(), torch_tensor.numpy())
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_assign_broadcast(self):
|
||||
# broadcasting during assign should behave like PyTorch
|
||||
# NOTE: we don't want implicit dtype casting (int64 -> float32 loses precision), so this fails
|
||||
torch_tensor = torch.zeros(3, 5)
|
||||
torch_tensor[:] = torch.arange(5)
|
||||
t = Tensor.zeros(3, 5)
|
||||
|
||||
@@ -493,6 +493,38 @@ class TestAssign(unittest.TestCase):
|
||||
assert oba1 is None and oba2 is None
|
||||
np.testing.assert_allclose(a.numpy(), np.arange(N*N,dtype=np.int32).reshape((N,N)))
|
||||
|
||||
def test_assign_dtype_mismatch(self):
|
||||
# assign should not implicitly cast dtypes - this can lose precision
|
||||
a = Tensor.zeros(4, dtype=dtypes.float32).contiguous().realize()
|
||||
b = Tensor([1, 2, 3, 4], dtype=dtypes.int32)
|
||||
with self.assertRaisesRegex(RuntimeError, "assign dtype mismatch"):
|
||||
a.assign(b)
|
||||
|
||||
def test_assign_dtype_mismatch_int64_to_float32(self):
|
||||
# int64 -> float32 loses precision for large values, should not be implicit
|
||||
a = Tensor.zeros(1, dtype=dtypes.float32).contiguous().realize()
|
||||
b = Tensor([16777217], dtype=dtypes.int64) # 2^24 + 1, not exactly representable in float32
|
||||
with self.assertRaisesRegex(RuntimeError, "assign dtype mismatch"):
|
||||
a.assign(b)
|
||||
|
||||
def test_assign_shape_broadcast(self):
|
||||
# shape broadcasting should work when dtypes match
|
||||
a = Tensor.zeros(3, 5, dtype=dtypes.float32).contiguous().realize()
|
||||
b = Tensor([1., 2., 3., 4., 5.], dtype=dtypes.float32)
|
||||
a.assign(b)
|
||||
a.realize()
|
||||
expected = np.array([[1., 2., 3., 4., 5.]] * 3)
|
||||
np.testing.assert_allclose(a.numpy(), expected)
|
||||
|
||||
def test_assign_shape_broadcast_2d(self):
|
||||
# broadcast (1, 5) to (3, 5)
|
||||
a = Tensor.zeros(3, 5, dtype=dtypes.float32).contiguous().realize()
|
||||
b = Tensor([[1., 2., 3., 4., 5.]], dtype=dtypes.float32)
|
||||
a.assign(b)
|
||||
a.realize()
|
||||
expected = np.array([[1., 2., 3., 4., 5.]] * 3)
|
||||
np.testing.assert_allclose(a.numpy(), expected)
|
||||
|
||||
def test_disk_assignment(self):
|
||||
a = Tensor.empty(5, device=f"disk:{temp('disk_assignment')}").assign(Tensor.ones(5)).numpy()
|
||||
np.testing.assert_equal(a, np.ones(5))
|
||||
@@ -587,12 +619,12 @@ class TestAssignOrdering(unittest.TestCase):
|
||||
def test_slice_write_then_full_read(self):
|
||||
"""Write to slice, then read full buffer."""
|
||||
# without .realize(): orphan slice assign not triggered by .numpy()
|
||||
buf = Tensor.zeros(4).contiguous().realize()
|
||||
buf = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
|
||||
buf[1:3].assign(Tensor([5, 6]))
|
||||
np.testing.assert_equal(buf.numpy(), [0, 0, 0, 0]) # TODO: wrong! should be [0, 5, 6, 0]
|
||||
|
||||
# with .realize(): assign executes
|
||||
buf = Tensor.zeros(4).contiguous().realize()
|
||||
buf = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
|
||||
buf[1:3].assign(Tensor([5, 6])).realize()
|
||||
np.testing.assert_equal(buf.numpy(), [0, 5, 6, 0])
|
||||
|
||||
@@ -674,7 +706,7 @@ class TestAssignOrdering(unittest.TestCase):
|
||||
|
||||
def test_three_buffer_chain(self):
|
||||
"""Chain: A depends on B, B depends on C - ordering matters."""
|
||||
a = Tensor.zeros(4).contiguous().realize()
|
||||
a = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
|
||||
b = Tensor([1, 2, 3, 4]).contiguous().realize()
|
||||
c = Tensor([10, 10, 10, 10]).contiguous().realize()
|
||||
# b reads from c, a reads from b
|
||||
@@ -686,8 +718,8 @@ class TestAssignOrdering(unittest.TestCase):
|
||||
|
||||
def test_interleaved_assign_read_patterns(self):
|
||||
"""Complex interleaved pattern: write A, read A into B, write B, read B."""
|
||||
a = Tensor.zeros(4).contiguous().realize()
|
||||
b = Tensor.zeros(4).contiguous().realize()
|
||||
a = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
|
||||
b = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
|
||||
|
||||
a.assign(Tensor([1, 2, 3, 4]))
|
||||
b.assign(a.contiguous()) # b should get [1,2,3,4]
|
||||
|
||||
@@ -289,8 +289,8 @@ class Tensor(OpMixin):
|
||||
is_disk = isinstance(self.device, str) and self.device.startswith("DISK")
|
||||
if not isinstance(x, Tensor): x = Tensor(x, device="CPU" if is_disk else self.device, dtype=self.dtype)
|
||||
if self.uop is x.uop: return self # a self assign is a NOOP
|
||||
# broadcast x
|
||||
if least_upper_dtype(self.dtype, x.dtype) == self.dtype: x = x._broadcast_to(self.shape).cast(self.dtype)
|
||||
# broadcast x (shape only, dtype must match)
|
||||
if self.shape != x.shape: x = x._broadcast_to(self.shape)
|
||||
if self.shape != x.shape: raise RuntimeError(f"assign shape mismatch {self.shape} != {x.shape}")
|
||||
if not is_disk and self.device != x.device: raise RuntimeError(f"assign device mismatch {self.device} != {x.device}")
|
||||
if self.dtype != x.dtype: raise RuntimeError(f"assign dtype mismatch {self.dtype} != {x.dtype}")
|
||||
|
||||
Reference in New Issue
Block a user