remove numpy from device (#3123)

* remove numpy from device

* fix tests

* np item

* cleanups

* simplify with as_buffer

* no toCPU

* tinygradic

* cast to scalar
This commit is contained in:
George Hotz
2024-01-14 19:36:05 -08:00
committed by GitHub
parent ea5824657d
commit 1f9aee8b6f
9 changed files with 46 additions and 38 deletions

View File

@@ -21,7 +21,13 @@ repos:
pass_filenames: false pass_filenames: false
- id: docs - id: docs
name: docs name: docs
entry: python3 docs/abstractions.py && python3 docs/abstractions2.py entry: python3 docs/abstractions.py
language: system
always_run: true
pass_filenames: false
- id: docs2
name: docs2
entry: python3 docs/abstractions2.py
language: system language: system
always_run: true always_run: true
pass_filenames: false pass_filenames: false

View File

@@ -146,7 +146,8 @@ assert result.lazydata.base.realized is not None, "the LazyBuffer is realized!"
assert isinstance(result.lazydata.base.realized, Buffer) assert isinstance(result.lazydata.base.realized, Buffer)
assert result.lazydata.base.realized.device == "CLANG" assert result.lazydata.base.realized.device == "CLANG"
# getting ahead of ourselves, but we can move the Buffer to CPU # getting ahead of ourselves, but we can move the Buffer to CPU
assert result.lazydata.base.realized.toCPU()[0] == 5, "when put in numpy with toCPU, it's 5" out = result.lazydata.base.realized.as_buffer().cast('I')
assert out[0] == 5, "when put in numpy with toCPU, it's 5"
# %% # %%
# == Union[Interpreted, Compiled] (in tinygrad/device.py, code 6/10) == # == Union[Interpreted, Compiled] (in tinygrad/device.py, code 6/10) ==

View File

@@ -67,8 +67,7 @@ print(fxn.prg)
fxn.exec([out, a, b]) fxn.exec([out, a, b])
# check the data out # check the data out
print(val:=out.toCPU().item()) assert out.as_buffer().cast('I')[0] == 5
assert val == 5
print("******** third, the LazyBuffer ***********") print("******** third, the LazyBuffer ***********")
@@ -100,8 +99,7 @@ print_tree(sched[-1].ast)
run_schedule(sched) run_schedule(sched)
# check the data out # check the data out
print(val:=out.realized.toCPU().item()) assert out.realized.as_buffer().cast('I')[0] == 5
assert val == 5
print("******** fourth, the Tensor ***********") print("******** fourth, the Tensor ***********")

View File

@@ -90,8 +90,8 @@ def test_vs_onnx(onnx_data, schedule:Optional[List[ScheduleItem]], inputs:Dict[s
GlobalCounters.reset() GlobalCounters.reset()
for si in schedule: lower_schedule_item(si)([si.out.realized] + [x.realized for x in si.inputs], {}) for si in schedule: lower_schedule_item(si)([si.out.realized] + [x.realized for x in si.inputs], {})
new_tinygrad_out = schedule[-1].out.realized.toCPU() new_tinygrad_out = Tensor(schedule[-1].out).numpy()
np.testing.assert_allclose(new_torch_out.flatten(), new_tinygrad_out, atol=1e-4, rtol=1e-2) np.testing.assert_allclose(new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2)
print("semi-thneed self-test passed!") print("semi-thneed self-test passed!")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -377,13 +377,13 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False):
prg = to_prg(k) prg = to_prg(k)
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled
prg.exec(real_bufs) prg.exec(real_bufs)
np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4) np.testing.assert_allclose(wanna_output, np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np), atol=1e-4, rtol=1e-4)
# Get baseline, which is not optimized at all. # Get baseline, which is not optimized at all.
k = Linearizer(realized_ast) k = Linearizer(realized_ast)
prg = Device[Device.DEFAULT].to_program(k) prg = Device[Device.DEFAULT].to_program(k)
prg.exec(real_bufs) prg.exec(real_bufs)
wanna_output = real_bufs[0].toCPU().copy() wanna_output = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np).copy()
# Check correctness of handcoded optimiztions. # Check correctness of handcoded optimiztions.
k = Linearizer(realized_ast) k = Linearizer(realized_ast)
@@ -391,7 +391,7 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False):
prg = Device[Device.DEFAULT].to_program(k) prg = Device[Device.DEFAULT].to_program(k)
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled
prg.exec(real_bufs) prg.exec(real_bufs)
np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4) np.testing.assert_allclose(wanna_output, np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np), atol=1e-4, rtol=1e-4)
for x in opts: # Check custom transformations if any. for x in opts: # Check custom transformations if any.
check_opt(x, lambda: Linearizer(realized_ast), Device[Device.DEFAULT].to_program) check_opt(x, lambda: Linearizer(realized_ast), Device[Device.DEFAULT].to_program)

