disk runner (prereq for interpreted removal) (#3421)

* disk runner

* simpler diskrunner
This commit is contained in:
George Hotz
2024-02-15 18:14:05 +01:00
committed by GitHub
parent e1a57fe58a
commit 5de660ca0d
2 changed files with 37 additions and 17 deletions

View File

@@ -1,9 +1,9 @@
import os, mmap, _posixshmem, io
from typing import Callable, Dict, Tuple
import os, mmap, _posixshmem, io, functools
from typing import Dict, List, Any
from tinygrad.dtype import DType, dtypes
from tinygrad.helpers import prod, OSX
from tinygrad.device import Interpreted, Allocator
from tinygrad.ops import Op, MovementOps, UnaryOps
from tinygrad.device import Compiled, Allocator, JITRunner, Buffer
from tinygrad.ops import UnaryOps, LazyOp, BufferOps
from tinygrad.shape.view import strides_for_shape
class UnderlyingDiskBuffer:
@@ -15,17 +15,8 @@ class DiskBuffer:
def __init__(self, ud:UnderlyingDiskBuffer, size:int, dtype:DType=dtypes.uint8, offset=0):
self.ud, self.size, self.dtype, self.offset = ud, size, dtype, offset
def __repr__(self): return f"<DiskBuffer size={self.size} dtype={self.dtype} offset={self.offset}>"
def cast(self, arg:Tuple[DType, bool]):
# TODO: support shape changing bitcast
#assert arg[1], "DiskTensor only supports bitcast"
return DiskBuffer(self.ud, self.size, arg[0], offset=self.offset)
def as_strided(self, arg):
assert strides_for_shape(arg[0]) == arg[1], "disk tensors don't support strides"
return DiskBuffer(self.ud, prod(arg[0]), self.dtype, offset=self.offset+arg[2]*self.dtype.itemsize)
def _buf(self) -> memoryview: return memoryview(self.ud.mem)[self.offset:self.offset+self.size*self.dtype.itemsize]
disk_fxn_for_op: Dict[Op, Callable] = { UnaryOps.CAST: DiskBuffer.cast, MovementOps.AS_STRIDED: DiskBuffer.as_strided }
MAP_LOCKED, MAP_POPULATE = 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000)
class DiskAllocator(Allocator):
def __init__(self, device:str): self.device = device
@@ -53,5 +44,34 @@ class DiskAllocator(Allocator):
else:
dest[:] = src._buf()
class DiskDevice(Interpreted):
def __init__(self, device:str): super().__init__(device, DiskAllocator(device[len("disk:"):]), disk_fxn_for_op)
class DiskRunner(JITRunner):
skip_allocation = True
def __init__(self, ast:LazyOp):
# two ASTs are allowed here.
assert ast.op == BufferOps.STORE, "output of AST must be store"
assert ast.arg.st.contiguous, "shapetracker must be contiguous"
# TODO: there shouldn't actually be casts here, bitcasts should fold into the load
if ast.src[0].op == UnaryOps.CAST:
top_src = ast.src[0].src[0]
# TODO: assert that this is bitcast
self.new_dtype = ast.src[0].arg[0]
else:
top_src = ast.src[0]
self.new_dtype = top_src.arg.dtype
assert top_src.op == BufferOps.LOAD, "top of AST must be load"
assert len(top_src.arg.st.views) == 1, "shapetracker must have 1 view"
view = top_src.arg.st.views[0]
assert view.mask is None, "view cannot have a mask"
assert strides_for_shape(view.shape) == view.strides, "disk tensors don't support strides"
self.new_size = prod(view.shape)
self.new_offset = view.offset
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Any, int], wait=False, jit=False):
assert len(rawbufs) == 2
src = rawbufs[1]._buf
# TODO: src.dtype.itemsize or self.new_dtype.itemsize?
rawbufs[0]._buf = DiskBuffer(src.ud, self.new_size, self.new_dtype, offset=src.offset+self.new_offset*src.dtype.itemsize)
class DiskDevice(Compiled):
def __init__(self, device:str): super().__init__(device, DiskAllocator(device[len("disk:"):]), None, None)
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def get_runner(self, ast:LazyOp): return DiskRunner(ast)