mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
remove numpy from device (#3123)
* remove numpy from device * fix tests * np item * cleanups * simplify with as_buffer * no toCPU * tinygradic * cast to scalar
This commit is contained in:
@@ -21,7 +21,13 @@ repos:
|
||||
pass_filenames: false
|
||||
- id: docs
|
||||
name: docs
|
||||
entry: python3 docs/abstractions.py && python3 docs/abstractions2.py
|
||||
entry: python3 docs/abstractions.py
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
- id: docs2
|
||||
name: docs2
|
||||
entry: python3 docs/abstractions2.py
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
|
||||
@@ -146,7 +146,8 @@ assert result.lazydata.base.realized is not None, "the LazyBuffer is realized!"
|
||||
assert isinstance(result.lazydata.base.realized, Buffer)
|
||||
assert result.lazydata.base.realized.device == "CLANG"
|
||||
# getting ahead of ourselves, but we can move the Buffer to CPU
|
||||
assert result.lazydata.base.realized.toCPU()[0] == 5, "when put in numpy with toCPU, it's 5"
|
||||
out = result.lazydata.base.realized.as_buffer().cast('I')
|
||||
assert out[0] == 5, "when put in numpy with toCPU, it's 5"
|
||||
|
||||
# %%
|
||||
# == Union[Interpreted, Compiled] (in tinygrad/device.py, code 6/10) ==
|
||||
|
||||
@@ -67,8 +67,7 @@ print(fxn.prg)
|
||||
fxn.exec([out, a, b])
|
||||
|
||||
# check the data out
|
||||
print(val:=out.toCPU().item())
|
||||
assert val == 5
|
||||
assert out.as_buffer().cast('I')[0] == 5
|
||||
|
||||
|
||||
print("******** third, the LazyBuffer ***********")
|
||||
@@ -100,8 +99,7 @@ print_tree(sched[-1].ast)
|
||||
run_schedule(sched)
|
||||
|
||||
# check the data out
|
||||
print(val:=out.realized.toCPU().item())
|
||||
assert val == 5
|
||||
assert out.realized.as_buffer().cast('I')[0] == 5
|
||||
|
||||
|
||||
print("******** fourth, the Tensor ***********")
|
||||
|
||||
@@ -90,8 +90,8 @@ def test_vs_onnx(onnx_data, schedule:Optional[List[ScheduleItem]], inputs:Dict[s
|
||||
GlobalCounters.reset()
|
||||
for si in schedule: lower_schedule_item(si)([si.out.realized] + [x.realized for x in si.inputs], {})
|
||||
|
||||
new_tinygrad_out = schedule[-1].out.realized.toCPU()
|
||||
np.testing.assert_allclose(new_torch_out.flatten(), new_tinygrad_out, atol=1e-4, rtol=1e-2)
|
||||
new_tinygrad_out = Tensor(schedule[-1].out).numpy()
|
||||
np.testing.assert_allclose(new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2)
|
||||
print("semi-thneed self-test passed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -377,13 +377,13 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False):
|
||||
prg = to_prg(k)
|
||||
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled
|
||||
prg.exec(real_bufs)
|
||||
np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(wanna_output, np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np), atol=1e-4, rtol=1e-4)
|
||||
|
||||
# Get baseline, which is not optimized at all.
|
||||
k = Linearizer(realized_ast)
|
||||
prg = Device[Device.DEFAULT].to_program(k)
|
||||
prg.exec(real_bufs)
|
||||
wanna_output = real_bufs[0].toCPU().copy()
|
||||
wanna_output = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np).copy()
|
||||
|
||||
# Check correctness of handcoded optimiztions.
|
||||
k = Linearizer(realized_ast)
|
||||
@@ -391,7 +391,7 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False):
|
||||
prg = Device[Device.DEFAULT].to_program(k)
|
||||
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled
|
||||
prg.exec(real_bufs)
|
||||
np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(wanna_output, np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np), atol=1e-4, rtol=1e-4)
|
||||
for x in opts: # Check custom transformations if any.
|
||||
check_opt(x, lambda: Linearizer(realized_ast), Device[Device.DEFAULT].to_program)
|
||||
|
||||
|
||||
@@ -28,10 +28,12 @@ def _test_single_value(vals, op, dts):
|
||||
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
|
||||
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
||||
buf = Buffer(Device.DEFAULT, 1, output_dtype)
|
||||
buf2 = [Buffer.fromCPU(Device.DEFAULT, np.array([a], dtype=dtype.np)) for a,dtype in zip(vals, dts)]
|
||||
buf2 = [Buffer(Device.DEFAULT, 1, dtype).copyin(np.array([a], dtype=dtype.np).data) for a,dtype in zip(vals, dts)]
|
||||
prg = _uops_to_prg(uops)
|
||||
prg.exec([buf]+buf2)
|
||||
return buf.toCPU()[0]
|
||||
ret = np.empty(1, output_dtype.np)
|
||||
buf.copyout(ret.data)
|
||||
return ret[0]
|
||||
|
||||
def _test_single_value_const(vals, op, dts):
|
||||
uops = []
|
||||
@@ -43,7 +45,9 @@ def _test_single_value_const(vals, op, dts):
|
||||
buf = Buffer(Device.DEFAULT, 1, output_dtype)
|
||||
prg = _uops_to_prg(uops)
|
||||
prg.exec([buf])
|
||||
return buf.toCPU()[0]
|
||||
ret = np.empty(1, output_dtype.np)
|
||||
buf.copyout(ret.data)
|
||||
return ret[0]
|
||||
|
||||
class TestUOps(unittest.TestCase):
|
||||
def _equal(self, v1, v2):
|
||||
|
||||
@@ -6,7 +6,7 @@ def time_tensor_numpy(out:Tensor):
|
||||
times = []
|
||||
for _ in range(5):
|
||||
st = time.perf_counter()
|
||||
out.lazydata.base.realized.toCPU()
|
||||
out.lazydata.base.realized.as_buffer(allow_zero_copy=True)
|
||||
et = time.perf_counter() - st
|
||||
times.append(et)
|
||||
return min(times)
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Union, Any, List, Optional, Dict, Callable
|
||||
import importlib, inspect, functools, pathlib, time, re, ctypes
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
||||
from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, GlobalCounters
|
||||
@@ -80,20 +79,20 @@ class Buffer:
|
||||
if isinstance(self.dtype, ImageDType): self.allocator.free(self._buf, self.dtype)
|
||||
else: self.allocator.free(self._buf, self.size * self.dtype.itemsize)
|
||||
def __repr__(self): return f"<buf device:{self.device} size:{self.size} dtype:{self.dtype}>"
|
||||
def as_buffer(self, allow_zero_copy=False) -> memoryview:
|
||||
# zero copy with as_buffer (disabled by default due to use after free)
|
||||
if allow_zero_copy and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf)
|
||||
return self.copyout(memoryview(bytearray(self.size*self.dtype.itemsize)))
|
||||
def copyin(self, mv:memoryview):
|
||||
mv = flat_mv(mv)
|
||||
assert len(mv) == self.size*self.dtype.itemsize, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
||||
self.allocator.copyin(self._buf, mv)
|
||||
return self
|
||||
@staticmethod
|
||||
def fromCPU(device:str, x:np.ndarray): return Buffer(device, x.size, dtypes.from_np(x.dtype)).copyin(x.data)
|
||||
def toCPU(self) -> np.ndarray:
|
||||
# zero copy with as_buffer
|
||||
if hasattr(self.allocator, 'as_buffer'):
|
||||
return np.frombuffer(self.allocator.as_buffer(self._buf), dtype=np.dtype(self.dtype.np, metadata={"backing": self._buf})) # type: ignore
|
||||
ret = np.empty(self.size, self.dtype.np)
|
||||
if self.size > 0: self.allocator.copyout(flat_mv(ret.data), self._buf)
|
||||
return ret
|
||||
def copyout(self, mv:memoryview) -> memoryview:
|
||||
mv = flat_mv(mv)
|
||||
assert len(mv) == self.size*self.dtype.itemsize, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
||||
self.allocator.copyout(mv, self._buf)
|
||||
return mv
|
||||
|
||||
def _internal_buffer_copy(dest:Buffer, src:Buffer):
|
||||
if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator): # noqa: E721
|
||||
@@ -114,11 +113,9 @@ def _internal_buffer_copy(dest:Buffer, src:Buffer):
|
||||
elif hasattr(dest.allocator, 'as_buffer'):
|
||||
# fast(ish) path, uses readinto in diskbuffers
|
||||
src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
|
||||
elif hasattr(src.allocator, 'as_buffer'):
|
||||
dest.allocator.copyin(dest._buf, src.allocator.as_buffer(src._buf))
|
||||
else:
|
||||
# slow path, allocates a CPU buffer
|
||||
dest.copyin(src.toCPU().data)
|
||||
# may allocate a CPU buffer depending on allow_zero_copy
|
||||
dest.copyin(src.as_buffer(allow_zero_copy=True))
|
||||
|
||||
class _BufferCopy(JITRunner):
|
||||
# TODO: make wait work
|
||||
|
||||
@@ -139,18 +139,20 @@ class Tensor:
|
||||
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
||||
|
||||
# TODO: these are good places to start removing numpy
|
||||
def item(self) -> Scalar:
|
||||
assert self.numel() == 1, "must have one element for item"
|
||||
return cast(Buffer, self.contiguous().realize().lazydata.base.realized).toCPU().item()
|
||||
def data(self) -> memoryview: return self.numpy().data
|
||||
def _data(self) -> memoryview:
|
||||
if 0 in self.shape: return memoryview(bytearray(0))
|
||||
t = self if isinstance(self.device, str) else self.to("CPU") # deal with multitensor
|
||||
return cast(Buffer, t.cast(t.dtype.scalar()).contiguous().realize().lazydata.base.realized).as_buffer()
|
||||
|
||||
# TODO: this should import numpy and use .data() to construct the array
|
||||
def numpy(self) -> np.ndarray:
|
||||
assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}"
|
||||
assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}"
|
||||
if 0 in self.shape: return np.zeros(self.shape, dtype=self.dtype.np)
|
||||
t = self if isinstance(self.device, str) else self.to("CPU")
|
||||
return t.cast(self.dtype.scalar()).contiguous().realize().lazydata.base.realized.toCPU().astype(self.dtype.np, copy=True).reshape(self.shape)
|
||||
return np.frombuffer(self._data(), dtype=self.dtype.np).reshape(self.shape)
|
||||
# TODO: numpy is only used here to get the memoryview type
|
||||
def data(self) -> memoryview: return self.numpy().data
|
||||
def item(self) -> Scalar:
|
||||
assert self.numel() == 1, "must have one element for item"
|
||||
return self.numpy().item()
|
||||
|
||||
def to(self, device:Optional[str]) -> Tensor:
|
||||
if device is None or device == self.device: return self
|
||||
|
||||
Reference in New Issue
Block a user