View File

@@ -28,10 +28,12 @@ def _test_single_value(vals, op, dts):
alu = uop(uops, UOps.ALU, output_dtype, loads, op) alu = uop(uops, UOps.ALU, output_dtype, loads, op)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
buf = Buffer(Device.DEFAULT, 1, output_dtype) buf = Buffer(Device.DEFAULT, 1, output_dtype)
buf2 = [Buffer.fromCPU(Device.DEFAULT, np.array([a], dtype=dtype.np)) for a,dtype in zip(vals, dts)] buf2 = [Buffer(Device.DEFAULT, 1, dtype).copyin(np.array([a], dtype=dtype.np).data) for a,dtype in zip(vals, dts)]
prg = _uops_to_prg(uops) prg = _uops_to_prg(uops)
prg.exec([buf]+buf2) prg.exec([buf]+buf2)
return buf.toCPU()[0] ret = np.empty(1, output_dtype.np)
buf.copyout(ret.data)
return ret[0]
def _test_single_value_const(vals, op, dts): def _test_single_value_const(vals, op, dts):
uops = [] uops = []
@@ -43,7 +45,9 @@ def _test_single_value_const(vals, op, dts):
buf = Buffer(Device.DEFAULT, 1, output_dtype) buf = Buffer(Device.DEFAULT, 1, output_dtype)
prg = _uops_to_prg(uops) prg = _uops_to_prg(uops)
prg.exec([buf]) prg.exec([buf])
return buf.toCPU()[0] ret = np.empty(1, output_dtype.np)
buf.copyout(ret.data)
return ret[0]
class TestUOps(unittest.TestCase): class TestUOps(unittest.TestCase):
def _equal(self, v1, v2): def _equal(self, v1, v2):

View File

@@ -6,7 +6,7 @@ def time_tensor_numpy(out:Tensor):
times = [] times = []
for _ in range(5): for _ in range(5):
st = time.perf_counter() st = time.perf_counter()
out.lazydata.base.realized.toCPU() out.lazydata.base.realized.as_buffer(allow_zero_copy=True)
et = time.perf_counter() - st et = time.perf_counter() - st
times.append(et) times.append(et)
return min(times) return min(times)

View File

