diff --git a/test/null/test_tensor.py b/test/null/test_tensor.py index 84fdae1a5a..d288e94b54 100644 --- a/test/null/test_tensor.py +++ b/test/null/test_tensor.py @@ -170,5 +170,9 @@ class TestRand(unittest.TestCase): Tensor.rand(2**17, 2**17).schedule() Tensor.rand(2**17, 2**17).schedule() +class TestTensorDevice(unittest.TestCase): + def test_create_from_single_device_tuple(self): + (Tensor([1.0], device=(Device.DEFAULT,)) + Tensor([2.0])).realize() + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b23c254fb7..9b12b07f5d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -20,7 +20,8 @@ from tinygrad.engine.allocations import transform_to_call # TODO: this should be the only usage of Device def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]: - return tuple(Device.canonicalize(d) for d in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device) + if not isinstance(device, (tuple, list)): return Device.canonicalize(device) + return canonical[0] if len(canonical:=tuple(Device.canonicalize(d) for d in device)) == 1 else canonical # *** all in scope Tensors are here. this gets relevant UOps ***