remove force_wait, refactor to graph (#2405)

* remove force_wait

* refactor

* get rid of stupid ASTRunner

* fix del in diskbuffer

* BufferOps.FROM_UNDERLYING

* put offset in the rawbuffer

* fix bugs

* use exec
This commit is contained in:
George Hotz
2023-11-23 12:46:07 -08:00
committed by GitHub
parent c5d585ea35
commit 0505c5ea50
12 changed files with 143 additions and 172 deletions

6
extra/dist/world.py vendored
View File

@@ -12,8 +12,8 @@ from tinygrad.tensor import Tensor, Function
import extra.hip_wrapper as hip
import numpy as np
# fake the function signature of ASTRunner so we can put it in the cache
def __send_rb(args, variables=None, jit=False, force_wait=False):
# match the function signature of JITRunner so we can put it in the cache
def __send_rb(args, variables=None, wait=False, jit=False):
x, target_rank, y = args[:3]
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
hip.hipSetDevice(x._device)
@@ -24,7 +24,7 @@ def __send_rb(args, variables=None, jit=False, force_wait=False):
dist.OOB.send(None, target_rank)
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} sent {x} to rank {target_rank}")
def __recv_rb(args, variables=None, jit=False, force_wait=False):
def __recv_rb(args, variables=None, wait=False, jit=False):
x, target_rank, y = args[:3]
dist.OOB.recv(target_rank)
if RawHIPBuffer and x.__class__ is RawHIPBuffer:

View File

@@ -1,26 +0,0 @@
import unittest
from tinygrad.helpers import prod
from tinygrad.tensor import Tensor
from tinygrad.helpers import GlobalCounters
from tinygrad.jit import CacheCollector
class TestCopy(unittest.TestCase):
def test_add1(self):
pts = []
for i in range(16384, 16384*256, 16384):
t = Tensor.randn(i).realize()
CacheCollector.start()
t.assign(t+1).realize()
ji = CacheCollector.finish()[0]
GlobalCounters.reset()
def run(): return ji.prg(ji.rawbufs, force_wait=True)
ct = min([run() for _ in range(10)])
mb = prod(t.shape)*t.dtype.itemsize*2*1e-6
print(f"{mb*1e3:.2f} kB, {ct*1e3:.2f} ms, {mb/ct:.2f} MB/s")
pts.append((mb, mb/ct))
from matplotlib import pyplot as plt
plt.plot([x[0] for x in pts], [x[1] for x in pts])
plt.show()
if __name__ == '__main__':
unittest.main()

View File

@@ -21,7 +21,7 @@ def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None):
if isinstance(device, Compiled):
prg = device.to_program(lin)
else:
prg = get_interpreted_fxn(device.fxn_for_op, device.from_underlying, lin.ast)
prg = get_interpreted_fxn(device.fxn_for_op, lin.ast)
except:
print(lin.ast)
traceback.print_exc()
@@ -29,7 +29,7 @@ def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None):
return "COMPILE_ERROR"
try:
prg.exec(rawbufs, var_vals, force_wait=True)
prg.exec(rawbufs, var_vals)
except:
print(lin.ast)
traceback.print_exc()

View File

@@ -349,13 +349,13 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False):
k.apply_opt(opt)
prg = to_prg(k)
real_bufs[0] = real_bufs[0].fromCPU(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np)) # Zero to check that all values are filled
prg.exec(real_bufs, force_wait=True)
prg.exec(real_bufs)
np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
# Get baseline, which is not optimized at all.
k = Linearizer(realized_ast)
prg = Device[Device.DEFAULT].to_program(k)
prg.exec(real_bufs, force_wait=True)
prg.exec(real_bufs)
wanna_output = real_bufs[0].toCPU().copy()
# Check correctness of handcoded optimiztions.
@@ -363,7 +363,7 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False):
k.hand_coded_optimizations()
prg = Device[Device.DEFAULT].to_program(k)
real_bufs[0] = real_bufs[0].fromCPU(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np)) # Zero to check that all values are filled
prg.exec(real_bufs, force_wait=True)
prg.exec(real_bufs)
np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
for x in opts: # Check custom transformations if any.
check_opt(x, lambda: Linearizer(realized_ast), Device[Device.DEFAULT].to_program)

