torch and numpy dtype interop [pr] (#9224)

* torch and numpy dtype interop [pr]

* less lines

* order
This commit is contained in:
George Hotz
2025-02-24 18:26:49 +08:00
committed by GitHub
parent 24615db5f5
commit fc32ff80d6
4 changed files with 27 additions and 23 deletions

View File

@@ -3,24 +3,13 @@ from tinygrad.helpers import DEBUG, getenv, prod
TORCH_DEBUG = getenv("TORCH_DEBUG")
import torch, pathlib
torch.autograd.grad_mode.set_multithreading_enabled(False)
from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype
# https://pytorch.org/docs/stable/torch.compiler_ir.html
# TODO: don't replicate this in cpp
torch_to_tiny_dtype = {
torch.float32: dtypes.float32,
torch.float64: dtypes.float64,
torch.uint8: dtypes.uint8,
torch.int8: dtypes.int8,
torch.int32: dtypes.int32,
torch.int64: dtypes.int64,
torch.bool: dtypes.bool,
}
tiny_to_torch_dtype = {v: k for k, v in torch_to_tiny_dtype.items()}
import torch.utils.cpp_extension
mod = torch.utils.cpp_extension.load(name="custom_device_extension", sources=[pathlib.Path(__file__).parent / "wrapped_tensor.cpp"])
def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, tiny_to_torch_dtype[x.dtype])
def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, _to_torch_dtype(x.dtype))
def unwrap(x:torch.Tensor) -> Tensor:
assert isinstance(x, torch.Tensor), f"x isn't {type(x)}"
return mod.unwrap(x)
@@ -66,13 +55,13 @@ def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None):
@torch.library.impl("aten::empty_strided", "privateuseone")
def empty_strided(size, stride, dtype, layout=None, device=None, pin_memory=False):
if TORCH_DEBUG: print(f"empty_strided {size=} {stride=} {dtype=} {layout=} {device=} {pin_memory=}")
ret = Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype])
ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype))
return wrap(ret)
@torch.library.impl("aten::empty.memory_format", "privateuseone")
def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
if TORCH_DEBUG: print(f"empty.memory_format {size=} {dtype=} {layout=} {device=} {pin_memory=} {memory_format=}")
ret = Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype or torch.get_default_dtype()])
ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype or torch.get_default_dtype()))
return wrap(ret)
@torch.library.impl("aten::max_pool2d_with_indices", "privateuseone")

View File

@@ -5,10 +5,10 @@
from tinygrad import Tensor
import torch, contextlib
from torch.utils._python_dispatch import TorchDispatchMode
from extra.torch_backend.backend import torch_to_tiny_dtype
from tinygrad.dtype import _from_torch_dtype
def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
return TTensor(Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype]))
return TTensor(Tensor.empty(*size, dtype=_from_torch_dtype(dtype)))
# NOTE: if we have a way to change wrap/unwrap, these can be the same methods from backend.py
tiny_backend = {

View File

@@ -147,6 +147,7 @@ class dtypes:
uints = (uint8, uint16, uint32, uint64)
sints = (int8, int16, int32, int64)
ints = uints + sints
all = floats + ints + (bool,)
if (env_default_float := getenv("DEFAULT_FLOAT", "")):
dtypes.default_float = getattr(dtypes, env_default_float.lower())
@@ -197,3 +198,22 @@ truncate: dict[DType, Callable] = {dtypes.bool: bool,
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value,
dtypes.int64: lambda x: ctypes.c_int64(x).value}
# numpy and torch dtype interop
def _to_np_dtype(dtype:DType) -> Optional[type]:
import numpy as np
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
import numpy as np
return dtypes.fields()[np.dtype(npdtype).name]
@functools.lru_cache(None)
def _to_torch_dtype(dtype:DType) -> Optional['torch.dtype']: # type: ignore [name-defined] # noqa: F821
import numpy as np, torch
# NOTE: torch doesn't expose this mapping with a stable API
try: return torch.from_numpy(np.array([], dtype=_to_np_dtype(dtype))).dtype
except TypeError: return None
@functools.lru_cache(None)
def _from_torch_dtype(torchdtype:'torch.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
return {v:k for k in dtypes.all if (v:=_to_torch_dtype(k)) is not None}[torchdtype]

View File

@@ -4,6 +4,7 @@ import time, math, itertools, functools, struct, sys, inspect, pathlib, string,
from contextlib import ContextDecorator
from typing import Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
from tinygrad.engine.multi import get_multi_map
@@ -49,12 +50,6 @@ def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str,
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None)
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
import numpy as np
return dtypes.fields()[np.dtype(npdtype).name]
def _to_np_dtype(dtype:DType) -> Optional[type]:
import numpy as np
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
def _fromnp(x: 'np.ndarray') -> UOp: # type: ignore [name-defined] # noqa: F821
ret = UOp.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")