mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
construct Tensor from python list/tuple directly (#4947)
* construct Tensor from python list/tuple directly no numpy. annoying that half memoryview is 3.12 feature... * simpler, and test * flat already * simpler * cute * 10% faster * 5%
This commit is contained in:
@@ -445,7 +445,7 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
def test_double_from(self):
|
||||
x = Tensor([1,2,3,4])
|
||||
out = x.to('npy')
|
||||
out = x.to('python')
|
||||
check_schedule(out, 0, filter_loadops=False)
|
||||
|
||||
def test_pow_const_tensor_simplified(self):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
from PIL import Image
|
||||
from tinygrad.helpers import Context, ContextVar, merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, get_contraction
|
||||
from tinygrad.helpers import Context, ContextVar
|
||||
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, get_contraction, get_shape
|
||||
from tinygrad.shape.symbolic import Variable, NumNode
|
||||
|
||||
VARIABLE = ContextVar("VARIABLE", 0)
|
||||
@@ -240,5 +241,17 @@ class TestGetContraction(unittest.TestCase):
|
||||
r = get_contraction((1,1,1,1), (1,1,1,1))
|
||||
self.assertEqual(r, [[], [], [], [0,1,2,3]])
|
||||
|
||||
class TestGetShape(unittest.TestCase):
|
||||
def test_get_shape(self):
|
||||
assert get_shape(2) == ()
|
||||
assert get_shape([]) == (0,)
|
||||
assert get_shape([[]]) == (1, 0)
|
||||
assert get_shape([[1, 2]]) == (1, 2)
|
||||
assert get_shape([[1, 2], (3, 4)]) == (2, 2)
|
||||
|
||||
def test_inhomogeneous_shape(self):
|
||||
with self.assertRaises(ValueError): get_shape([[], [1]])
|
||||
with self.assertRaises(ValueError): get_shape([[1, [2]], [1]])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -50,11 +50,11 @@ class dtypes:
|
||||
def from_np(x: type) -> DType: return DTYPES_DICT[np.dtype(x).name]
|
||||
@staticmethod
|
||||
def from_py(x) -> DType:
|
||||
if isinstance(x, (list, tuple)): return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
|
||||
if isinstance(x, float): return dtypes.default_float
|
||||
# NOTE: isinstance(True, int) is True in python, so check bool before int
|
||||
if isinstance(x, bool): return dtypes.bool
|
||||
if isinstance(x, int): return dtypes.default_int
|
||||
if x.__class__ is float: return dtypes.default_float
|
||||
if x.__class__ is int: return dtypes.default_int
|
||||
if x.__class__ is bool: return dtypes.bool
|
||||
# put this in the last is faster because there are more items than lists/tuples to check
|
||||
if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
|
||||
raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
|
||||
@staticmethod
|
||||
def as_const(val: ConstType, dtype:DType): return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
|
||||
|
||||
@@ -55,6 +55,12 @@ def get_child(obj, key):
|
||||
else: obj = getattr(obj, k)
|
||||
return obj
|
||||
|
||||
def get_shape(x) -> Tuple[int, ...]:
|
||||
if not isinstance(x, (list, tuple)): return ()
|
||||
subs = [get_shape(xi) for xi in x]
|
||||
if not all_same([sub for sub in subs]): raise ValueError(f"inhomogeneous shape from {x}")
|
||||
return (len(subs),) + (subs[0] if subs else ())
|
||||
|
||||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
|
||||
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import time, math, itertools, functools
|
||||
import time, math, itertools, functools, struct
|
||||
from contextlib import ContextDecorator
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype
|
||||
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv
|
||||
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
@@ -50,6 +50,14 @@ def _fromcpu(x: np.ndarray) -> LazyBuffer:
|
||||
del ret.srcs
|
||||
return ret
|
||||
|
||||
def _frompy(x:Union[List, Tuple], dtype:DType) -> LazyBuffer:
|
||||
ret = LazyBuffer.loadop(LoadOps.EMPTY, get_shape(x), dtype, "PYTHON")
|
||||
# fake realize
|
||||
assert dtype.fmt is not None, f"{dtype=} has None fmt"
|
||||
ret.buffer.allocate(memoryview(struct.pack(f"@{ret.size}{dtype.fmt}", *fully_flatten(x))))
|
||||
del ret.srcs
|
||||
return ret
|
||||
|
||||
def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str, Tuple[str, ...]]) -> List[List[Tensor]]:
|
||||
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device) for m in mat], dim=dim)
|
||||
for k in range(len(mat[0]))] for dim in range(dims)]
|
||||
@@ -110,8 +118,8 @@ class Tensor:
|
||||
elif isinstance(data, bytes): data = _fromcpu(np.frombuffer(data, np.uint8))
|
||||
elif isinstance(data, (list, tuple)):
|
||||
if dtype is None: dtype = dtypes.from_py(data)
|
||||
if dtype == dtypes.bfloat16: data = Tensor(_fromcpu(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
|
||||
else: data = _fromcpu(np.array(data, dtype.np))
|
||||
if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
|
||||
else: data = _frompy(data, dtype)
|
||||
elif isinstance(data, np.ndarray):
|
||||
if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
|
||||
else: data = _fromcpu(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
|
||||
|
||||
Reference in New Issue
Block a user