clean up canonicalize_device (#14027)

centralize the type check
This commit is contained in:
chenyu
2026-01-05 10:29:55 -05:00
committed by GitHub
parent ce464b147a
commit eda6a73897

View File

@@ -18,7 +18,8 @@ from tinygrad.device import Device, Buffer
from tinygrad.engine.realize import run_schedule
# TODO: this should be the only usage of Device
def canonicalize_device(device:str|None) -> str: return Device.canonicalize(device)
def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]:
return tuple(Device.canonicalize(d) for d in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
# *** all in scope Tensors are here. this gets relevant UOps ***
@@ -115,7 +116,7 @@ class Tensor(OpMixin):
device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None, _force_unique:bool=False):
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
_dtype:DType|None = to_dtype(dtype) if dtype is not None else None
_device:str|tuple[str, ...] = tuple(canonicalize_device(x) for x in device) if isinstance(device, (tuple, list)) else canonicalize_device(device)
_device:str|tuple[str, ...] = canonicalize_device(device)
del device, dtype
# tensors can have gradients if you have called .backward
@@ -373,7 +374,7 @@ class Tensor(OpMixin):
"""
Moves the tensor to the given device.
"""
device = tuple(canonicalize_device(x) for x in device) if isinstance(device, (tuple, list)) else canonicalize_device(device)
device = canonicalize_device(device)
if device == self.device: return self
if not isinstance(device, str): return self.shard(device)
ret = Tensor(self.uop, device, requires_grad=self.requires_grad)
@@ -399,7 +400,7 @@ class Tensor(OpMixin):
"""
if not isinstance(self.device, str): raise RuntimeError("can't shard a multi-device tensor")
if len(devices) == 1: return self.to(devices[0])
devices = tuple(canonicalize_device(x) for x in devices)
devices = cast(tuple[str, ...], canonicalize_device(devices))
mlb = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices)
return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
@@ -495,7 +496,7 @@ class Tensor(OpMixin):
dtype, shape = to_dtype(dtype) if dtype is not None else dtypes.default_float, argfix(*shape)
if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
# TODO: add test for multidevice tensor
device = tuple(canonicalize_device(d) for d in device) if isinstance(device, tuple) else canonicalize_device(device)
device = canonicalize_device(device)
return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).shrink(((0,prod(shape)),)).reshape(shape)
def empty_like(self, **kwargs) -> Tensor:
@@ -577,7 +578,7 @@ class Tensor(OpMixin):
if not dtypes.is_float(dtype := to_dtype(dtype or dtypes.default_float)): raise ValueError(f"rand only supports float dtypes, got {dtype}")
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
device = canonicalize_device(device)
device = cast(str, canonicalize_device(device))
# if shape has 0, return zero tensor
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)