mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-15 17:15:48 -05:00
torch and numpy dtype interop [pr] (#9224)
* torch and numpy dtype interop [pr] * less lines * order
This commit is contained in:
@@ -3,24 +3,13 @@ from tinygrad.helpers import DEBUG, getenv, prod
|
||||
TORCH_DEBUG = getenv("TORCH_DEBUG")
|
||||
import torch, pathlib
|
||||
torch.autograd.grad_mode.set_multithreading_enabled(False)
|
||||
from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype
|
||||
|
||||
# https://pytorch.org/docs/stable/torch.compiler_ir.html
|
||||
|
||||
# TODO: don't replicate this in cpp
|
||||
torch_to_tiny_dtype = {
|
||||
torch.float32: dtypes.float32,
|
||||
torch.float64: dtypes.float64,
|
||||
torch.uint8: dtypes.uint8,
|
||||
torch.int8: dtypes.int8,
|
||||
torch.int32: dtypes.int32,
|
||||
torch.int64: dtypes.int64,
|
||||
torch.bool: dtypes.bool,
|
||||
}
|
||||
tiny_to_torch_dtype = {v: k for k, v in torch_to_tiny_dtype.items()}
|
||||
|
||||
import torch.utils.cpp_extension
|
||||
mod = torch.utils.cpp_extension.load(name="custom_device_extension", sources=[pathlib.Path(__file__).parent / "wrapped_tensor.cpp"])
|
||||
def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, tiny_to_torch_dtype[x.dtype])
|
||||
def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, _to_torch_dtype(x.dtype))
|
||||
def unwrap(x:torch.Tensor) -> Tensor:
|
||||
assert isinstance(x, torch.Tensor), f"x isn't {type(x)}"
|
||||
return mod.unwrap(x)
|
||||
@@ -66,13 +55,13 @@ def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None):
|
||||
@torch.library.impl("aten::empty_strided", "privateuseone")
|
||||
def empty_strided(size, stride, dtype, layout=None, device=None, pin_memory=False):
|
||||
if TORCH_DEBUG: print(f"empty_strided {size=} {stride=} {dtype=} {layout=} {device=} {pin_memory=}")
|
||||
ret = Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype])
|
||||
ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype))
|
||||
return wrap(ret)
|
||||
|
||||
@torch.library.impl("aten::empty.memory_format", "privateuseone")
|
||||
def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
|
||||
if TORCH_DEBUG: print(f"empty.memory_format {size=} {dtype=} {layout=} {device=} {pin_memory=} {memory_format=}")
|
||||
ret = Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype or torch.get_default_dtype()])
|
||||
ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype or torch.get_default_dtype()))
|
||||
return wrap(ret)
|
||||
|
||||
@torch.library.impl("aten::max_pool2d_with_indices", "privateuseone")
|
||||
|
||||
@@ -5,10 +5,10 @@
|
||||
from tinygrad import Tensor
|
||||
import torch, contextlib
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from extra.torch_backend.backend import torch_to_tiny_dtype
|
||||
from tinygrad.dtype import _from_torch_dtype
|
||||
|
||||
def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
|
||||
return TTensor(Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype]))
|
||||
return TTensor(Tensor.empty(*size, dtype=_from_torch_dtype(dtype)))
|
||||
|
||||
# NOTE: if we have a way to change wrap/unwrap, these can be the same methods from backend.py
|
||||
tiny_backend = {
|
||||
|
||||
Reference in New Issue
Block a user