move buffer logic to Buffer [pr] (#9487)

* move buffer logic to Buffer [pr]

* pass shape into as_typed_buffer

* pass shape into as_typed_buffer

* work

* cleaner

* fix tests
This commit is contained in:
George Hotz
2025-03-18 11:21:21 +08:00
committed by GitHub
parent 3be228182f
commit d20494e6d7
2 changed files with 18 additions and 17 deletions

View File

@@ -5,7 +5,7 @@ from typing import Optional, Any, Iterator, Generator
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \
cpu_time_execution, colored, Context, round_up
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
from tinygrad.renderer import Renderer
# **************** Device ****************
@@ -155,6 +155,14 @@ class Buffer:
return self.allocator._as_buffer(self._buf)
assert not force_zero_copy, "force zero copy was passed, but copy is required"
return self.copyout(memoryview(bytearray(self.nbytes)))
def as_typed_buffer(self, shape=None, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
assert self.dtype.base.fmt != "e" or sys.version_info >= (3, 12)
return self.as_buffer(allow_zero_copy, force_zero_copy).cast(self.dtype.base.fmt, shape if shape is not None else (self.size,))
def numpy(self) -> 'np.ndarray': # type: ignore [name-defined] # noqa: F821
import numpy as np
assert _to_np_dtype(self.dtype.base) is not None, f"no np dtype for {self.dtype.base}"
return np.frombuffer(self.as_buffer(), dtype=_to_np_dtype(self.dtype.base))
def copyin(self, mv:memoryview):
mv = flat_mv(mv)
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
from contextlib import ContextDecorator
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex, ParamSpec, TypeVar
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
@@ -11,7 +11,7 @@ from tinygrad.engine.multi import get_multi_map
from tinygrad.gradient import compute_gradient
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
from tinygrad.spec import tensor_uop_spec, type_verify
from tinygrad.device import Device, BufferSpec
from tinygrad.device import Device, Buffer
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.memory import memory_planner
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
@@ -288,14 +288,8 @@ class Tensor(SimpleMathTrait):
"""
return Tensor(self.lazydata.detach(), device=self.device, requires_grad=False)
def _data(self) -> memoryview:
if 0 in self.shape: return memoryview(bytearray(0))
# NOTE: this realizes on the object from as_buffer being a Python object
cpu = self.cast(self.dtype.base).contiguous().to("CPU").realize()
buf = cpu.lazydata.base.realized
assert buf is not None, f"{cpu.lazydata.base} was not realized"
if self.device != "CPU": buf.options = BufferSpec(nolru=True)
return buf.as_buffer(allow_zero_copy=True if self.device != "CPU" else False)
def _buffer(self) -> Buffer: return self.cast(self.dtype.base).contiguous().to("CPU").realize().lazydata.base.buffer
def _data(self) -> memoryview: return self._buffer().as_buffer()
def data(self) -> memoryview:
"""
@@ -306,10 +300,9 @@ class Tensor(SimpleMathTrait):
print(np.frombuffer(t.data(), dtype=np.int32))
```
"""
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
if 0 in self.shape: return memoryview(bytearray(0)).cast(self.dtype.base.fmt)
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
if TYPE_CHECKING or sys.version_info < (3, 12): assert self.dtype.base.fmt != "e"
return self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape)
return self._buffer().as_typed_buffer(self.shape)
def item(self) -> ConstType:
"""
@@ -350,11 +343,11 @@ class Tensor(SimpleMathTrait):
print(repr(t.numpy()))
```
"""
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
import numpy as np
if self.dtype.base == dtypes.bfloat16: return self.float().numpy()
assert _to_np_dtype(self.dtype.base) is not None, f"no np dtype for {self.dtype.base}"
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype.base)).reshape(self.shape)
if 0 in self.shape: return np.empty(self.shape, dtype=_to_np_dtype(self.dtype.base))
return self._buffer().numpy().reshape(self.shape)
def clone(self) -> Tensor:
"""