mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-31 01:38:20 -05:00
Tensor._fromcpu -> Tensor._fromnp (#4966)
and moved to constructor with np.ndarray
This commit is contained in:
@@ -43,13 +43,6 @@ def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str,
|
||||
if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src)
|
||||
return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None)
|
||||
|
||||
def _fromcpu(x: np.ndarray) -> LazyBuffer:
|
||||
ret = LazyBuffer.loadop(LoadOps.EMPTY, x.shape, dtypes.from_np(x.dtype), "NPY")
|
||||
# fake realize
|
||||
ret.buffer.allocate(x)
|
||||
del ret.srcs
|
||||
return ret
|
||||
|
||||
def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
|
||||
if isinstance(x, bytes): ret, data = LazyBuffer.loadop(LoadOps.EMPTY, (len(x),), dtype, "PYTHON"), x
|
||||
else:
|
||||
@@ -123,10 +116,17 @@ class Tensor:
|
||||
if dtype is None: dtype = dtypes.from_py(data)
|
||||
if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
|
||||
else: data = _frompy(data, dtype)
|
||||
elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
|
||||
elif isinstance(data, np.ndarray):
|
||||
if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
|
||||
else: data = _fromcpu(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
|
||||
elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
|
||||
else:
|
||||
def _fromnp(x: np.ndarray) -> LazyBuffer:
|
||||
ret = LazyBuffer.loadop(LoadOps.EMPTY, x.shape, dtypes.from_np(x.dtype), "NPY")
|
||||
# fake realize
|
||||
ret.buffer.allocate(x)
|
||||
del ret.srcs
|
||||
return ret
|
||||
data = _fromnp(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
|
||||
|
||||
# by this point, it has to be a LazyBuffer
|
||||
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)):
|
||||
|
||||
Reference in New Issue
Block a user