torch and numpy dtype interop [pr] (#9224)

* torch and numpy dtype interop [pr]

* less lines

* order
This commit is contained in:
George Hotz
2025-02-24 18:26:49 +08:00
committed by GitHub
parent 24615db5f5
commit fc32ff80d6
4 changed files with 27 additions and 23 deletions

View File

@@ -147,6 +147,7 @@ class dtypes:
uints = (uint8, uint16, uint32, uint64)
sints = (int8, int16, int32, int64)
ints = uints + sints
all = floats + ints + (bool,)
if (env_default_float := getenv("DEFAULT_FLOAT", "")):
dtypes.default_float = getattr(dtypes, env_default_float.lower())
@@ -197,3 +198,22 @@ truncate: dict[DType, Callable] = {dtypes.bool: bool,
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value,
dtypes.int64: lambda x: ctypes.c_int64(x).value}
# numpy and torch dtype interop
def _to_np_dtype(dtype:DType) -> Optional[type]:
import numpy as np
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
import numpy as np
return dtypes.fields()[np.dtype(npdtype).name]
@functools.lru_cache(None)
def _to_torch_dtype(dtype:DType) -> Optional['torch.dtype']: # type: ignore [name-defined] # noqa: F821
import numpy as np, torch
# NOTE: torch doesn't expose this mapping with a stable API
try: return torch.from_numpy(np.array([], dtype=_to_np_dtype(dtype))).dtype
except TypeError: return None
@functools.lru_cache(None)
def _from_torch_dtype(torchdtype:'torch.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
return {v:k for k in dtypes.all if (v:=_to_torch_dtype(k)) is not None}[torchdtype]