View File

@@ -26,7 +26,7 @@ def _test_single_value(vals, op, dtype):
buf = Device[Device.DEFAULT].buffer(1, dtype)
buf2 = [Device[Device.DEFAULT].buffer.fromCPU(np.array([a], dtype=dtype.np)) for a in vals]
prg = _uops_to_prg(uops)
prg([buf]+buf2)
prg.exec([buf]+buf2)
return buf.toCPU()[0]
def _test_single_value_const(vals, op, dtype):
@@ -37,7 +37,7 @@ def _test_single_value_const(vals, op, dtype):
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
buf = Device[Device.DEFAULT].buffer(1, dtype)
prg = _uops_to_prg(uops)
prg([buf])
prg.exec([buf])
return buf.toCPU()[0]
class TestUOps(unittest.TestCase):

View File

@@ -1,32 +1,46 @@
from __future__ import annotations
from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar, Generic
import functools, itertools
from tinygrad.helpers import DEBUG, DType, merge_dicts
from tinygrad.ops import RawBuffer, Device, ASTRunner, BatchExecutor, JitItem
import functools, itertools, operator
from tinygrad.helpers import DEBUG, DType, merge_dicts, getenv
from tinygrad.ops import RawBuffer, Device, JITRunner
from tinygrad.tensor import Tensor
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable
from tinygrad.shape.symbolic import Variable, NumNode, Node
from weakref import ref, WeakKeyDictionary
from dataclasses import dataclass
@dataclass(frozen=True)
class JitItem:
prg: JITRunner
rawbufs: List[Optional[RawBuffer]]
def get_jit_stats(jit_cache: List[JitItem]) -> Tuple[Node, Node]:
return functools.reduce(operator.__add__, [ji.prg.op_estimate for ji in jit_cache], NumNode(0)), functools.reduce(operator.__add__, [ji.prg.mem_estimate for ji in jit_cache], NumNode(0))
def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[RawBuffer]) -> Dict[Tuple[int, int], int]:
input_replace: Dict[Tuple[int, int], int] = {}
for j,ji in enumerate(jit_cache):
for i,a in enumerate(ji.rawbufs):
if a in input_rawbuffers:
input_replace[(j,i)] = input_rawbuffers.index(a)
assert len(set(input_replace.values())) == len(input_rawbuffers), "some input tensors not found"
return input_replace
class GraphException(Exception): pass
ReturnType = TypeVar('ReturnType')
class TinyJit(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType]):
self.fxn = fxn
self.reset()
def reset(self):
self.jit_fxn: Optional[BatchExecutor] = None
self.jit_cache: List[JitItem] = []
self.input_replace: Dict[Tuple[int, int], int] = {}
self.cnt: int = 0
self.ret: Optional[ReturnType] = None
self.expected_vals: Optional[Tuple[Variable, ...]] = None
self.expected_name_sts_dtype: Optional[Tuple[Tuple[Union[int, str], ShapeTracker, DType], ...]] = None
@property
def jit_cache(self) -> List[JitItem]: return self.jit_fxn.jit_cache if self.jit_fxn else []
@property
def input_replace(self) -> Dict[Tuple[int, int], int]: return self.jit_fxn.input_replace if self.jit_fxn else {}
# add support for instance methods
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
@@ -44,23 +58,35 @@ class TinyJit(Generic[ReturnType]):
expected_vals = tuple(var_vals.keys())
if self.cnt >= 2:
# jit exec
assert self.expected_vals == expected_vals, "mismatch of var_vals"
assert self.expected_name_sts_dtype == expected_name_sts_dtype, f"mismatch of sts, expected {self.expected_name_sts_dtype} got {expected_name_sts_dtype}"
assert self.jit_fxn, "didn't get jitted?"
self.jit_fxn(input_rawbuffers, var_vals, DEBUG>=2)
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), var_vals, jit=True)
elif self.cnt == 1:
# jit capture
self.expected_vals, self.expected_name_sts_dtype = expected_vals, expected_name_sts_dtype
CacheCollector.start(var_vals)
self.ret = self.fxn(*args, **kwargs)
jit_cache = CacheCollector.finish()
assert len(jit_cache) != 0, "didn't JIT anything!"
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_rawbuffers)} inputs")
self.jit_cache = CacheCollector.finish()
assert len(self.jit_cache) != 0, "didn't JIT anything!"
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
self.jit_fxn = Device[Device.DEFAULT].batch_executor(jit_cache, input_rawbuffers, var_vals)
# if your Device supports it, condense the items into a graph executor
if (make_graph := Device[Device.DEFAULT].graph) and getenv("JIT") != 2:
try:
self.jit_cache = [JitItem(make_graph(self.jit_cache, input_rawbuffers, var_vals), cast(List[Optional[RawBuffer]], input_rawbuffers))]
except GraphException as e:
if DEBUG >= 1: print(f"graph create failed {e}")
self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers)
elif self.cnt == 0:
# jit ignore
self.ret = self.fxn(*args, **kwargs)
# clear jit inputs
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
self.cnt += 1
return cast(ReturnType, self.ret)
@@ -77,7 +103,7 @@ class PlaceHolder:
class _CacheCollector:
def __init__(self):
self.cache: Optional[List[Tuple[ASTRunner, List[Union[RawBuffer, PlaceHolder]]]]] = None
self.cache: Optional[List[Tuple[JITRunner, List[Union[RawBuffer, PlaceHolder]]]]] = None
def start(self, var_vals:Optional[Dict[Variable, int]]=None):
self.cache = []
@@ -88,7 +114,7 @@ class _CacheCollector:
if self.cache is None: return
for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}"
self.placeholders[rawbufs[0]] = PlaceHolder(rawbufs[0]) # NOTE: this is making an assumption that 0 is special
self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(prg, ASTRunner) and isinstance(x, RawBuffer) else x for x in rawbufs]))
self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, RawBuffer) else x for x in rawbufs]))
def finish(self) -> List[JitItem]:
if self.cache is None: return []

