[ready] perf: simpler Tensor init (#1679)

Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>
This commit is contained in:
Roelof van Dijk
2023-08-28 04:18:03 +02:00
committed by GitHub
parent b66f54e379
commit 50f669e43b

View File

@@ -38,7 +38,6 @@ class Tensor:
training: ClassVar[bool] = False
no_grad: ClassVar[bool] = False
default_type: ClassVar[DType] = dtypes.float32
def __init__(self, data:Union[int, float, list, LazyBuffer, np.ndarray], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
device = Device.canonicalize(device)
@@ -51,25 +50,18 @@ class Tensor:
# internal variables used for autograd graph construction
self._ctx: Optional[Function] = None
if isinstance(data, LazyBuffer):
assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
return
if isinstance(data, (int, float)):
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
elif isinstance(data, (int, float)):
self.lazydata = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data)
return
if data.__class__ is list:
elif data.__class__ is list:
assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
data = np.array(data, dtype=(dtype or Tensor.default_type).np)
if isinstance(data, np.ndarray):
data = LazyBuffer.fromCPU(np.array(data, dtype=(dtype or Tensor.default_type).np))
elif isinstance(data, np.ndarray):
data = LazyBuffer.fromCPU(data)
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
return
else: raise RuntimeError(f"can't create Tensor from {data}")
raise RuntimeError(f"can't create Tensor from {data}")
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
def __repr__(self):
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"