mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
61 lines
3.0 KiB
Python
61 lines
3.0 KiB
Python
from __future__ import annotations
|
|
from typing import Any, Optional
|
|
from dataclasses import dataclass
|
|
from tinygrad.helpers import GlobalCounters, flat_mv
|
|
from tinygrad.dtype import DType, ImageDType
|
|
|
|
@dataclass(frozen=True, eq=True)
|
|
class BufferOptions:
|
|
image: Optional[ImageDType] = None
|
|
uncached: bool = False
|
|
host: bool = False
|
|
nolru: bool = False
|
|
|
|
class Buffer:
|
|
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None, initial_value:Optional[bytes]=None):
|
|
assert isinstance(dtype, DType)
|
|
if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
|
|
self.device, self.size, self.dtype, self.options = device, size, dtype, options
|
|
if opaque is not None: self.allocate(opaque)
|
|
if initial_value is not None:
|
|
self.allocate()
|
|
self.copyin(memoryview(initial_value))
|
|
def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
|
|
def allocate(self, opaque=None) -> Buffer:
|
|
assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
|
|
from tinygrad.device import Device
|
|
self.allocator = Device[self.device].allocator
|
|
self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
|
|
if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
|
|
return self
|
|
def __reduce__(self):
|
|
buf = None
|
|
if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options)
|
|
if hasattr(self, '_buf'):
|
|
buf = bytearray(self.nbytes)
|
|
self.copyout(memoryview(buf))
|
|
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf)
|
|
@property
|
|
def nbytes(self): return self.size*self.dtype.itemsize
|
|
def __del__(self):
|
|
if not hasattr(self, '_buf'): return
|
|
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
|
|
self.allocator.free(self._buf, self.nbytes, self.options)
|
|
def __repr__(self):
|
|
return f"<buf real:{hasattr(self, '_buf')} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
|
|
(">" if self.options is None else f"{self.options=}>")
|
|
def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
|
|
# zero copy with as_buffer (disabled by default due to use after free)
|
|
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf)
|
|
assert not force_zero_copy, "force zero copy was passed, but copy is required"
|
|
return self.copyout(memoryview(bytearray(self.nbytes)))
|
|
def copyin(self, mv:memoryview):
|
|
mv = flat_mv(mv)
|
|
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
|
self.allocator.copyin(self._buf, mv)
|
|
return self
|
|
def copyout(self, mv:memoryview) -> memoryview:
|
|
mv = flat_mv(mv)
|
|
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
|
self.allocator.copyout(mv, self._buf)
|
|
return mv |