mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-15 09:05:40 -05:00
disk runner (prereq for interpreted removal) (#3421)
* disk runner * simpler diskrunner
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import os, mmap, _posixshmem, io
|
||||
from typing import Callable, Dict, Tuple
|
||||
import os, mmap, _posixshmem, io, functools
|
||||
from typing import Dict, List, Any
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.helpers import prod, OSX
|
||||
from tinygrad.device import Interpreted, Allocator
|
||||
from tinygrad.ops import Op, MovementOps, UnaryOps
|
||||
from tinygrad.device import Compiled, Allocator, JITRunner, Buffer
|
||||
from tinygrad.ops import UnaryOps, LazyOp, BufferOps
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
|
||||
class UnderlyingDiskBuffer:
|
||||
@@ -15,17 +15,8 @@ class DiskBuffer:
|
||||
def __init__(self, ud:UnderlyingDiskBuffer, size:int, dtype:DType=dtypes.uint8, offset=0):
|
||||
self.ud, self.size, self.dtype, self.offset = ud, size, dtype, offset
|
||||
def __repr__(self): return f"<DiskBuffer size={self.size} dtype={self.dtype} offset={self.offset}>"
|
||||
def cast(self, arg:Tuple[DType, bool]):
|
||||
# TODO: support shape changing bitcast
|
||||
#assert arg[1], "DiskTensor only supports bitcast"
|
||||
return DiskBuffer(self.ud, self.size, arg[0], offset=self.offset)
|
||||
def as_strided(self, arg):
|
||||
assert strides_for_shape(arg[0]) == arg[1], "disk tensors don't support strides"
|
||||
return DiskBuffer(self.ud, prod(arg[0]), self.dtype, offset=self.offset+arg[2]*self.dtype.itemsize)
|
||||
def _buf(self) -> memoryview: return memoryview(self.ud.mem)[self.offset:self.offset+self.size*self.dtype.itemsize]
|
||||
|
||||
disk_fxn_for_op: Dict[Op, Callable] = { UnaryOps.CAST: DiskBuffer.cast, MovementOps.AS_STRIDED: DiskBuffer.as_strided }
|
||||
|
||||
MAP_LOCKED, MAP_POPULATE = 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000)
|
||||
class DiskAllocator(Allocator):
|
||||
def __init__(self, device:str): self.device = device
|
||||
@@ -53,5 +44,34 @@ class DiskAllocator(Allocator):
|
||||
else:
|
||||
dest[:] = src._buf()
|
||||
|
||||
class DiskDevice(Interpreted):
|
||||
def __init__(self, device:str): super().__init__(device, DiskAllocator(device[len("disk:"):]), disk_fxn_for_op)
|
||||
class DiskRunner(JITRunner):
|
||||
skip_allocation = True
|
||||
def __init__(self, ast:LazyOp):
|
||||
# two ASTs are allowed here.
|
||||
assert ast.op == BufferOps.STORE, "output of AST must be store"
|
||||
assert ast.arg.st.contiguous, "shapetracker must be contiguous"
|
||||
# TODO: there shouldn't actually be casts here, bitcasts should fold into the load
|
||||
if ast.src[0].op == UnaryOps.CAST:
|
||||
top_src = ast.src[0].src[0]
|
||||
# TODO: assert that this is bitcast
|
||||
self.new_dtype = ast.src[0].arg[0]
|
||||
else:
|
||||
top_src = ast.src[0]
|
||||
self.new_dtype = top_src.arg.dtype
|
||||
assert top_src.op == BufferOps.LOAD, "top of AST must be load"
|
||||
assert len(top_src.arg.st.views) == 1, "shapetracker must have 1 view"
|
||||
view = top_src.arg.st.views[0]
|
||||
assert view.mask is None, "view cannot have a mask"
|
||||
assert strides_for_shape(view.shape) == view.strides, "disk tensors don't support strides"
|
||||
self.new_size = prod(view.shape)
|
||||
self.new_offset = view.offset
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Any, int], wait=False, jit=False):
|
||||
assert len(rawbufs) == 2
|
||||
src = rawbufs[1]._buf
|
||||
# TODO: src.dtype.itemsize or self.new_dtype.itemsize?
|
||||
rawbufs[0]._buf = DiskBuffer(src.ud, self.new_size, self.new_dtype, offset=src.offset+self.new_offset*src.dtype.itemsize)
|
||||
|
||||
class DiskDevice(Compiled):
|
||||
def __init__(self, device:str): super().__init__(device, DiskAllocator(device[len("disk:"):]), None, None)
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
def get_runner(self, ast:LazyOp): return DiskRunner(ast)
|
||||
|
||||
Reference in New Issue
Block a user