mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
[ready] perf: simpler Tensor init (#1679)
Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>
This commit is contained in:
@@ -38,7 +38,6 @@ class Tensor:
|
||||
training: ClassVar[bool] = False
|
||||
no_grad: ClassVar[bool] = False
|
||||
default_type: ClassVar[DType] = dtypes.float32
|
||||
|
||||
def __init__(self, data:Union[int, float, list, LazyBuffer, np.ndarray], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
||||
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
||||
device = Device.canonicalize(device)
|
||||
@@ -51,25 +50,18 @@ class Tensor:
|
||||
|
||||
# internal variables used for autograd graph construction
|
||||
self._ctx: Optional[Function] = None
|
||||
if isinstance(data, LazyBuffer):
|
||||
assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
||||
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
|
||||
return
|
||||
|
||||
if isinstance(data, (int, float)):
|
||||
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
||||
elif isinstance(data, (int, float)):
|
||||
self.lazydata = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data)
|
||||
return
|
||||
|
||||
if data.__class__ is list:
|
||||
elif data.__class__ is list:
|
||||
assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
|
||||
data = np.array(data, dtype=(dtype or Tensor.default_type).np)
|
||||
|
||||
if isinstance(data, np.ndarray):
|
||||
data = LazyBuffer.fromCPU(np.array(data, dtype=(dtype or Tensor.default_type).np))
|
||||
elif isinstance(data, np.ndarray):
|
||||
data = LazyBuffer.fromCPU(data)
|
||||
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
|
||||
return
|
||||
else: raise RuntimeError(f"can't create Tensor from {data}")
|
||||
|
||||
raise RuntimeError(f"can't create Tensor from {data}")
|
||||
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"
|
||||
|
||||
Reference in New Issue
Block a user