@@ -1,9 +1,8 @@
from __future__ import annotations from __future__ import annotations
import numpy as np
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Union, Any, List, Optional, Dict, Callable from typing import TYPE_CHECKING, Union, Any, List, Optional, Dict, Callable
import importlib, inspect, functools, pathlib, time, re, ctypes import importlib, inspect, functools, pathlib, time, re, ctypes
from tinygrad.dtype import DType, dtypes, ImageDType from tinygrad.dtype import DType, ImageDType
from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put
from tinygrad.shape.symbolic import Variable, sym_infer, sint from tinygrad.shape.symbolic import Variable, sym_infer, sint
from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, GlobalCounters from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, GlobalCounters
@@ -80,20 +79,20 @@ class Buffer:
if isinstance(self.dtype, ImageDType): self.allocator.free(self._buf, self.dtype) if isinstance(self.dtype, ImageDType): self.allocator.free(self._buf, self.dtype)
else: self.allocator.free(self._buf, self.size * self.dtype.itemsize) else: self.allocator.free(self._buf, self.size * self.dtype.itemsize)
def __repr__(self): return f"<buf device:{self.device} size:{self.size} dtype:{self.dtype}>" def __repr__(self): return f"<buf device:{self.device} size:{self.size} dtype:{self.dtype}>"
def as_buffer(self, allow_zero_copy=False) -> memoryview:
# zero copy with as_buffer (disabled by default due to use after free)
if allow_zero_copy and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf)
return self.copyout(memoryview(bytearray(self.size*self.dtype.itemsize)))
def copyin(self, mv:memoryview): def copyin(self, mv:memoryview):
mv = flat_mv(mv) mv = flat_mv(mv)
assert len(mv) == self.size*self.dtype.itemsize, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}" assert len(mv) == self.size*self.dtype.itemsize, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
self.allocator.copyin(self._buf, mv) self.allocator.copyin(self._buf, mv)
return self return self
@staticmethod def copyout(self, mv:memoryview) -> memoryview:
def fromCPU(device:str, x:np.ndarray): return Buffer(device, x.size, dtypes.from_np(x.dtype)).copyin(x.data) mv = flat_mv(mv)
def toCPU(self) -> np.ndarray: assert len(mv) == self.size*self.dtype.itemsize, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
# zero copy with as_buffer self.allocator.copyout(mv, self._buf)
if hasattr(self.allocator, 'as_buffer'): return mv
return np.frombuffer(self.allocator.as_buffer(self._buf), dtype=np.dtype(self.dtype.np, metadata={"backing": self._buf})) # type: ignore
ret = np.empty(self.size, self.dtype.np)
if self.size > 0: self.allocator.copyout(flat_mv(ret.data), self._buf)
return ret
def _internal_buffer_copy(dest:Buffer, src:Buffer): def _internal_buffer_copy(dest:Buffer, src:Buffer):
if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator): # noqa: E721 if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator): # noqa: E721
@@ -114,11 +113,9 @@ def _internal_buffer_copy(dest:Buffer, src:Buffer):
elif hasattr(dest.allocator, 'as_buffer'): elif hasattr(dest.allocator, 'as_buffer'):
# fast(ish) path, uses readinto in diskbuffers # fast(ish) path, uses readinto in diskbuffers
src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf) src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
elif hasattr(src.allocator, 'as_buffer'):
dest.allocator.copyin(dest._buf, src.allocator.as_buffer(src._buf))
else: else:
# slow path, allocates a CPU buffer # may allocate a CPU buffer depending on allow_zero_copy
dest.copyin(src.toCPU().data) dest.copyin(src.as_buffer(allow_zero_copy=True))
class _BufferCopy(JITRunner): class _BufferCopy(JITRunner):
# TODO: make wait work # TODO: make wait work

View File

@@ -139,18 +139,20 @@ class Tensor:
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False) def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
# TODO: these are good places to start removing numpy # TODO: these are good places to start removing numpy
def item(self) -> Scalar: def _data(self) -> memoryview:
assert self.numel() == 1, "must have one element for item" if 0 in self.shape: return memoryview(bytearray(0))
return cast(Buffer, self.contiguous().realize().lazydata.base.realized).toCPU().item() t = self if isinstance(self.device, str) else self.to("CPU") # deal with multitensor
def data(self) -> memoryview: return self.numpy().data return cast(Buffer, t.cast(t.dtype.scalar()).contiguous().realize().lazydata.base.realized).as_buffer()
# TODO: this should import numpy and use .data() to construct the array
def numpy(self) -> np.ndarray: def numpy(self) -> np.ndarray:
assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}" assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}"
assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}" assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}"
if 0 in self.shape: return np.zeros(self.shape, dtype=self.dtype.np) return np.frombuffer(self._data(), dtype=self.dtype.np).reshape(self.shape)
t = self if isinstance(self.device, str) else self.to("CPU") # TODO: numpy is only used here to get the memoryview type
return t.cast(self.dtype.scalar()).contiguous().realize().lazydata.base.realized.toCPU().astype(self.dtype.np, copy=True).reshape(self.shape) def data(self) -> memoryview: return self.numpy().data
def item(self) -> Scalar:
assert self.numel() == 1, "must have one element for item"
return self.numpy().item()
def to(self, device:Optional[str]) -> Tensor: def to(self, device:Optional[str]) -> Tensor:
if device is None or device == self.device: return self if device is None or device == self.device: return self