remove numpy from device (#3123)

* remove numpy from device

* fix tests

* np item

* cleanups

* simplify with as_buffer

* no toCPU

* tinygradic

* cast to scalar
This commit is contained in:
George Hotz
2024-01-14 19:36:05 -08:00
committed by GitHub
parent ea5824657d
commit 1f9aee8b6f
9 changed files with 46 additions and 38 deletions

View File

@@ -21,7 +21,13 @@ repos:
pass_filenames: false
- id: docs
name: docs
entry: python3 docs/abstractions.py && python3 docs/abstractions2.py
entry: python3 docs/abstractions.py
language: system
always_run: true
pass_filenames: false
- id: docs2
name: docs2
entry: python3 docs/abstractions2.py
language: system
always_run: true
pass_filenames: false

View File

@@ -146,7 +146,8 @@ assert result.lazydata.base.realized is not None, "the LazyBuffer is realized!"
assert isinstance(result.lazydata.base.realized, Buffer)
assert result.lazydata.base.realized.device == "CLANG"
# getting ahead of ourselves, but we can move the Buffer to CPU
assert result.lazydata.base.realized.toCPU()[0] == 5, "when put in numpy with toCPU, it's 5"
out = result.lazydata.base.realized.as_buffer().cast('I')
assert out[0] == 5, "when put in numpy with toCPU, it's 5"
# %%
# == Union[Interpreted, Compiled] (in tinygrad/device.py, code 6/10) ==

View File

@@ -67,8 +67,7 @@ print(fxn.prg)
fxn.exec([out, a, b])
# check the data out
print(val:=out.toCPU().item())
assert val == 5
assert out.as_buffer().cast('I')[0] == 5
print("******** third, the LazyBuffer ***********")
@@ -100,8 +99,7 @@ print_tree(sched[-1].ast)
run_schedule(sched)
# check the data out
print(val:=out.realized.toCPU().item())
assert val == 5
assert out.realized.as_buffer().cast('I')[0] == 5
print("******** fourth, the Tensor ***********")

View File

@@ -90,8 +90,8 @@ def test_vs_onnx(onnx_data, schedule:Optional[List[ScheduleItem]], inputs:Dict[s
GlobalCounters.reset()
for si in schedule: lower_schedule_item(si)([si.out.realized] + [x.realized for x in si.inputs], {})
new_tinygrad_out = schedule[-1].out.realized.toCPU()
np.testing.assert_allclose(new_torch_out.flatten(), new_tinygrad_out, atol=1e-4, rtol=1e-2)
new_tinygrad_out = Tensor(schedule[-1].out).numpy()
np.testing.assert_allclose(new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2)
print("semi-thneed self-test passed!")
if __name__ == "__main__":

View File

@@ -377,13 +377,13 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False):
prg = to_prg(k)
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled
prg.exec(real_bufs)
np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(wanna_output, np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np), atol=1e-4, rtol=1e-4)
# Get baseline, which is not optimized at all.
k = Linearizer(realized_ast)
prg = Device[Device.DEFAULT].to_program(k)
prg.exec(real_bufs)
wanna_output = real_bufs[0].toCPU().copy()
wanna_output = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np).copy()
# Check correctness of handcoded optimiztions.
k = Linearizer(realized_ast)
@@ -391,7 +391,7 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False):
prg = Device[Device.DEFAULT].to_program(k)
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled
prg.exec(real_bufs)
np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(wanna_output, np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np), atol=1e-4, rtol=1e-4)
for x in opts: # Check custom transformations if any.
check_opt(x, lambda: Linearizer(realized_ast), Device[Device.DEFAULT].to_program)

View File

@@ -28,10 +28,12 @@ def _test_single_value(vals, op, dts):
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
buf = Buffer(Device.DEFAULT, 1, output_dtype)
buf2 = [Buffer.fromCPU(Device.DEFAULT, np.array([a], dtype=dtype.np)) for a,dtype in zip(vals, dts)]
buf2 = [Buffer(Device.DEFAULT, 1, dtype).copyin(np.array([a], dtype=dtype.np).data) for a,dtype in zip(vals, dts)]
prg = _uops_to_prg(uops)
prg.exec([buf]+buf2)
return buf.toCPU()[0]
ret = np.empty(1, output_dtype.np)
buf.copyout(ret.data)
return ret[0]
def _test_single_value_const(vals, op, dts):
uops = []
@@ -43,7 +45,9 @@ def _test_single_value_const(vals, op, dts):
buf = Buffer(Device.DEFAULT, 1, output_dtype)
prg = _uops_to_prg(uops)
prg.exec([buf])
return buf.toCPU()[0]
ret = np.empty(1, output_dtype.np)
buf.copyout(ret.data)
return ret[0]
class TestUOps(unittest.TestCase):
def _equal(self, v1, v2):

