mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
no more underlying diskbuffer, that's just the device (#4129)
This commit is contained in:
@@ -33,7 +33,7 @@ def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, i
|
||||
|
||||
for out in si.outputs:
|
||||
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
|
||||
if out.size > 0 and not (out.device.startswith("DISK") and si.ast[0].op is BufferOps.STORE) and not hasattr(out, "_buf"): out.allocate()
|
||||
if out.size > 0 and not hasattr(out, "_buf"): out.allocate()
|
||||
|
||||
# run the function (put it in JIT)
|
||||
real_buffers = [x for x in si.outputs+si.inputs if x.size != 0]
|
||||
|
||||
@@ -1,44 +1,32 @@
|
||||
from __future__ import annotations
|
||||
import os, mmap, _posixshmem, io, functools
|
||||
from typing import Dict, List, Any
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from typing import Dict, List, Any, Optional
|
||||
from tinygrad.helpers import prod, OSX
|
||||
from tinygrad.device import Compiled, Allocator, JITRunner, Buffer
|
||||
from tinygrad.ops import UnaryOps, LazyOp, BufferOps
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
|
||||
class UnderlyingDiskBuffer:
|
||||
def __init__(self, fd, mem): self.fd, self.mem = fd, mem
|
||||
def __del__(self):
|
||||
if self.fd is not None: os.close(self.fd)
|
||||
|
||||
class DiskBuffer:
|
||||
def __init__(self, ud:UnderlyingDiskBuffer, size:int, dtype:DType=dtypes.uint8, offset=0):
|
||||
self.ud, self.size, self.dtype, self.offset = ud, size, dtype, offset
|
||||
def __repr__(self): return f"<DiskBuffer size={self.size} dtype={self.dtype} offset={self.offset}>"
|
||||
def _buf(self) -> memoryview: return memoryview(self.ud.mem)[self.offset:self.offset+self.size*self.dtype.itemsize]
|
||||
def __init__(self, device:DiskDevice, size:int, offset=0):
|
||||
self.device, self.size, self.offset = device, size, offset
|
||||
def __repr__(self): return f"<DiskBuffer size={self.size} offset={self.offset}>"
|
||||
def _buf(self) -> memoryview:
|
||||
assert self.device.mem is not None, "DiskBuffer wasn't opened"
|
||||
return memoryview(self.device.mem)[self.offset:self.offset+self.size]
|
||||
|
||||
MAP_LOCKED, MAP_POPULATE = 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000)
|
||||
class DiskAllocator(Allocator):
|
||||
def __init__(self, device:str): self.device = device
|
||||
def __init__(self, device:DiskDevice): self.device = device
|
||||
def _alloc(self, size:int, options):
|
||||
if self.device.startswith("shm:"):
|
||||
fd = _posixshmem.shm_open("/"+self.device[4:].lstrip("/"), os.O_RDWR, 0o600)
|
||||
mem = mmap.mmap(fd, size, mmap.MAP_SHARED | MAP_POPULATE | MAP_LOCKED)
|
||||
os.close(fd)
|
||||
fd = None
|
||||
else:
|
||||
try: fd = os.open(self.device, os.O_RDWR|os.O_CREAT|(0 if OSX else os.O_DIRECT))
|
||||
except OSError: fd = os.open(self.device, os.O_RDWR|os.O_CREAT)
|
||||
if os.fstat(fd).st_size < size: os.ftruncate(fd, size)
|
||||
mem = mmap.mmap(fd, size)
|
||||
if (hp := getattr(mmap, "MADV_HUGEPAGE", None)) is not None: mem.madvise(hp) # type: ignore
|
||||
return DiskBuffer(UnderlyingDiskBuffer(fd, mem), size)
|
||||
self.device._might_open(size)
|
||||
return DiskBuffer(self.device, size)
|
||||
def _free(self, buf, options): self.device._might_close()
|
||||
def as_buffer(self, src:DiskBuffer): return src._buf()
|
||||
def copyin(self, dest:DiskBuffer, src:memoryview): dest._buf()[:] = src
|
||||
def copyout(self, dest:memoryview, src:DiskBuffer):
|
||||
if OSX and src.ud.fd is not None:
|
||||
if OSX and hasattr(self.device, 'fd'):
|
||||
# OSX doesn't seem great at mmap, this is faster
|
||||
with io.FileIO(src.ud.fd, "a+b", closefd=False) as fo:
|
||||
with io.FileIO(self.device.fd, "a+b", closefd=False) as fo:
|
||||
fo.seek(src.offset)
|
||||
fo.readinto(dest)
|
||||
else:
|
||||
@@ -66,11 +54,36 @@ class DiskRunner(JITRunner):
|
||||
self.new_offset = view.offset * top_src.arg.dtype.itemsize
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Any, int], wait=False, jit=False):
|
||||
assert len(rawbufs) == 2
|
||||
src = rawbufs[1]._buf
|
||||
rawbufs[0].allocate(DiskBuffer(src.ud, self.new_size, self.new_dtype, offset=src.offset+self.new_offset))
|
||||
# TODO: this is a terrible hack that should be moved to lazy.py
|
||||
rawbufs[0]._buf.offset = rawbufs[1]._buf.offset+self.new_offset
|
||||
|
||||
class DiskDevice(Compiled):
|
||||
def __init__(self, device:str): super().__init__(device, DiskAllocator(device[len("disk:"):]), None, None)
|
||||
def __init__(self, device:str):
|
||||
self.size: Optional[int] = None
|
||||
self.count = 0
|
||||
super().__init__(device, DiskAllocator(self), None, None)
|
||||
def _might_open(self, size):
|
||||
self.count += 1
|
||||
assert self.size is None or size <= self.size, f"can't reopen Disk tensor with larger size, opened with {self.size}, tried to open with {size}"
|
||||
if self.size is not None: return
|
||||
filename = self.dname[len("disk:"):]
|
||||
self.size = size
|
||||
|
||||
if filename.startswith("shm:"):
|
||||
fd = _posixshmem.shm_open("/"+filename[4:].lstrip("/"), os.O_RDWR, 0o600)
|
||||
self.mem = mmap.mmap(fd, self.size, mmap.MAP_SHARED | MAP_POPULATE | MAP_LOCKED)
|
||||
os.close(fd)
|
||||
else:
|
||||
try: self.fd = os.open(filename, os.O_RDWR|os.O_CREAT|(0 if OSX else os.O_DIRECT))
|
||||
except OSError: self.fd = os.open(filename, os.O_RDWR|os.O_CREAT)
|
||||
if os.fstat(self.fd).st_size < self.size: os.ftruncate(self.fd, self.size)
|
||||
self.mem = mmap.mmap(self.fd, self.size)
|
||||
if (hp := getattr(mmap, "MADV_HUGEPAGE", None)) is not None: self.mem.madvise(hp) # type: ignore
|
||||
def _might_close(self):
|
||||
self.count -= 1
|
||||
if self.count == 0:
|
||||
os.close(self.fd)
|
||||
self.size = None
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
def get_runner(self, *ast:LazyOp):
|
||||
assert len(ast) == 1, "DiskRunner doesn't support multioutput kernels."
|
||||
|
||||
Reference in New Issue
Block a user