mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-08 13:45:50 -05:00
torch and numpy dtype interop [pr] (#9224)
* torch and numpy dtype interop [pr] * less lines * order
This commit is contained in:
@@ -147,6 +147,7 @@ class dtypes:
|
||||
uints = (uint8, uint16, uint32, uint64)
|
||||
sints = (int8, int16, int32, int64)
|
||||
ints = uints + sints
|
||||
all = floats + ints + (bool,)
|
||||
|
||||
if (env_default_float := getenv("DEFAULT_FLOAT", "")):
|
||||
dtypes.default_float = getattr(dtypes, env_default_float.lower())
|
||||
@@ -197,3 +198,22 @@ truncate: dict[DType, Callable] = {dtypes.bool: bool,
|
||||
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
|
||||
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value,
|
||||
dtypes.int64: lambda x: ctypes.c_int64(x).value}
|
||||
|
||||
# numpy and torch dtype interop
|
||||
|
||||
def _to_np_dtype(dtype:DType) -> Optional[type]:
|
||||
import numpy as np
|
||||
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
|
||||
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
||||
import numpy as np
|
||||
return dtypes.fields()[np.dtype(npdtype).name]
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _to_torch_dtype(dtype:DType) -> Optional['torch.dtype']: # type: ignore [name-defined] # noqa: F821
|
||||
import numpy as np, torch
|
||||
# NOTE: torch doesn't expose this mapping with a stable API
|
||||
try: return torch.from_numpy(np.array([], dtype=_to_np_dtype(dtype))).dtype
|
||||
except TypeError: return None
|
||||
@functools.lru_cache(None)
|
||||
def _from_torch_dtype(torchdtype:'torch.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
||||
return {v:k for k in dtypes.all if (v:=_to_torch_dtype(k)) is not None}[torchdtype]
|
||||
Reference in New Issue
Block a user