View File

@@ -6,7 +6,7 @@ def time_tensor_numpy(out:Tensor):
times = []
for _ in range(5):
st = time.perf_counter()
out.lazydata.base.realized.toCPU()
out.lazydata.base.realized.as_buffer(allow_zero_copy=True)
et = time.perf_counter() - st
times.append(et)
return min(times)

View File

@@ -1,9 +1,8 @@
from __future__ import annotations
import numpy as np
from collections import defaultdict
from typing import TYPE_CHECKING, Union, Any, List, Optional, Dict, Callable
import importlib, inspect, functools, pathlib, time, re, ctypes
from tinygrad.dtype import DType, dtypes, ImageDType
from tinygrad.dtype import DType, ImageDType
from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put
from tinygrad.shape.symbolic import Variable, sym_infer, sint
from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, GlobalCounters
@@ -80,20 +79,20 @@ class Buffer:
if isinstance(self.dtype, ImageDType): self.allocator.free(self._buf, self.dtype)
else: self.allocator.free(self._buf, self.size * self.dtype.itemsize)
def __repr__(self): return f"<buf device:{self.device} size:{self.size} dtype:{self.dtype}>"
def as_buffer(self, allow_zero_copy=False) -> memoryview:
# zero copy with as_buffer (disabled by default due to use after free)
if allow_zero_copy and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf)
return self.copyout(memoryview(bytearray(self.size*self.dtype.itemsize)))
def copyin(self, mv:memoryview):
mv = flat_mv(mv)
assert len(mv) == self.size*self.dtype.itemsize, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
self.allocator.copyin(self._buf, mv)
return self
@staticmethod
def fromCPU(device:str, x:np.ndarray): return Buffer(device, x.size, dtypes.from_np(x.dtype)).copyin(x.data)
def toCPU(self) -> np.ndarray:
# zero copy with as_buffer
if hasattr(self.allocator, 'as_buffer'):
return np.frombuffer(self.allocator.as_buffer(self._buf), dtype=np.dtype(self.dtype.np, metadata={"backing": self._buf})) # type: ignore
ret = np.empty(self.size, self.dtype.np)
if self.size > 0: self.allocator.copyout(flat_mv(ret.data), self._buf)
return ret
def copyout(self, mv:memoryview) -> memoryview:
mv = flat_mv(mv)
assert len(mv) == self.size*self.dtype.itemsize, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
self.allocator.copyout(mv, self._buf)
return mv
def _internal_buffer_copy(dest:Buffer, src:Buffer):
if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator): # noqa: E721
@@ -114,11 +113,9 @@ def _internal_buffer_copy(dest:Buffer, src:Buffer):
elif hasattr(dest.allocator, 'as_buffer'):
# fast(ish) path, uses readinto in diskbuffers
src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
elif hasattr(src.allocator, 'as_buffer'):
dest.allocator.copyin(dest._buf, src.allocator.as_buffer(src._buf))
else:
# slow path, allocates a CPU buffer
dest.copyin(src.toCPU().data)
# may allocate a CPU buffer depending on allow_zero_copy
dest.copyin(src.as_buffer(allow_zero_copy=True))
class _BufferCopy(JITRunner):
# TODO: make wait work

View File

@@ -139,18 +139,20 @@ class Tensor:
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
# TODO: these are good places to start removing numpy
def item(self) -> Scalar:
assert self.numel() == 1, "must have one element for item"
return cast(Buffer, self.contiguous().realize().lazydata.base.realized).toCPU().item()
def data(self) -> memoryview: return self.numpy().data
def _data(self) -> memoryview:
if 0 in self.shape: return memoryview(bytearray(0))
t = self if isinstance(self.device, str) else self.to("CPU") # deal with multitensor
return cast(Buffer, t.cast(t.dtype.scalar()).contiguous().realize().lazydata.base.realized).as_buffer()
# TODO: this should import numpy and use .data() to construct the array
def numpy(self) -> np.ndarray:
assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}"
assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}"
if 0 in self.shape: return np.zeros(self.shape, dtype=self.dtype.np)
t = self if isinstance(self.device, str) else self.to("CPU")
return t.cast(self.dtype.scalar()).contiguous().realize().lazydata.base.realized.toCPU().astype(self.dtype.np, copy=True).reshape(self.shape)
return np.frombuffer(self._data(), dtype=self.dtype.np).reshape(self.shape)
# TODO: numpy is only used here to get the memoryview type
def data(self) -> memoryview: return self.numpy().data
def item(self) -> Scalar:
assert self.numel() == 1, "must have one element for item"
return self.numpy().item()
def to(self, device:Optional[str]) -> Tensor:
if device is None or device == self.device: return self