mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
no numpy (#6751)
This commit is contained in:
9
.github/workflows/test.yml
vendored
9
.github/workflows/test.yml
vendored
@@ -140,6 +140,15 @@ jobs:
|
||||
python -c "from tinygrad.tensor import Tensor; print(Tensor([1,2,3,4,5]))"
|
||||
pip install mypy
|
||||
mypy -c "from tinygrad.tensor import Tensor; print(Tensor([1,2,3,4,5]))"
|
||||
- name: Run beautiful_mnist without numpy
|
||||
run: |
|
||||
mkdir $HOME/test_no_numpy_dir
|
||||
cd $HOME/test_no_numpy_dir
|
||||
python -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install $GITHUB_WORKSPACE
|
||||
cp $GITHUB_WORKSPACE/examples/beautiful_mnist.py .
|
||||
PYTHONPATH=$GITHUB_WORKSPACE BS=2 STEPS=10 python beautiful_mnist.py
|
||||
- name: Test DEBUG
|
||||
run: DEBUG=100 python3 -c "from tinygrad import Tensor; N = 1024; a, b = Tensor.rand(N, N), Tensor.rand(N, N); c = (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2); print((c.numpy() - (a.numpy() @ b.numpy())).mean())"
|
||||
- name: Repo line count <9800 lines
|
||||
|
||||
2
setup.py
2
setup.py
@@ -21,7 +21,7 @@ setup(name='tinygrad',
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License"
|
||||
],
|
||||
install_requires=["numpy"],
|
||||
install_requires=[],
|
||||
python_requires='>=3.8',
|
||||
extras_require={
|
||||
'llvm': ["llvmlite"],
|
||||
|
||||
@@ -578,7 +578,7 @@ class TestAutoCastType(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
|
||||
|
||||
@given(strat.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_int(d) and is_dtype_supported(d)]))
|
||||
@given(strat.sampled_from([d for d in core_dtypes if dtypes.is_int(d) and is_dtype_supported(d)]))
|
||||
def test_int_to_float_unary_func(self, dtype):
|
||||
for func in [
|
||||
lambda t: t.exp(),
|
||||
|
||||
@@ -477,7 +477,7 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
def test_double_from(self):
|
||||
x = Tensor([1,2,3,4])
|
||||
out = x.to('npy')
|
||||
out = x.to('python')
|
||||
check_schedule(out, 0, filter_sink=False)
|
||||
|
||||
def test_pow_const_tensor_simplified(self):
|
||||
|
||||
@@ -3,6 +3,7 @@ from PIL import Image
|
||||
from tinygrad.helpers import Context, ContextVar
|
||||
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, get_contraction, get_shape
|
||||
from tinygrad.shape.symbolic import Variable, NumNode
|
||||
import numpy as np
|
||||
|
||||
VARIABLE = ContextVar("VARIABLE", 0)
|
||||
|
||||
@@ -188,6 +189,12 @@ class TestFullyFlatten(unittest.TestCase):
|
||||
self.assertEqual(fully_flatten([[1, 2, [3, 4]], [5, 6], 7]), [1, 2, 3, 4, 5, 6, 7])
|
||||
self.assertEqual(fully_flatten([[1, "ab"], [True, None], [3.14, [5, "b"]]]), [1, "ab", True, None, 3.14, 5, "b"])
|
||||
|
||||
def test_fully_flatten_numpy(self):
|
||||
self.assertEqual(fully_flatten([np.array([1, 3]), np.array([1, 2])]), [1, 3, 1, 2])
|
||||
self.assertEqual(fully_flatten((np.array([1, 3]), np.array([1, 2]))), [1, 3, 1, 2])
|
||||
self.assertEqual(fully_flatten([np.array([[1], [3]]), np.array([[1], [2]])]), [1, 3, 1, 2])
|
||||
self.assertEqual(fully_flatten([[1, "ab"], [True, None], np.array([[3.14], [6.28]])]), [1, "ab", True, None, 3.14, 6.28])
|
||||
|
||||
class TestMemoryview(unittest.TestCase):
|
||||
def test_from_mv_to_mv(self):
|
||||
base = memoryview(bytearray(b"\x11\x22\x33"*40))
|
||||
|
||||
@@ -32,7 +32,14 @@ def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
|
||||
def ansilen(s:str): return len(ansistrip(s))
|
||||
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
|
||||
def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
|
||||
def fully_flatten(l): return [item for sublist in l for item in (fully_flatten(sublist) if isinstance(sublist, (tuple, list)) else [sublist])]
|
||||
def fully_flatten(l):
|
||||
if hasattr(l, "__len__") and hasattr(l, "__getitem__") and not isinstance(l, str):
|
||||
flattened = []
|
||||
if hasattr(l, "shape") and l.shape == (): flattened.append(l[()])
|
||||
else:
|
||||
for i in range(len(l)): flattened.extend(fully_flatten(l[i]))
|
||||
return flattened
|
||||
return [l]
|
||||
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
||||
def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
|
||||
def round_up(num, amt:int): return (num+amt-1)//amt * amt
|
||||
@@ -62,10 +69,12 @@ def get_child(obj, key):
|
||||
return obj
|
||||
|
||||
def get_shape(x) -> Tuple[int, ...]:
|
||||
if not isinstance(x, (list, tuple)): return ()
|
||||
if not hasattr(x, "__len__") or not hasattr(x, "__getitem__") or isinstance(x, str): return ()
|
||||
if (aapi := (hasattr(x, "shape") and x.shape == ())): return ()
|
||||
subs = [get_shape(xi) for xi in x]
|
||||
if not all_same(subs): raise ValueError(f"inhomogeneous shape from {x}")
|
||||
return (len(subs),) + (subs[0] if subs else ())
|
||||
slen = 1 if aapi else len(subs)
|
||||
return (slen,) + (subs[0] if subs else ())
|
||||
|
||||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
|
||||
|
||||
@@ -16,7 +16,7 @@ def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
|
||||
"""
|
||||
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
json_len = t[0:8].bitcast(dtypes.int64).item()
|
||||
return t, json_len, json.loads(t[8:8+json_len].numpy().tobytes())
|
||||
return t, json_len, json.loads(t[8:8+json_len].data().tobytes())
|
||||
|
||||
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
||||
"""
|
||||
|
||||
@@ -4,7 +4,6 @@ import time, math, itertools, functools, struct, sys, inspect, pathlib, string,
|
||||
from contextlib import ContextDecorator
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Literal
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype
|
||||
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten, dedup
|
||||
@@ -44,10 +43,14 @@ def _metaop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str,
|
||||
if isinstance(device, str): return LazyBuffer.metaop(op, shape, dtype, device, arg, src)
|
||||
return MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype, d, arg, src) for d in device], None)
|
||||
|
||||
def _from_np_dtype(npdtype:np.dtype) -> DType: return dtypes.fields()[np.dtype(npdtype).name]
|
||||
def _to_np_dtype(dtype:DType) -> Optional[type]: return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
|
||||
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
||||
import numpy as np
|
||||
return dtypes.fields()[np.dtype(npdtype).name]
|
||||
def _to_np_dtype(dtype:DType) -> Optional[type]:
|
||||
import numpy as np
|
||||
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
|
||||
|
||||
def _fromnp(x: np.ndarray) -> LazyBuffer:
|
||||
def _fromnp(x: 'np.ndarray') -> LazyBuffer: # type: ignore [name-defined] # noqa: F821
|
||||
ret = LazyBuffer.metaop(MetaOps.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
|
||||
# fake realize
|
||||
ret.buffer.allocate(x)
|
||||
@@ -62,7 +65,7 @@ def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
|
||||
truncate_function = truncate[dtype]
|
||||
data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
|
||||
# fake realize
|
||||
ret.buffer.allocate(memoryview(data))
|
||||
ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data)))
|
||||
del ret.srcs
|
||||
return ret
|
||||
|
||||
@@ -106,7 +109,7 @@ class Tensor:
|
||||
training: ClassVar[bool] = False
|
||||
no_grad: ClassVar[bool] = False
|
||||
|
||||
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable, pathlib.Path],
|
||||
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, 'np.ndarray', bytes, MultiLazyBuffer, Variable, pathlib.Path], # type: ignore [name-defined] # noqa: F821
|
||||
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
|
||||
if dtype is not None: dtype = to_dtype(dtype)
|
||||
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
||||
@@ -132,12 +135,14 @@ class Tensor:
|
||||
if dtype is None:
|
||||
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
|
||||
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
|
||||
if dtype == dtypes.bfloat16: data = Tensor(_fromnp(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
|
||||
else: data = _fromnp(np.array(data).astype(_to_np_dtype(dtype)))
|
||||
if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
|
||||
else: data = _frompy(data, dtype)
|
||||
elif data is None: data = _metaop(MetaOps.EMPTY, (0,), dtype or dtypes.default_float, device)
|
||||
elif isinstance(data, np.ndarray):
|
||||
elif str(type(data)) == "<class 'numpy.ndarray'>":
|
||||
import numpy as np
|
||||
assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}"
|
||||
if data.shape == (): data = _metaop(MetaOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
|
||||
else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data)
|
||||
else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data) # type: ignore [name-defined]
|
||||
elif isinstance(data, pathlib.Path):
|
||||
dtype = dtype or dtypes.uint8
|
||||
data = _metaop(MetaOps.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
|
||||
@@ -295,7 +300,7 @@ class Tensor:
|
||||
"""
|
||||
return self.data().tolist()
|
||||
|
||||
def numpy(self) -> np.ndarray:
|
||||
def numpy(self) -> 'np.ndarray': # type: ignore [name-defined] # noqa: F821
|
||||
"""
|
||||
Returns the value of this tensor as a `numpy.ndarray`.
|
||||
|
||||
@@ -304,6 +309,7 @@ class Tensor:
|
||||
print(repr(t.numpy()))
|
||||
```
|
||||
"""
|
||||
import numpy as np
|
||||
if self.dtype == dtypes.bfloat16: return self.float().numpy()
|
||||
assert _to_np_dtype(self.dtype) is not None, f"no np dtype for {self.dtype}"
|
||||
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
||||
|
||||
Reference in New Issue
Block a user