torch and numpy dtype interop [pr] (#9224)

* torch and numpy dtype interop [pr]

* less lines

* order
This commit is contained in:
George Hotz
2025-02-24 18:26:49 +08:00
committed by GitHub
parent 24615db5f5
commit fc32ff80d6
4 changed files with 27 additions and 23 deletions

View File

@@ -3,24 +3,13 @@ from tinygrad.helpers import DEBUG, getenv, prod
TORCH_DEBUG = getenv("TORCH_DEBUG")
import torch, pathlib
torch.autograd.grad_mode.set_multithreading_enabled(False)
from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype
# https://pytorch.org/docs/stable/torch.compiler_ir.html
# TODO: don't replicate this in cpp
torch_to_tiny_dtype = {
torch.float32: dtypes.float32,
torch.float64: dtypes.float64,
torch.uint8: dtypes.uint8,
torch.int8: dtypes.int8,
torch.int32: dtypes.int32,
torch.int64: dtypes.int64,
torch.bool: dtypes.bool,
}
tiny_to_torch_dtype = {v: k for k, v in torch_to_tiny_dtype.items()}
import torch.utils.cpp_extension
mod = torch.utils.cpp_extension.load(name="custom_device_extension", sources=[pathlib.Path(__file__).parent / "wrapped_tensor.cpp"])
def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, tiny_to_torch_dtype[x.dtype])
def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, _to_torch_dtype(x.dtype))
def unwrap(x:torch.Tensor) -> Tensor:
assert isinstance(x, torch.Tensor), f"x isn't {type(x)}"
return mod.unwrap(x)
@@ -66,13 +55,13 @@ def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None):
@torch.library.impl("aten::empty_strided", "privateuseone")
def empty_strided(size, stride, dtype, layout=None, device=None, pin_memory=False):
if TORCH_DEBUG: print(f"empty_strided {size=} {stride=} {dtype=} {layout=} {device=} {pin_memory=}")
ret = Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype])
ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype))
return wrap(ret)
@torch.library.impl("aten::empty.memory_format", "privateuseone")
def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
if TORCH_DEBUG: print(f"empty.memory_format {size=} {dtype=} {layout=} {device=} {pin_memory=} {memory_format=}")
ret = Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype or torch.get_default_dtype()])
ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype or torch.get_default_dtype()))
return wrap(ret)
@torch.library.impl("aten::max_pool2d_with_indices", "privateuseone")

View File

@@ -5,10 +5,10 @@
from tinygrad import Tensor
import torch, contextlib
from torch.utils._python_dispatch import TorchDispatchMode
from extra.torch_backend.backend import torch_to_tiny_dtype
from tinygrad.dtype import _from_torch_dtype
def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
return TTensor(Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype]))
return TTensor(Tensor.empty(*size, dtype=_from_torch_dtype(dtype)))
# NOTE: if we have a way to change wrap/unwrap, these can be the same methods from backend.py
tiny_backend = {