From fc32ff80d64951744d4321cfec3c74db008c925e Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 24 Feb 2025 18:26:49 +0800 Subject: [PATCH] torch and numpy dtype interop [pr] (#9224) * torch and numpy dtype interop [pr] * less lines * order --- extra/torch_backend/backend.py | 19 ++++--------------- extra/torch_backend/backend2.py | 4 ++-- tinygrad/dtype.py | 20 ++++++++++++++++++++ tinygrad/tensor.py | 7 +------ 4 files changed, 27 insertions(+), 23 deletions(-) diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index c199643a10..6f138c936d 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -3,24 +3,13 @@ from tinygrad.helpers import DEBUG, getenv, prod TORCH_DEBUG = getenv("TORCH_DEBUG") import torch, pathlib torch.autograd.grad_mode.set_multithreading_enabled(False) +from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype # https://pytorch.org/docs/stable/torch.compiler_ir.html -# TODO: don't replicate this in cpp -torch_to_tiny_dtype = { - torch.float32: dtypes.float32, - torch.float64: dtypes.float64, - torch.uint8: dtypes.uint8, - torch.int8: dtypes.int8, - torch.int32: dtypes.int32, - torch.int64: dtypes.int64, - torch.bool: dtypes.bool, -} -tiny_to_torch_dtype = {v: k for k, v in torch_to_tiny_dtype.items()} - import torch.utils.cpp_extension mod = torch.utils.cpp_extension.load(name="custom_device_extension", sources=[pathlib.Path(__file__).parent / "wrapped_tensor.cpp"]) -def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, tiny_to_torch_dtype[x.dtype]) +def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, _to_torch_dtype(x.dtype)) def unwrap(x:torch.Tensor) -> Tensor: assert isinstance(x, torch.Tensor), f"x isn't {type(x)}" return mod.unwrap(x) @@ -66,13 +55,13 @@ def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None): @torch.library.impl("aten::empty_strided", "privateuseone") def empty_strided(size, stride, dtype, layout=None, device=None, pin_memory=False): if TORCH_DEBUG: print(f"empty_strided {size=} {stride=} {dtype=} {layout=} {device=} {pin_memory=}") - ret = Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype]) + ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype)) return wrap(ret) @torch.library.impl("aten::empty.memory_format", "privateuseone") def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None): if TORCH_DEBUG: print(f"empty.memory_format {size=} {dtype=} {layout=} {device=} {pin_memory=} {memory_format=}") - ret = Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype or torch.get_default_dtype()]) + ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())) return wrap(ret) @torch.library.impl("aten::max_pool2d_with_indices", "privateuseone") diff --git a/extra/torch_backend/backend2.py b/extra/torch_backend/backend2.py index 0138c98792..9ce3bb233d 100644 --- a/extra/torch_backend/backend2.py +++ b/extra/torch_backend/backend2.py @@ -5,10 +5,10 @@ from tinygrad import Tensor import torch, contextlib from torch.utils._python_dispatch import TorchDispatchMode -from extra.torch_backend.backend import torch_to_tiny_dtype +from tinygrad.dtype import _from_torch_dtype def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None): - return TTensor(Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype])) + return TTensor(Tensor.empty(*size, dtype=_from_torch_dtype(dtype))) # NOTE: if we have a way to change wrap/unwrap, these can be the same methods from backend.py tiny_backend = { diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 746262a3fc..2259b51f34 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -147,6 +147,7 @@ class dtypes: uints = (uint8, uint16, uint32, uint64) sints = (int8, int16, int32, int64) ints = uints + sints + all = floats + ints + (bool,) if (env_default_float := getenv("DEFAULT_FLOAT", "")): dtypes.default_float = getattr(dtypes, env_default_float.lower()) @@ -197,3 +198,22 @@ truncate: dict[DType, Callable] = {dtypes.bool: bool, dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value, dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value, dtypes.int64: lambda x: ctypes.c_int64(x).value} + +# numpy and torch dtype interop + +def _to_np_dtype(dtype:DType) -> Optional[type]: + import numpy as np + return np.dtype(dtype.fmt).type if dtype.fmt is not None else None +def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821 + import numpy as np + return dtypes.fields()[np.dtype(npdtype).name] + +@functools.lru_cache(None) +def _to_torch_dtype(dtype:DType) -> Optional['torch.dtype']: # type: ignore [name-defined] # noqa: F821 + import numpy as np, torch + # NOTE: torch doesn't expose this mapping with a stable API + try: return torch.from_numpy(np.array([], dtype=_to_np_dtype(dtype))).dtype + except TypeError: return None +@functools.lru_cache(None) +def _from_torch_dtype(torchdtype:'torch.dtype') -> DType: # type: ignore [name-defined] # noqa: F821 + return {v:k for k in dtypes.all if (v:=_to_torch_dtype(k)) is not None}[torchdtype] \ No newline at end of file diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index ee310e3074..fa0853143e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -4,6 +4,7 @@ import time, math, itertools, functools, struct, sys, inspect, pathlib, string, from contextlib import ContextDecorator from typing import Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate +from tinygrad.dtype import _from_np_dtype, _to_np_dtype from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap from tinygrad.engine.multi import get_multi_map @@ -49,12 +50,6 @@ def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg) return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None) -def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821 - import numpy as np - return dtypes.fields()[np.dtype(npdtype).name] -def _to_np_dtype(dtype:DType) -> Optional[type]: - import numpy as np - return np.dtype(dtype.fmt).type if dtype.fmt is not None else None def _fromnp(x: 'np.ndarray') -> UOp: # type: ignore [name-defined] # noqa: F821 ret = UOp.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")