construct Tensor from python list/tuple directly (#4947)

* construct Tensor from python list/tuple directly

no numpy. annoying that half memoryview is 3.12 feature...

* simpler, and test

* flat already

* simpler

* cute

* 10% faster

* 5%
This commit is contained in:
chenyu
2024-06-14 11:36:05 -04:00
committed by GitHub
parent 90332eb529
commit 5eee974b2a
5 changed files with 38 additions and 11 deletions

View File

@@ -445,7 +445,7 @@ class TestSchedule(unittest.TestCase):
def test_double_from(self):
x = Tensor([1,2,3,4])
out = x.to('npy')
out = x.to('python')
check_schedule(out, 0, filter_loadops=False)
def test_pow_const_tensor_simplified(self):

View File

@@ -1,6 +1,7 @@
import unittest
from PIL import Image
from tinygrad.helpers import Context, ContextVar, merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, get_contraction
from tinygrad.helpers import Context, ContextVar
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, get_contraction, get_shape
from tinygrad.shape.symbolic import Variable, NumNode
VARIABLE = ContextVar("VARIABLE", 0)
@@ -240,5 +241,17 @@ class TestGetContraction(unittest.TestCase):
r = get_contraction((1,1,1,1), (1,1,1,1))
self.assertEqual(r, [[], [], [], [0,1,2,3]])
class TestGetShape(unittest.TestCase):
def test_get_shape(self):
assert get_shape(2) == ()
assert get_shape([]) == (0,)
assert get_shape([[]]) == (1, 0)
assert get_shape([[1, 2]]) == (1, 2)
assert get_shape([[1, 2], (3, 4)]) == (2, 2)
def test_inhomogeneous_shape(self):
with self.assertRaises(ValueError): get_shape([[], [1]])
with self.assertRaises(ValueError): get_shape([[1, [2]], [1]])
if __name__ == '__main__':
unittest.main()

View File

@@ -50,11 +50,11 @@ class dtypes:
def from_np(x: type) -> DType: return DTYPES_DICT[np.dtype(x).name]
@staticmethod
def from_py(x) -> DType:
if isinstance(x, (list, tuple)): return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
if isinstance(x, float): return dtypes.default_float
# NOTE: isinstance(True, int) is True in python, so check bool before int
if isinstance(x, bool): return dtypes.bool
if isinstance(x, int): return dtypes.default_int
if x.__class__ is float: return dtypes.default_float
if x.__class__ is int: return dtypes.default_int
if x.__class__ is bool: return dtypes.bool
# put this in the last is faster because there are more items than lists/tuples to check
if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
@staticmethod
def as_const(val: ConstType, dtype:DType): return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)

View File

@@ -55,6 +55,12 @@ def get_child(obj, key):
else: obj = getattr(obj, k)
return obj
def get_shape(x) -> Tuple[int, ...]:
if not isinstance(x, (list, tuple)): return ()
subs = [get_shape(xi) for xi in x]
if not all_same([sub for sub in subs]): raise ValueError(f"inhomogeneous shape from {x}")
return (len(subs),) + (subs[0] if subs else ())
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))

View File

@@ -1,13 +1,13 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, math, itertools, functools
import time, math, itertools, functools, struct
from contextlib import ContextDecorator
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
from collections import defaultdict
import numpy as np
from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten
from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
from tinygrad.lazy import LazyBuffer
from tinygrad.multi import MultiLazyBuffer
@@ -50,6 +50,14 @@ def _fromcpu(x: np.ndarray) -> LazyBuffer:
del ret.srcs
return ret
def _frompy(x:Union[List, Tuple], dtype:DType) -> LazyBuffer:
ret = LazyBuffer.loadop(LoadOps.EMPTY, get_shape(x), dtype, "PYTHON")
# fake realize
assert dtype.fmt is not None, f"{dtype=} has None fmt"
ret.buffer.allocate(memoryview(struct.pack(f"@{ret.size}{dtype.fmt}", *fully_flatten(x))))
del ret.srcs
return ret
def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str, Tuple[str, ...]]) -> List[List[Tensor]]:
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device) for m in mat], dim=dim)
for k in range(len(mat[0]))] for dim in range(dims)]
@@ -110,8 +118,8 @@ class Tensor:
elif isinstance(data, bytes): data = _fromcpu(np.frombuffer(data, np.uint8))
elif isinstance(data, (list, tuple)):
if dtype is None: dtype = dtypes.from_py(data)
if dtype == dtypes.bfloat16: data = Tensor(_fromcpu(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
else: data = _fromcpu(np.array(data, dtype.np))
if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
else: data = _frompy(data, dtype)
elif isinstance(data, np.ndarray):
if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
else: data = _fromcpu(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)