Allow Tensor(tuple) (#911)

This commit is contained in:
Alexey Zaytsev
2023-06-04 13:48:19 +07:00
committed by GitHub
parent afd0be8a9c
commit d429553730

View File

@@ -34,10 +34,10 @@ class Tensor:
no_grad: ClassVar[bool] = False
default_type: ClassVar[DType] = dtypes.float32
def __init__(self, data:Union[int, float, list, LazyBuffer, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
def __init__(self, data:Union[int, float, list, tuple, LazyBuffer, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
device = Device.canonicalize(device)
if isinstance(data, list):
if isinstance(data, (list, tuple)):
data = np.array(data, dtype=(dtype if dtype is not None else Tensor.default_type).np)
if isinstance(data, LazyBuffer):