View File

@@ -1,10 +1,10 @@
from __future__ import annotations
import importlib, inspect, functools, pathlib, time, re
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, cast
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, Set
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT, dedup, all_int
from tinygrad.runtime.lib import RawBuffer
from tinygrad.shape.symbolic import Variable, sym_infer, NumNode, sint
from tinygrad.shape.symbolic import Variable, sym_infer, sint
from dataclasses import dataclass
# these are the llops your accelerator must implement, along with toCpu
@@ -15,7 +15,7 @@ class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto()
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702
class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class BufferOps(Enum): MEM = auto(); CONST = auto() # noqa: E702
class BufferOps(Enum): MEM = auto(); CONST = auto(); FROM_UNDERLYING = auto() # noqa: E702
# Ops below this line are not allowed in ASTs
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
@@ -144,91 +144,52 @@ def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Option
GlobalCounters.global_mem += mem_estimate
if et is not None: GlobalCounters.time_sum_s += et
# **************** batch executor ****************
# **************** shared AST runner ****************
@dataclass(frozen=True)
class JitItem:
prg: ASTRunner
rawbufs: List[Optional[RawBuffer]]
class BatchExecutor:
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int]):
self.jit_cache: List[JitItem] = jit_cache
self.input_replace: Dict[Tuple[int, int], int] = {}
self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0)
for j,ji in enumerate(jit_cache):
if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored
self.op_estimate += ji.prg.op_estimate
self.mem_estimate += ji.prg.mem_estimate
for i,a in enumerate(ji.rawbufs):
if a in input_rawbuffers:
self.input_replace[(j,i)] = input_rawbuffers.index(a)
assert len(set(self.input_replace.values())) == len(input_rawbuffers), "some input tensors not found"
self.clear_jit_inputs()
def __call__(self, input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int], wait=False):
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
for ji in self.jit_cache: ji.prg(ji.rawbufs, var_vals, jit=True)
self.clear_jit_inputs()
def clear_jit_inputs(self):
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
class ASTRunner:
def __init__(self, ast:Optional[LazyOp]):
if ast is None:
self.op_estimate, self.mem_estimate, self.vars = 0, 0, set()
else:
info = get_lazyop_info(ast)
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
from tinygrad.lazy import vars_from_ast
self.vars = vars_from_ast(ast)
assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}"
def exec(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, force_wait=False) -> Optional[float]:
class JITRunner:
def __init__(self):
self.op_estimate, self.mem_estimate = 0, 0
def exec(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
var_vals = var_vals if var_vals is not None else {}
from tinygrad.jit import CacheCollector
et = self(rawbufs, var_vals, force_wait=force_wait)
CacheCollector.add(self, rawbufs, var_vals if var_vals is not None else {})
et = self(rawbufs, var_vals)
CacheCollector.add(self, rawbufs, var_vals)
return et
def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
raise NotImplementedError("override this")
# **************** for Interpreted Buffers ****************
class InterpretedASTRunner(ASTRunner):
class InterpretedASTRunner(JITRunner):
def __init__(self, ast:LazyOp, fxn:Callable):
super().__init__()
self.fxn = fxn
super().__init__(ast)
info = get_lazyop_info(ast)
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> float:
var_vals = {k:var_vals[k] for k in sorted(self.vars)} if var_vals is not None else {}
def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> float:
st = time.perf_counter()
ret: RawBuffer = self.fxn(rawbufs[1:], var_vals)
et = time.perf_counter() - st
update_stats(f"<interpreted {ret.size}>", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit)
if rawbufs[0] is not None:
assert rawbufs[0].dtype == ret.dtype
rawbufs[0].size = ret.size # NOTE: for symbolic this can change
rawbufs[0]._buf = ret._buf
else: rawbufs[0] = ret
assert getattr(rawbufs[0], 'dtype', ret.dtype) == ret.dtype
rawbufs[0].dtype, rawbufs[0].size, rawbufs[0]._buf, rawbufs[0].offset = ret.dtype, ret.size, ret._buf, ret.offset
return et
class Interpreted:
def __init__(self, buffer: Type[RawBuffer], fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable]=None):
self.buffer, self.fxn_for_op, self.from_underlying = buffer, fxn_for_op, from_underlying
self.synchronize = lambda: None
self.batch_executor = BatchExecutor
self.codegen = None
def __init__(self, buffer: Type[RawBuffer], fxn_for_op:Dict[Op, Callable]):
self.buffer, self.fxn_for_op = buffer, fxn_for_op
self.synchronize, self.codegen, self.graph = lambda: None, None, None
self.method_cache: Dict[LazyOp, InterpretedASTRunner] = {}
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
if ast not in self.method_cache: self.method_cache[ast] = get_interpreted_fxn(self.fxn_for_op, self.from_underlying, ast)
if ast not in self.method_cache: self.method_cache[ast] = get_interpreted_fxn(self.fxn_for_op, ast)
rawbufs = [output.realized if output.realized is not None else output.output_buffer] + [x.realized for x in inputs]
if rawbufs[0] is None: rawbufs[0] = self.buffer.__new__(self.buffer)
self.method_cache[ast].exec(rawbufs, var_vals)
output.realized = rawbufs[0]
def get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable], ast:LazyOp) -> InterpretedASTRunner:
def get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> InterpretedASTRunner:
if DEBUG >= 3:
from tinygrad.graph import print_tree
print_tree(ast)
@@ -262,19 +223,26 @@ def get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], from_underlying:Optional[
return ret
ret = _interpret_ast(ast)
src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(from_underlying, 'from_underlying')}({ret})" if from_underlying is not None else f" return {ret}"])
src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(fxn_for_op[BufferOps.FROM_UNDERLYING], BufferOps.FROM_UNDERLYING)}({ret})" if BufferOps.FROM_UNDERLYING in fxn_for_op else f" return {ret}"])
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
return InterpretedASTRunner(ast, tglob['run'])
# **************** for Compiled Buffers ****************
class CompiledASTRunner(ASTRunner):
class CompiledASTRunner(JITRunner):
def __init__(self, ast:Optional[LazyOp], name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
super().__init__()
if DEBUG >= 4: print(prg)
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = \
name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
super().__init__(ast)
self.vars: Set[Variable] = set()
if ast:
info = get_lazyop_info(ast)
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
from tinygrad.lazy import vars_from_ast
self.vars = vars_from_ast(ast)
assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}"
def build(self, compiler, runtime):
self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg)
@@ -286,26 +254,25 @@ class CompiledASTRunner(ASTRunner):
local_size = ([sym_infer(sz, var_vals) for sz in self.local_size] + [1]*(3-len(self.local_size))) if self.local_size is not None else self.local_size
return global_size, local_size
def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
# filter the var_vals
var_vals = {k:var_vals[k] for k in sorted(self.vars)} if var_vals is not None else {}
var_vals = {k:var_vals[k] for k in sorted(self.vars)}
global_size, local_size = self.launch_dims(var_vals)
if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type]
# TODO: this is copied from get_program
from tinygrad.features.search import optimize_local_size
local_size = self.local_size = optimize_local_size(self.clprg, global_size, cast(List[RawBuffer], rawbufs))
local_size = self.local_size = optimize_local_size(self.clprg, global_size, rawbufs)
global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
lra = self.runtime_args.copy()
if global_size: lra['global_size'] = global_size
if local_size and 'local_size' not in lra: lra['local_size'] = local_size
et = self.clprg(*rawbufs, *var_vals.values(), **lra, wait=force_wait or DEBUG>=2)
et = self.clprg(*rawbufs, *var_vals.values(), **lra, wait=wait or DEBUG>=2)
update_stats(self.display_name if self.display_name is not None else self.name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra)
return et
class Compiled:
def __init__(self, buffer: Type[RawBuffer], linearizer_opts:LinearizerOptions, renderer, compiler, runtime, synchronize=lambda: None, batch_executor=BatchExecutor):
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize = buffer, linearizer_opts, renderer, compiler, runtime, synchronize
self.batch_executor = BatchExecutor if getenv("JIT") == 2 else batch_executor
def __init__(self, buffer: Type[RawBuffer], linearizer_opts:LinearizerOptions, renderer, compiler, runtime, synchronize=lambda: None, graph=None):
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize, self.graph = buffer, linearizer_opts, renderer, compiler, runtime, synchronize, graph
self.method_cache: Dict[LazyOp, CompiledASTRunner] = {}
def to_program(self, k:Linearizer) -> CompiledASTRunner:

View File

@@ -10,6 +10,7 @@ class RawBuffer: # pylint: disable=abstract-method
def __init__(self, size:int, dtype:DType, buf:Any=None, allocator:Any=None, **kwargs):
self.size: int = size
self.dtype: DType = dtype
self.offset: int = 0 # TODO: this is very unsupported, only in disk
self._buf = buf if buf is not None else (allocator(size, dtype, **kwargs) if allocator else None) # If buf is provided, use it. Otherwise try to allocate from the allocator.
self._memsz: int = size*dtype.itemsize
self._allocator = allocator if allocator and hasattr(allocator, 'free') else None

View File

@@ -5,6 +5,12 @@ from tinygrad.helpers import dtypes, DType
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op, Interpreted
from tinygrad.runtime.lib import RawBuffer
class RawNumpyBuffer(RawBuffer):
def __init__(self, size:int, dtype:DType, buf:Optional[np.ndarray]=None, allocator=lambda size, dtype: np.empty([size], dtype.np)): super().__init__(size, dtype, buf, allocator)
@classmethod
def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x)
def toCPU(self): return self._buf
def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b)
@@ -32,7 +38,7 @@ def einsum_mulacc(einsum, get_strides, expand):
return mulacc
numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np),
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np), BufferOps.FROM_UNDERLYING: RawNumpyBuffer.fromCPU,
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False), UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(output_type(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
@@ -44,9 +50,4 @@ numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
TernaryOps.WHERE: np.where,
}}
class RawNumpyBuffer(RawBuffer):
def __init__(self, size:int, dtype:DType, buf:Optional[np.ndarray]=None, allocator=lambda size, dtype: np.empty([size], dtype.np)): super().__init__(size, dtype, buf, allocator)
@classmethod
def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x)
def toCPU(self): return self._buf
CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op, RawNumpyBuffer.fromCPU)
CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op)

