move canonicalize_device to device.py (#15675)

This commit is contained in:
chenyu
2026-04-10 09:43:56 -04:00
committed by GitHub
parent 8e7fcc8ca3
commit e1334d3852
3 changed files with 6 additions and 8 deletions

View File

@@ -54,6 +54,10 @@ class _Device:
Device: _Device = _Device()
atexit.register(lambda: [Device[dn].finalize() for dn in Device._opened_devices])
def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]:
if not isinstance(device, (tuple, list)): return Device.canonicalize(device)
return canonical[0] if len(canonical:=tuple(Device.canonicalize(d) for d in device)) == 1 else canonical
# **************** Profile ****************
@dataclass(frozen=True)

View File

@@ -14,15 +14,10 @@ from tinygrad.mixin import OpMixin
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, Variable
from tinygrad.uop.ops import _broadcast_shape
from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars
from tinygrad.device import Buffer
from tinygrad.device import Buffer, canonicalize_device
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.allocations import transform_to_call
def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]:
from tinygrad.device import Device
if not isinstance(device, (tuple, list)): return Device.canonicalize(device)
return canonical[0] if len(canonical:=tuple(Device.canonicalize(d) for d in device)) == 1 else canonical
# *** all in scope Tensors are here. this gets relevant UOps ***
all_tensors: dict[weakref.ref[Tensor], None] = {}

View File

@@ -6,11 +6,11 @@ from enum import Enum, auto
from tinygrad.uop import Ops, GroupOp
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, AddrSpace, ConstFloat, PyConst
from tinygrad.dtype import storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CAPTURE_PROCESS_REPLAY
from tinygrad.helpers import strip_parens, colored, ansilen, printable
if TYPE_CHECKING:
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.renderer import Estimates
class AxisType(Enum):
@@ -713,7 +713,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
@property
def buffer(self) -> Buffer|MultiBuffer:
from tinygrad.device import Buffer, MultiBuffer
if self.op in {Ops.CONTIGUOUS, Ops.RESHAPE, Ops.DETACH, Ops.AFTER}: return self.src[0].buffer
# this buffer can process disk tensors and simple movement ops
if self is not self.base: