mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
@@ -18,7 +18,8 @@ from tinygrad.device import Device, Buffer
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
|
||||
# TODO: this should be the only usage of Device
|
||||
def canonicalize_device(device:str|None) -> str: return Device.canonicalize(device)
|
||||
def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]:
|
||||
return tuple(Device.canonicalize(d) for d in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
||||
|
||||
# *** all in scope Tensors are here. this gets relevant UOps ***
|
||||
|
||||
@@ -115,7 +116,7 @@ class Tensor(OpMixin):
|
||||
device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None, _force_unique:bool=False):
|
||||
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
|
||||
_dtype:DType|None = to_dtype(dtype) if dtype is not None else None
|
||||
_device:str|tuple[str, ...] = tuple(canonicalize_device(x) for x in device) if isinstance(device, (tuple, list)) else canonicalize_device(device)
|
||||
_device:str|tuple[str, ...] = canonicalize_device(device)
|
||||
del device, dtype
|
||||
|
||||
# tensors can have gradients if you have called .backward
|
||||
@@ -373,7 +374,7 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
Moves the tensor to the given device.
|
||||
"""
|
||||
device = tuple(canonicalize_device(x) for x in device) if isinstance(device, (tuple, list)) else canonicalize_device(device)
|
||||
device = canonicalize_device(device)
|
||||
if device == self.device: return self
|
||||
if not isinstance(device, str): return self.shard(device)
|
||||
ret = Tensor(self.uop, device, requires_grad=self.requires_grad)
|
||||
@@ -399,7 +400,7 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
if not isinstance(self.device, str): raise RuntimeError("can't shard a multi-device tensor")
|
||||
if len(devices) == 1: return self.to(devices[0])
|
||||
devices = tuple(canonicalize_device(x) for x in devices)
|
||||
devices = cast(tuple[str, ...], canonicalize_device(devices))
|
||||
mlb = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices)
|
||||
return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
|
||||
|
||||
@@ -495,7 +496,7 @@ class Tensor(OpMixin):
|
||||
dtype, shape = to_dtype(dtype) if dtype is not None else dtypes.default_float, argfix(*shape)
|
||||
if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
|
||||
# TODO: add test for multidevice tensor
|
||||
device = tuple(canonicalize_device(d) for d in device) if isinstance(device, tuple) else canonicalize_device(device)
|
||||
device = canonicalize_device(device)
|
||||
return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).shrink(((0,prod(shape)),)).reshape(shape)
|
||||
|
||||
def empty_like(self, **kwargs) -> Tensor:
|
||||
@@ -577,7 +578,7 @@ class Tensor(OpMixin):
|
||||
if not dtypes.is_float(dtype := to_dtype(dtype or dtypes.default_float)): raise ValueError(f"rand only supports float dtypes, got {dtype}")
|
||||
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
|
||||
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
|
||||
device = canonicalize_device(device)
|
||||
device = cast(str, canonicalize_device(device))
|
||||
|
||||
# if shape has 0, return zero tensor
|
||||
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user