mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
cleanup dtype of tensor creation from list (#4566)
This commit is contained in:
@@ -111,9 +111,9 @@ class Tensor:
|
||||
elif isinstance(data, bytes): data = _fromcpu(np.frombuffer(data, np.uint8))
|
||||
elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
|
||||
elif isinstance(data, list):
|
||||
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool
|
||||
elif d and all_int(d): dtype = dtype or dtypes.default_int
|
||||
else: dtype = dtype or dtypes.default_float
|
||||
if dtype is None:
|
||||
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
|
||||
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
|
||||
if dtype == dtypes.bfloat16: data = Tensor(_fromcpu(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
|
||||
else: data = _fromcpu(np.array(data, dtype.np))
|
||||
elif isinstance(data, np.ndarray):
|
||||
|
||||
Reference in New Issue
Block a user