From eda6a73897e9564de2950b1544443738d5cd8e9d Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 5 Jan 2026 10:29:55 -0500 Subject: [PATCH] clean up canonicalize_device (#14027) centralize the type check --- tinygrad/tensor.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 0f52ac2d7c..9fce183ada 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -18,7 +18,8 @@ from tinygrad.device import Device, Buffer from tinygrad.engine.realize import run_schedule # TODO: this should be the only usage of Device -def canonicalize_device(device:str|None) -> str: return Device.canonicalize(device) +def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]: + return tuple(Device.canonicalize(d) for d in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device) # *** all in scope Tensors are here. this gets relevant UOps *** @@ -115,7 +116,7 @@ class Tensor(OpMixin): device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None, _force_unique:bool=False): if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None _dtype:DType|None = to_dtype(dtype) if dtype is not None else None - _device:str|tuple[str, ...] = tuple(canonicalize_device(x) for x in device) if isinstance(device, (tuple, list)) else canonicalize_device(device) + _device:str|tuple[str, ...] = canonicalize_device(device) del device, dtype # tensors can have gradients if you have called .backward @@ -373,7 +374,7 @@ class Tensor(OpMixin): """ Moves the tensor to the given device. """ - device = tuple(canonicalize_device(x) for x in device) if isinstance(device, (tuple, list)) else canonicalize_device(device) + device = canonicalize_device(device) if device == self.device: return self if not isinstance(device, str): return self.shard(device) ret = Tensor(self.uop, device, requires_grad=self.requires_grad) @@ -399,7 +400,7 @@ class Tensor(OpMixin): """ if not isinstance(self.device, str): raise RuntimeError("can't shard a multi-device tensor") if len(devices) == 1: return self.to(devices[0]) - devices = tuple(canonicalize_device(x) for x in devices) + devices = cast(tuple[str, ...], canonicalize_device(devices)) mlb = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices) return Tensor(mlb, device=devices, requires_grad=self.requires_grad) @@ -495,7 +496,7 @@ class Tensor(OpMixin): dtype, shape = to_dtype(dtype) if dtype is not None else dtypes.default_float, argfix(*shape) if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}") # TODO: add test for multidevice tensor - device = tuple(canonicalize_device(d) for d in device) if isinstance(device, tuple) else canonicalize_device(device) + device = canonicalize_device(device) return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).shrink(((0,prod(shape)),)).reshape(shape) def empty_like(self, **kwargs) -> Tensor: @@ -577,7 +578,7 @@ class Tensor(OpMixin): if not dtypes.is_float(dtype := to_dtype(dtype or dtypes.default_float)): raise ValueError(f"rand only supports float dtypes, got {dtype}") if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}") if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}") - device = canonicalize_device(device) + device = cast(str, canonicalize_device(device)) # if shape has 0, return zero tensor if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)