View File

@@ -9,10 +9,13 @@ from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps, BufferOps
from tinygrad.shape.view import strides_for_shape
MAP_LOCKED, MAP_POPULATE = 0x2000, 0x008000
class UnderlyingDiskBuffer:
def __init__(self, fd, mem): self.fd, self.mem = fd, mem
def __del__(self):
if self.fd: self.fd.close()
class RawDiskBuffer(RawBufferMapped):
def __init__(self, size, dtype:DType, device:Optional[str]=None, buf=None, shape=None, offset=0): # pylint: disable=super-init-not-called
self.shape = (size, ) if shape is None else shape
self.offset = offset # this is an offset in bytes
def __init__(self, size, dtype:DType, buf=None, device:Optional[str]=None, offset:int=0): # pylint: disable=super-init-not-called
assert device is not None or buf is not None, "disk tensor needs a path or a buf"
if device is not None:
if str(device).startswith("shm:"):
@@ -26,28 +29,23 @@ class RawDiskBuffer(RawBufferMapped):
shm = mmap.mmap(fd, size * dtype.itemsize, flags=mmap.MAP_SHARED | MAP_LOCKED | MAP_POPULATE)
shm.madvise(mmap.MADV_HUGEPAGE) # type: ignore # not on OSX
os.close(fd)
buf = [None, shm, 1]
buf = UnderlyingDiskBuffer(None, shm)
else:
f = open(device, "a+b")
if os.path.getsize(device) < size * dtype.itemsize: os.ftruncate(f.fileno(), size * dtype.itemsize)
buf = [f, mmap.mmap(f.fileno(), size * dtype.itemsize), 1]
else:
buf[2] += 1
buf = UnderlyingDiskBuffer(f, mmap.mmap(f.fileno(), size * dtype.itemsize))
# NOTE: we don't call super since disk tensors don't use RAM
self.size, self.dtype, self._buf = size, dtype, buf
def __del__(self):
self._buf[2] -= 1
if self._buf[2] == 0 and self._buf[0] is not None: self._buf[0].close()
def cast(self, arg:Tuple[DType, bool]): return RawDiskBuffer(self.size, arg[0], buf=self._buf, shape=self.shape, offset=self.offset)
self.size, self.dtype, self._buf, self.offset = size, dtype, buf, offset
def cast(self, arg:Tuple[DType, bool]):
return RawDiskBuffer(self.size, arg[0], self._buf, offset=self.offset)
def as_strided(self, arg):
assert strides_for_shape(arg[0]) == arg[1], "disk tensors don't support strides"
return RawDiskBuffer(prod(arg[0]), self.dtype, buf=self._buf, offset=self.offset+arg[2]*self.dtype.itemsize, shape=arg[0])
def _buffer(self): return memoryview(self._buf[1])[self.offset:self.offset+self.size*self.dtype.itemsize]
return RawDiskBuffer(prod(arg[0]), self.dtype, self._buf, offset=self.offset+arg[2]*self.dtype.itemsize)
def _buffer(self): return memoryview(self._buf.mem)[self.offset:self.offset+self.size*self.dtype.itemsize]
def readinto(self, buf:memoryview):
if self._buf[0] is not None:
self._buf[0].seek(self.offset)
self._buf[0].readinto(buf)
if self._buf.fd is not None:
self._buf.fd.seek(self.offset)
self._buf.fd.readinto(buf)
else:
buf.cast('B')[:] = self._buffer()
def transfer(self, cls, shape, dtype, **kwargs):

