From 5eee974b2a9cd135655b763a002e187ed0e6e427 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 14 Jun 2024 11:36:05 -0400 Subject: [PATCH] construct Tensor from python list/tuple directly (#4947) * construct Tensor from python list/tuple directly no numpy. annoying that half memoryview is 3.12 feature... * simpler, and test * flat already * simpler * cute * 10% faster * 5% --- test/test_schedule.py | 2 +- test/unit/test_helpers.py | 15 ++++++++++++++- tinygrad/dtype.py | 10 +++++----- tinygrad/helpers.py | 6 ++++++ tinygrad/tensor.py | 16 ++++++++++++---- 5 files changed, 38 insertions(+), 11 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 35de8c127b..a01d0e816f 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -445,7 +445,7 @@ class TestSchedule(unittest.TestCase): def test_double_from(self): x = Tensor([1,2,3,4]) - out = x.to('npy') + out = x.to('python') check_schedule(out, 0, filter_loadops=False) def test_pow_const_tensor_simplified(self): diff --git a/test/unit/test_helpers.py b/test/unit/test_helpers.py index d529c3b640..8e9a60f34a 100644 --- a/test/unit/test_helpers.py +++ b/test/unit/test_helpers.py @@ -1,6 +1,7 @@ import unittest from PIL import Image -from tinygrad.helpers import Context, ContextVar, merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, get_contraction +from tinygrad.helpers import Context, ContextVar +from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, get_contraction, get_shape from tinygrad.shape.symbolic import Variable, NumNode VARIABLE = ContextVar("VARIABLE", 0) @@ -240,5 +241,17 @@ class TestGetContraction(unittest.TestCase): r = get_contraction((1,1,1,1), (1,1,1,1)) self.assertEqual(r, [[], [], [], [0,1,2,3]]) +class TestGetShape(unittest.TestCase): + def test_get_shape(self): + assert get_shape(2) == () + assert get_shape([]) == (0,) + assert get_shape([[]]) == (1, 0) + assert get_shape([[1, 2]]) == (1, 2) + assert get_shape([[1, 2], (3, 4)]) == (2, 2) + + def test_inhomogeneous_shape(self): + with self.assertRaises(ValueError): get_shape([[], [1]]) + with self.assertRaises(ValueError): get_shape([[1, [2]], [1]]) + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 20deb07102..8287cb9278 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -50,11 +50,11 @@ class dtypes: def from_np(x: type) -> DType: return DTYPES_DICT[np.dtype(x).name] @staticmethod def from_py(x) -> DType: - if isinstance(x, (list, tuple)): return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float - if isinstance(x, float): return dtypes.default_float - # NOTE: isinstance(True, int) is True in python, so check bool before int - if isinstance(x, bool): return dtypes.bool - if isinstance(x, int): return dtypes.default_int + if x.__class__ is float: return dtypes.default_float + if x.__class__ is int: return dtypes.default_int + if x.__class__ is bool: return dtypes.bool + # put this in the last is faster because there are more items than lists/tuples to check + if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}") @staticmethod def as_const(val: ConstType, dtype:DType): return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 0a7a199e13..6c5693ebcc 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -55,6 +55,12 @@ def get_child(obj, key): else: obj = getattr(obj, k) return obj +def get_shape(x) -> Tuple[int, ...]: + if not isinstance(x, (list, tuple)): return () + subs = [get_shape(xi) for xi in x] + if not all_same([sub for sub in subs]): raise ValueError(f"inhomogeneous shape from {x}") + return (len(subs),) + (subs[0] if subs else ()) + # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]: acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 9e69aef254..621bb24e9a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1,13 +1,13 @@ # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py from __future__ import annotations -import time, math, itertools, functools +import time, math, itertools, functools, struct from contextlib import ContextDecorator from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set from collections import defaultdict import numpy as np from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype -from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv +from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY from tinygrad.lazy import LazyBuffer from tinygrad.multi import MultiLazyBuffer @@ -50,6 +50,14 @@ def _fromcpu(x: np.ndarray) -> LazyBuffer: del ret.srcs return ret +def _frompy(x:Union[List, Tuple], dtype:DType) -> LazyBuffer: + ret = LazyBuffer.loadop(LoadOps.EMPTY, get_shape(x), dtype, "PYTHON") + # fake realize + assert dtype.fmt is not None, f"{dtype=} has None fmt" + ret.buffer.allocate(memoryview(struct.pack(f"@{ret.size}{dtype.fmt}", *fully_flatten(x)))) + del ret.srcs + return ret + def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str, Tuple[str, ...]]) -> List[List[Tensor]]: return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device) for m in mat], dim=dim) for k in range(len(mat[0]))] for dim in range(dims)] @@ -110,8 +118,8 @@ class Tensor: elif isinstance(data, bytes): data = _fromcpu(np.frombuffer(data, np.uint8)) elif isinstance(data, (list, tuple)): if dtype is None: dtype = dtypes.from_py(data) - if dtype == dtypes.bfloat16: data = Tensor(_fromcpu(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata - else: data = _fromcpu(np.array(data, dtype.np)) + if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata + else: data = _frompy(data, dtype) elif isinstance(data, np.ndarray): if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item()) else: data = _fromcpu(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)