mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Allow Tensor(tuple) (#911)
This commit is contained in:
@@ -34,10 +34,10 @@ class Tensor:
|
||||
no_grad: ClassVar[bool] = False
|
||||
default_type: ClassVar[DType] = dtypes.float32
|
||||
|
||||
def __init__(self, data:Union[int, float, list, LazyBuffer, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
||||
def __init__(self, data:Union[int, float, list, tuple, LazyBuffer, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
||||
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
||||
device = Device.canonicalize(device)
|
||||
if isinstance(data, list):
|
||||
if isinstance(data, (list, tuple)):
|
||||
data = np.array(data, dtype=(dtype if dtype is not None else Tensor.default_type).np)
|
||||
|
||||
if isinstance(data, LazyBuffer):
|
||||
|
||||
Reference in New Issue
Block a user