View File

@@ -1,13 +1,14 @@
# pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch
import os, subprocess, pathlib, ctypes, tempfile
import Metal, libdispatch
from typing import List, Any, Tuple, Dict, Set, cast
from typing import List, Any, Tuple, Dict, Set, cast, Optional
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup
from tinygrad.ops import Compiled, BatchExecutor, JitItem, CompiledASTRunner, update_stats
from tinygrad.ops import Compiled, CompiledASTRunner, update_stats
from tinygrad.renderer.metal import MetalRenderer
from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator
from tinygrad.shape.symbolic import Variable, Node
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, GraphException
class MetalAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs):
@@ -23,7 +24,6 @@ class _METAL:
def __init__(self):
self.mtl_buffers_in_flight: List[Any] = []
self.device = Metal.MTLCreateSystemDefaultDevice()
self.supports_icb = (self.device.supportsFamily_(Metal.MTLGPUFamilyMac2) or self.device.supportsFamily_(Metal.MTLGPUFamilyApple3) or self.device.supportsFamily_(Metal.MTLGPUFamilyCommon2)) and self.device.argumentBuffersSupport() is Metal.MTLArgumentBuffersTier2
self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024)
self.allocator = MetalAllocator(self.device.dedicatedMemorySize() or self.device.sharedMemorySize())
# TODO: is there a better way to do this?
@@ -84,9 +84,11 @@ class MetalProgram:
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
METAL.mtl_buffers_in_flight.append(command_buffer)
class MetalBatchExecutor(BatchExecutor):
class MetalGraph:
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
self.jit_cache = jit_cache
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
# create metal batch exec
icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new()
@@ -95,7 +97,7 @@ class MetalBatchExecutor(BatchExecutor):
icb_descriptor.setInheritPipelineState_(False)
icb_descriptor.setMaxKernelBufferBindCount_(31)
self.icb = METAL.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache), Metal.MTLResourceOptions(0))
assert self.icb is not None, "create indirect command buffer failed, does your system support this?"
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
self.int_buf = RawMetalBuffer(len(var_vals), dtypes.int32)
self.input_has_variable_dims: Set[int] = set()
@@ -127,7 +129,7 @@ class MetalBatchExecutor(BatchExecutor):
self.command_buffer: Any = None
self.int_buf_view = self.int_buf.buffer_view() # TODO: this is metal syncing when it doesn't need to
def __call__(self, input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int], wait=False):
def __call__(self, input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
# NOTE: you at least can't update the ints if this is running
if self.command_buffer is not None and self.command_buffer in METAL.mtl_buffers_in_flight: self.command_buffer.waitUntilCompleted()
all_read_resources = self.read_resources + [x._buf for x in input_rawbuffers]
@@ -151,7 +153,7 @@ class MetalBatchExecutor(BatchExecutor):
else:
METAL.mtl_buffers_in_flight.append(command_buffer)
et = None
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=True, num_kernels=len(self.jit_cache))
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache))
return et
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, batch_executor=MetalBatchExecutor if METAL.supports_icb else BatchExecutor)
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, graph=MetalGraph)

