mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
torch and numpy dtype interop [pr] (#9224)
* torch and numpy dtype interop [pr] * less lines * order
This commit is contained in:
@@ -3,24 +3,13 @@ from tinygrad.helpers import DEBUG, getenv, prod
|
||||
TORCH_DEBUG = getenv("TORCH_DEBUG")
|
||||
import torch, pathlib
|
||||
torch.autograd.grad_mode.set_multithreading_enabled(False)
|
||||
from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype
|
||||
|
||||
# https://pytorch.org/docs/stable/torch.compiler_ir.html
|
||||
|
||||
# TODO: don't replicate this in cpp
|
||||
torch_to_tiny_dtype = {
|
||||
torch.float32: dtypes.float32,
|
||||
torch.float64: dtypes.float64,
|
||||
torch.uint8: dtypes.uint8,
|
||||
torch.int8: dtypes.int8,
|
||||
torch.int32: dtypes.int32,
|
||||
torch.int64: dtypes.int64,
|
||||
torch.bool: dtypes.bool,
|
||||
}
|
||||
tiny_to_torch_dtype = {v: k for k, v in torch_to_tiny_dtype.items()}
|
||||
|
||||
import torch.utils.cpp_extension
|
||||
mod = torch.utils.cpp_extension.load(name="custom_device_extension", sources=[pathlib.Path(__file__).parent / "wrapped_tensor.cpp"])
|
||||
def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, tiny_to_torch_dtype[x.dtype])
|
||||
def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, _to_torch_dtype(x.dtype))
|
||||
def unwrap(x:torch.Tensor) -> Tensor:
|
||||
assert isinstance(x, torch.Tensor), f"x isn't {type(x)}"
|
||||
return mod.unwrap(x)
|
||||
@@ -66,13 +55,13 @@ def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None):
|
||||
@torch.library.impl("aten::empty_strided", "privateuseone")
|
||||
def empty_strided(size, stride, dtype, layout=None, device=None, pin_memory=False):
|
||||
if TORCH_DEBUG: print(f"empty_strided {size=} {stride=} {dtype=} {layout=} {device=} {pin_memory=}")
|
||||
ret = Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype])
|
||||
ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype))
|
||||
return wrap(ret)
|
||||
|
||||
@torch.library.impl("aten::empty.memory_format", "privateuseone")
|
||||
def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
|
||||
if TORCH_DEBUG: print(f"empty.memory_format {size=} {dtype=} {layout=} {device=} {pin_memory=} {memory_format=}")
|
||||
ret = Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype or torch.get_default_dtype()])
|
||||
ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype or torch.get_default_dtype()))
|
||||
return wrap(ret)
|
||||
|
||||
@torch.library.impl("aten::max_pool2d_with_indices", "privateuseone")
|
||||
|
||||
@@ -5,10 +5,10 @@
|
||||
from tinygrad import Tensor
|
||||
import torch, contextlib
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from extra.torch_backend.backend import torch_to_tiny_dtype
|
||||
from tinygrad.dtype import _from_torch_dtype
|
||||
|
||||
def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
|
||||
return TTensor(Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype]))
|
||||
return TTensor(Tensor.empty(*size, dtype=_from_torch_dtype(dtype)))
|
||||
|
||||
# NOTE: if we have a way to change wrap/unwrap, these can be the same methods from backend.py
|
||||
tiny_backend = {
|
||||
|
||||
@@ -147,6 +147,7 @@ class dtypes:
|
||||
uints = (uint8, uint16, uint32, uint64)
|
||||
sints = (int8, int16, int32, int64)
|
||||
ints = uints + sints
|
||||
all = floats + ints + (bool,)
|
||||
|
||||
if (env_default_float := getenv("DEFAULT_FLOAT", "")):
|
||||
dtypes.default_float = getattr(dtypes, env_default_float.lower())
|
||||
@@ -197,3 +198,22 @@ truncate: dict[DType, Callable] = {dtypes.bool: bool,
|
||||
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
|
||||
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value,
|
||||
dtypes.int64: lambda x: ctypes.c_int64(x).value}
|
||||
|
||||
# numpy and torch dtype interop
|
||||
|
||||
def _to_np_dtype(dtype:DType) -> Optional[type]:
|
||||
import numpy as np
|
||||
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
|
||||
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
||||
import numpy as np
|
||||
return dtypes.fields()[np.dtype(npdtype).name]
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _to_torch_dtype(dtype:DType) -> Optional['torch.dtype']: # type: ignore [name-defined] # noqa: F821
|
||||
import numpy as np, torch
|
||||
# NOTE: torch doesn't expose this mapping with a stable API
|
||||
try: return torch.from_numpy(np.array([], dtype=_to_np_dtype(dtype))).dtype
|
||||
except TypeError: return None
|
||||
@functools.lru_cache(None)
|
||||
def _from_torch_dtype(torchdtype:'torch.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
||||
return {v:k for k in dtypes.all if (v:=_to_torch_dtype(k)) is not None}[torchdtype]
|
||||
@@ -4,6 +4,7 @@ import time, math, itertools, functools, struct, sys, inspect, pathlib, string,
|
||||
from contextlib import ContextDecorator
|
||||
from typing import Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
|
||||
from tinygrad.engine.multi import get_multi_map
|
||||
@@ -49,12 +50,6 @@ def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str,
|
||||
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
|
||||
return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None)
|
||||
|
||||
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
||||
import numpy as np
|
||||
return dtypes.fields()[np.dtype(npdtype).name]
|
||||
def _to_np_dtype(dtype:DType) -> Optional[type]:
|
||||
import numpy as np
|
||||
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
|
||||
|
||||
def _fromnp(x: 'np.ndarray') -> UOp: # type: ignore [name-defined] # noqa: F821
|
||||
ret = UOp.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
|
||||
|
||||
Reference in New Issue
Block a user