From 5de660ca0dc947922f06817e9b80d7797c389cfa Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 15 Feb 2024 18:14:05 +0100 Subject: [PATCH] disk runner (prereq for interpreted removal) (#3421) * disk runner * simpler diskrunner --- tinygrad/realize.py | 4 +-- tinygrad/runtime/ops_disk.py | 50 +++++++++++++++++++++++++----------- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/tinygrad/realize.py b/tinygrad/realize.py index cc48e2e66f..beb15a479e 100644 --- a/tinygrad/realize.py +++ b/tinygrad/realize.py @@ -2,7 +2,7 @@ import sys from collections import defaultdict from typing import List, Dict, Optional, cast, Set, DefaultDict from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps -from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, BufferRead, JITRunner, update_stats, InterpretedASTRunner, Compiled, BufferOptions +from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, BufferRead, JITRunner, update_stats, Compiled, BufferOptions from tinygrad.features.graph import print_tree, realized_lazybuffer, log_lazybuffer from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG, flatten, prod, dedup, all_int from tinygrad.shape.symbolic import Variable @@ -67,7 +67,7 @@ def run_schedule(schedule:List[ScheduleItem]): if si.out.size > 0: options = BufferOptions(host=True, signal=True) if si.ast.op is LoadOps.SYNC else None si.out.realized = si.out.output_buffer if si.out.output_buffer is not None else \ - Buffer(si.out.device, si.out.size, si.out.dtype, "PLACEHOLDER" if isinstance(prg, InterpretedASTRunner) else None, options=options) + Buffer(si.out.device, si.out.size, si.out.dtype, "PLACEHOLDER" if getattr(prg, "skip_allocation", False) else None, options=options) del si.out.srcs # run the function (put it in JIT) diff --git a/tinygrad/runtime/ops_disk.py b/tinygrad/runtime/ops_disk.py index 1c8627964f..281a2f2716 100644 --- a/tinygrad/runtime/ops_disk.py +++ b/tinygrad/runtime/ops_disk.py @@ -1,9 +1,9 @@ -import os, mmap, _posixshmem, io -from typing import Callable, Dict, Tuple +import os, mmap, _posixshmem, io, functools +from typing import Dict, List, Any from tinygrad.dtype import DType, dtypes from tinygrad.helpers import prod, OSX -from tinygrad.device import Interpreted, Allocator -from tinygrad.ops import Op, MovementOps, UnaryOps +from tinygrad.device import Compiled, Allocator, JITRunner, Buffer +from tinygrad.ops import UnaryOps, LazyOp, BufferOps from tinygrad.shape.view import strides_for_shape class UnderlyingDiskBuffer: @@ -15,17 +15,8 @@ class DiskBuffer: def __init__(self, ud:UnderlyingDiskBuffer, size:int, dtype:DType=dtypes.uint8, offset=0): self.ud, self.size, self.dtype, self.offset = ud, size, dtype, offset def __repr__(self): return f"" - def cast(self, arg:Tuple[DType, bool]): - # TODO: support shape changing bitcast - #assert arg[1], "DiskTensor only supports bitcast" - return DiskBuffer(self.ud, self.size, arg[0], offset=self.offset) - def as_strided(self, arg): - assert strides_for_shape(arg[0]) == arg[1], "disk tensors don't support strides" - return DiskBuffer(self.ud, prod(arg[0]), self.dtype, offset=self.offset+arg[2]*self.dtype.itemsize) def _buf(self) -> memoryview: return memoryview(self.ud.mem)[self.offset:self.offset+self.size*self.dtype.itemsize] -disk_fxn_for_op: Dict[Op, Callable] = { UnaryOps.CAST: DiskBuffer.cast, MovementOps.AS_STRIDED: DiskBuffer.as_strided } - MAP_LOCKED, MAP_POPULATE = 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000) class DiskAllocator(Allocator): def __init__(self, device:str): self.device = device @@ -53,5 +44,34 @@ class DiskAllocator(Allocator): else: dest[:] = src._buf() -class DiskDevice(Interpreted): - def __init__(self, device:str): super().__init__(device, DiskAllocator(device[len("disk:"):]), disk_fxn_for_op) \ No newline at end of file +class DiskRunner(JITRunner): + skip_allocation = True + def __init__(self, ast:LazyOp): + # two ASTs are allowed here. + assert ast.op == BufferOps.STORE, "output of AST must be store" + assert ast.arg.st.contiguous, "shapetracker must be contiguous" + # TODO: there shouldn't actually be casts here, bitcasts should fold into the load + if ast.src[0].op == UnaryOps.CAST: + top_src = ast.src[0].src[0] + # TODO: assert that this is bitcast + self.new_dtype = ast.src[0].arg[0] + else: + top_src = ast.src[0] + self.new_dtype = top_src.arg.dtype + assert top_src.op == BufferOps.LOAD, "top of AST must be load" + assert len(top_src.arg.st.views) == 1, "shapetracker must have 1 view" + view = top_src.arg.st.views[0] + assert view.mask is None, "view cannot have a mask" + assert strides_for_shape(view.shape) == view.strides, "disk tensors don't support strides" + self.new_size = prod(view.shape) + self.new_offset = view.offset + def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Any, int], wait=False, jit=False): + assert len(rawbufs) == 2 + src = rawbufs[1]._buf + # TODO: src.dtype.itemsize or self.new_dtype.itemsize? + rawbufs[0]._buf = DiskBuffer(src.ud, self.new_size, self.new_dtype, offset=src.offset+self.new_offset*src.dtype.itemsize) + +class DiskDevice(Compiled): + def __init__(self, device:str): super().__init__(device, DiskAllocator(device[len("disk:"):]), None, None) + @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none + def get_runner(self, ast:LazyOp): return DiskRunner(ast)