View File

@@ -10,6 +10,14 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if geten
type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool, torch.int16: dtypes.int16}
inverse_type_map = {v:k for k,v in type_map.items()}
class RawTorchBuffer(RawBuffer):
def __init__(self, size:int, dtype:DType, buf:Optional[torch.Tensor]=None, allocator=lambda size, dtype: torch.empty([size], device=device, dtype=inverse_type_map[dtype])): super().__init__(size, dtype, buf, allocator)
@classmethod
def fromCPU(cls, x):
buf = torch.from_numpy(x if all(s>=0 for s in x.strides) else x.copy()).requires_grad_(False).to(device)
return cls(prod(x.shape), type_map[buf.dtype], buf)
def toCPU(self): return self._buf.cpu().numpy()
def output_type(x, y): return x.dtype if type_map[x.dtype].priority > type_map[y.dtype].priority else y.dtype
def match_types(x, y, disallow_bool=False):
up = output_type(x, y)
@@ -26,6 +34,7 @@ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
# TODO: torch.tensor should work here
#BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]),
BufferOps.CONST: lambda val, dtype: torch.from_numpy(np.array(val, dtype=dtype.np)).requires_grad_(False).to(device),
BufferOps.FROM_UNDERLYING: lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x),
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.SIN: torch.sin,
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(next(k for k,v in type_map.items() if v==y[0])), UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x),
BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).type(torch.promote_types(x.dtype, y.dtype)),
@@ -39,11 +48,4 @@ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
TernaryOps.WHERE: lambda x, y, z: torch.where(x != 0, y, z),
}}
class RawTorchBuffer(RawBuffer):
def __init__(self, size:int, dtype:DType, buf:Optional[torch.Tensor]=None, allocator=lambda size, dtype: torch.empty([size], device=device, dtype=inverse_type_map[dtype])): super().__init__(size, dtype, buf, allocator)
@classmethod
def fromCPU(cls, x):
buf = torch.from_numpy(x if all(s>=0 for s in x.strides) else x.copy()).requires_grad_(False).to(device)
return cls(prod(x.shape), type_map[buf.dtype], buf)
def toCPU(self): return self._buf.cpu().numpy()
TorchBuffer = Interpreted(RawTorchBuffer, torch_fxn_for_op, lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x))
TorchBuffer = Interpreted(RawTorchBuffer, torch_fxn_for_op)