diff --git a/tinygrad/device.py b/tinygrad/device.py index 770a842a3c..e7c1aad366 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -54,6 +54,10 @@ class _Device: Device: _Device = _Device() atexit.register(lambda: [Device[dn].finalize() for dn in Device._opened_devices]) +def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]: + if not isinstance(device, (tuple, list)): return Device.canonicalize(device) + return canonical[0] if len(canonical:=tuple(Device.canonicalize(d) for d in device)) == 1 else canonical + # **************** Profile **************** @dataclass(frozen=True) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index da45cd4c14..3123cf95d9 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -14,15 +14,10 @@ from tinygrad.mixin import OpMixin from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, Variable from tinygrad.uop.ops import _broadcast_shape from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars -from tinygrad.device import Buffer +from tinygrad.device import Buffer, canonicalize_device from tinygrad.engine.realize import run_schedule from tinygrad.engine.allocations import transform_to_call -def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]: - from tinygrad.device import Device - if not isinstance(device, (tuple, list)): return Device.canonicalize(device) - return canonical[0] if len(canonical:=tuple(Device.canonicalize(d) for d in device)) == 1 else canonical - # *** all in scope Tensors are here. this gets relevant UOps *** all_tensors: dict[weakref.ref[Tensor], None] = {} diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index f7a89ad3be..eac45c59f5 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -6,11 +6,11 @@ from enum import Enum, auto from tinygrad.uop import Ops, GroupOp from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, AddrSpace, ConstFloat, PyConst from tinygrad.dtype import storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar +from tinygrad.device import Buffer, MultiBuffer from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CAPTURE_PROCESS_REPLAY from tinygrad.helpers import strip_parens, colored, ansilen, printable if TYPE_CHECKING: - from tinygrad.device import Buffer, MultiBuffer from tinygrad.renderer import Estimates class AxisType(Enum): @@ -713,7 +713,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass): @property def buffer(self) -> Buffer|MultiBuffer: - from tinygrad.device import Buffer, MultiBuffer if self.op in {Ops.CONTIGUOUS, Ops.RESHAPE, Ops.DETACH, Ops.AFTER}: return self.src[0].buffer # this buffer can process disk tensors and simple movement ops if self is not self.base: