mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
move buffer logic to Buffer [pr] (#9487)
* move buffer logic to Buffer [pr] * pass shape into as_typed_buffer * pass shape into as_typed_buffer * work * cleaner * fix tests
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import Optional, Any, Iterator, Generator
|
||||
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time
|
||||
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \
|
||||
cpu_time_execution, colored, Context, round_up
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# **************** Device ****************
|
||||
@@ -155,6 +155,14 @@ class Buffer:
|
||||
return self.allocator._as_buffer(self._buf)
|
||||
assert not force_zero_copy, "force zero copy was passed, but copy is required"
|
||||
return self.copyout(memoryview(bytearray(self.nbytes)))
|
||||
def as_typed_buffer(self, shape=None, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
|
||||
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
|
||||
assert self.dtype.base.fmt != "e" or sys.version_info >= (3, 12)
|
||||
return self.as_buffer(allow_zero_copy, force_zero_copy).cast(self.dtype.base.fmt, shape if shape is not None else (self.size,))
|
||||
def numpy(self) -> 'np.ndarray': # type: ignore [name-defined] # noqa: F821
|
||||
import numpy as np
|
||||
assert _to_np_dtype(self.dtype.base) is not None, f"no np dtype for {self.dtype.base}"
|
||||
return np.frombuffer(self.as_buffer(), dtype=_to_np_dtype(self.dtype.base))
|
||||
def copyin(self, mv:memoryview):
|
||||
mv = flat_mv(mv)
|
||||
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
|
||||
from contextlib import ContextDecorator
|
||||
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex, ParamSpec, TypeVar
|
||||
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
@@ -11,7 +11,7 @@ from tinygrad.engine.multi import get_multi_map
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
|
||||
from tinygrad.spec import tensor_uop_spec, type_verify
|
||||
from tinygrad.device import Device, BufferSpec
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||
@@ -288,14 +288,8 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return Tensor(self.lazydata.detach(), device=self.device, requires_grad=False)
|
||||
|
||||
def _data(self) -> memoryview:
|
||||
if 0 in self.shape: return memoryview(bytearray(0))
|
||||
# NOTE: this realizes on the object from as_buffer being a Python object
|
||||
cpu = self.cast(self.dtype.base).contiguous().to("CPU").realize()
|
||||
buf = cpu.lazydata.base.realized
|
||||
assert buf is not None, f"{cpu.lazydata.base} was not realized"
|
||||
if self.device != "CPU": buf.options = BufferSpec(nolru=True)
|
||||
return buf.as_buffer(allow_zero_copy=True if self.device != "CPU" else False)
|
||||
def _buffer(self) -> Buffer: return self.cast(self.dtype.base).contiguous().to("CPU").realize().lazydata.base.buffer
|
||||
def _data(self) -> memoryview: return self._buffer().as_buffer()
|
||||
|
||||
def data(self) -> memoryview:
|
||||
"""
|
||||
@@ -306,10 +300,9 @@ class Tensor(SimpleMathTrait):
|
||||
print(np.frombuffer(t.data(), dtype=np.int32))
|
||||
```
|
||||
"""
|
||||
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
|
||||
if 0 in self.shape: return memoryview(bytearray(0)).cast(self.dtype.base.fmt)
|
||||
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
||||
if TYPE_CHECKING or sys.version_info < (3, 12): assert self.dtype.base.fmt != "e"
|
||||
return self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape)
|
||||
return self._buffer().as_typed_buffer(self.shape)
|
||||
|
||||
def item(self) -> ConstType:
|
||||
"""
|
||||
@@ -350,11 +343,11 @@ class Tensor(SimpleMathTrait):
|
||||
print(repr(t.numpy()))
|
||||
```
|
||||
"""
|
||||
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
||||
import numpy as np
|
||||
if self.dtype.base == dtypes.bfloat16: return self.float().numpy()
|
||||
assert _to_np_dtype(self.dtype.base) is not None, f"no np dtype for {self.dtype.base}"
|
||||
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
||||
return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype.base)).reshape(self.shape)
|
||||
if 0 in self.shape: return np.empty(self.shape, dtype=_to_np_dtype(self.dtype.base))
|
||||
return self._buffer().numpy().reshape(self.shape)
|
||||
|
||||
def clone(self) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user