mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
remove force_wait, refactor to graph (#2405)
* remove force_wait * refactor * get rid of stupid ASTRunner * fix del in diskbuffer * BufferOps.FROM_UNDERLYING * put offset in the rawbuffer * fix bugs * use exec
This commit is contained in:
6
extra/dist/world.py
vendored
6
extra/dist/world.py
vendored
@@ -12,8 +12,8 @@ from tinygrad.tensor import Tensor, Function
|
||||
import extra.hip_wrapper as hip
|
||||
import numpy as np
|
||||
|
||||
# fake the function signature of ASTRunner so we can put it in the cache
|
||||
def __send_rb(args, variables=None, jit=False, force_wait=False):
|
||||
# match the function signature of JITRunner so we can put it in the cache
|
||||
def __send_rb(args, variables=None, wait=False, jit=False):
|
||||
x, target_rank, y = args[:3]
|
||||
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
|
||||
hip.hipSetDevice(x._device)
|
||||
@@ -24,7 +24,7 @@ def __send_rb(args, variables=None, jit=False, force_wait=False):
|
||||
dist.OOB.send(None, target_rank)
|
||||
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} sent {x} to rank {target_rank}")
|
||||
|
||||
def __recv_rb(args, variables=None, jit=False, force_wait=False):
|
||||
def __recv_rb(args, variables=None, wait=False, jit=False):
|
||||
x, target_rank, y = args[:3]
|
||||
dist.OOB.recv(target_rank)
|
||||
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
|
||||
|
||||
26
test/external/external_copy_benchmark.py
vendored
26
test/external/external_copy_benchmark.py
vendored
@@ -1,26 +0,0 @@
|
||||
import unittest
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import GlobalCounters
|
||||
from tinygrad.jit import CacheCollector
|
||||
|
||||
class TestCopy(unittest.TestCase):
|
||||
def test_add1(self):
|
||||
pts = []
|
||||
for i in range(16384, 16384*256, 16384):
|
||||
t = Tensor.randn(i).realize()
|
||||
CacheCollector.start()
|
||||
t.assign(t+1).realize()
|
||||
ji = CacheCollector.finish()[0]
|
||||
GlobalCounters.reset()
|
||||
def run(): return ji.prg(ji.rawbufs, force_wait=True)
|
||||
ct = min([run() for _ in range(10)])
|
||||
mb = prod(t.shape)*t.dtype.itemsize*2*1e-6
|
||||
print(f"{mb*1e3:.2f} kB, {ct*1e3:.2f} ms, {mb/ct:.2f} MB/s")
|
||||
pts.append((mb, mb/ct))
|
||||
from matplotlib import pyplot as plt
|
||||
plt.plot([x[0] for x in pts], [x[1] for x in pts])
|
||||
plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
4
test/external/fuzz_linearizer.py
vendored
4
test/external/fuzz_linearizer.py
vendored
@@ -21,7 +21,7 @@ def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None):
|
||||
if isinstance(device, Compiled):
|
||||
prg = device.to_program(lin)
|
||||
else:
|
||||
prg = get_interpreted_fxn(device.fxn_for_op, device.from_underlying, lin.ast)
|
||||
prg = get_interpreted_fxn(device.fxn_for_op, lin.ast)
|
||||
except:
|
||||
print(lin.ast)
|
||||
traceback.print_exc()
|
||||
@@ -29,7 +29,7 @@ def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None):
|
||||
return "COMPILE_ERROR"
|
||||
|
||||
try:
|
||||
prg.exec(rawbufs, var_vals, force_wait=True)
|
||||
prg.exec(rawbufs, var_vals)
|
||||
except:
|
||||
print(lin.ast)
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -349,13 +349,13 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False):
|
||||
k.apply_opt(opt)
|
||||
prg = to_prg(k)
|
||||
real_bufs[0] = real_bufs[0].fromCPU(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np)) # Zero to check that all values are filled
|
||||
prg.exec(real_bufs, force_wait=True)
|
||||
prg.exec(real_bufs)
|
||||
np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
# Get baseline, which is not optimized at all.
|
||||
k = Linearizer(realized_ast)
|
||||
prg = Device[Device.DEFAULT].to_program(k)
|
||||
prg.exec(real_bufs, force_wait=True)
|
||||
prg.exec(real_bufs)
|
||||
wanna_output = real_bufs[0].toCPU().copy()
|
||||
|
||||
# Check correctness of handcoded optimiztions.
|
||||
@@ -363,7 +363,7 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False):
|
||||
k.hand_coded_optimizations()
|
||||
prg = Device[Device.DEFAULT].to_program(k)
|
||||
real_bufs[0] = real_bufs[0].fromCPU(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np)) # Zero to check that all values are filled
|
||||
prg.exec(real_bufs, force_wait=True)
|
||||
prg.exec(real_bufs)
|
||||
np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
|
||||
for x in opts: # Check custom transformations if any.
|
||||
check_opt(x, lambda: Linearizer(realized_ast), Device[Device.DEFAULT].to_program)
|
||||
|
||||
@@ -26,7 +26,7 @@ def _test_single_value(vals, op, dtype):
|
||||
buf = Device[Device.DEFAULT].buffer(1, dtype)
|
||||
buf2 = [Device[Device.DEFAULT].buffer.fromCPU(np.array([a], dtype=dtype.np)) for a in vals]
|
||||
prg = _uops_to_prg(uops)
|
||||
prg([buf]+buf2)
|
||||
prg.exec([buf]+buf2)
|
||||
return buf.toCPU()[0]
|
||||
|
||||
def _test_single_value_const(vals, op, dtype):
|
||||
@@ -37,7 +37,7 @@ def _test_single_value_const(vals, op, dtype):
|
||||
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
||||
buf = Device[Device.DEFAULT].buffer(1, dtype)
|
||||
prg = _uops_to_prg(uops)
|
||||
prg([buf])
|
||||
prg.exec([buf])
|
||||
return buf.toCPU()[0]
|
||||
|
||||
class TestUOps(unittest.TestCase):
|
||||
|
||||
@@ -1,32 +1,46 @@
|
||||
from __future__ import annotations
|
||||
from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar, Generic
|
||||
import functools, itertools
|
||||
from tinygrad.helpers import DEBUG, DType, merge_dicts
|
||||
from tinygrad.ops import RawBuffer, Device, ASTRunner, BatchExecutor, JitItem
|
||||
import functools, itertools, operator
|
||||
from tinygrad.helpers import DEBUG, DType, merge_dicts, getenv
|
||||
from tinygrad.ops import RawBuffer, Device, JITRunner
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, Node
|
||||
from weakref import ref, WeakKeyDictionary
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JitItem:
|
||||
prg: JITRunner
|
||||
rawbufs: List[Optional[RawBuffer]]
|
||||
|
||||
def get_jit_stats(jit_cache: List[JitItem]) -> Tuple[Node, Node]:
|
||||
return functools.reduce(operator.__add__, [ji.prg.op_estimate for ji in jit_cache], NumNode(0)), functools.reduce(operator.__add__, [ji.prg.mem_estimate for ji in jit_cache], NumNode(0))
|
||||
def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[RawBuffer]) -> Dict[Tuple[int, int], int]:
|
||||
input_replace: Dict[Tuple[int, int], int] = {}
|
||||
for j,ji in enumerate(jit_cache):
|
||||
for i,a in enumerate(ji.rawbufs):
|
||||
if a in input_rawbuffers:
|
||||
input_replace[(j,i)] = input_rawbuffers.index(a)
|
||||
assert len(set(input_replace.values())) == len(input_rawbuffers), "some input tensors not found"
|
||||
return input_replace
|
||||
|
||||
class GraphException(Exception): pass
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
|
||||
class TinyJit(Generic[ReturnType]):
|
||||
def __init__(self, fxn:Callable[..., ReturnType]):
|
||||
self.fxn = fxn
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.jit_fxn: Optional[BatchExecutor] = None
|
||||
self.jit_cache: List[JitItem] = []
|
||||
self.input_replace: Dict[Tuple[int, int], int] = {}
|
||||
self.cnt: int = 0
|
||||
self.ret: Optional[ReturnType] = None
|
||||
self.expected_vals: Optional[Tuple[Variable, ...]] = None
|
||||
self.expected_name_sts_dtype: Optional[Tuple[Tuple[Union[int, str], ShapeTracker, DType], ...]] = None
|
||||
|
||||
@property
|
||||
def jit_cache(self) -> List[JitItem]: return self.jit_fxn.jit_cache if self.jit_fxn else []
|
||||
@property
|
||||
def input_replace(self) -> Dict[Tuple[int, int], int]: return self.jit_fxn.input_replace if self.jit_fxn else {}
|
||||
|
||||
# add support for instance methods
|
||||
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
|
||||
|
||||
@@ -44,23 +58,35 @@ class TinyJit(Generic[ReturnType]):
|
||||
expected_vals = tuple(var_vals.keys())
|
||||
|
||||
if self.cnt >= 2:
|
||||
# jit exec
|
||||
assert self.expected_vals == expected_vals, "mismatch of var_vals"
|
||||
assert self.expected_name_sts_dtype == expected_name_sts_dtype, f"mismatch of sts, expected {self.expected_name_sts_dtype} got {expected_name_sts_dtype}"
|
||||
assert self.jit_fxn, "didn't get jitted?"
|
||||
self.jit_fxn(input_rawbuffers, var_vals, DEBUG>=2)
|
||||
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
|
||||
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), var_vals, jit=True)
|
||||
elif self.cnt == 1:
|
||||
# jit capture
|
||||
self.expected_vals, self.expected_name_sts_dtype = expected_vals, expected_name_sts_dtype
|
||||
|
||||
CacheCollector.start(var_vals)
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
jit_cache = CacheCollector.finish()
|
||||
assert len(jit_cache) != 0, "didn't JIT anything!"
|
||||
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_rawbuffers)} inputs")
|
||||
self.jit_cache = CacheCollector.finish()
|
||||
assert len(self.jit_cache) != 0, "didn't JIT anything!"
|
||||
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
|
||||
|
||||
self.jit_fxn = Device[Device.DEFAULT].batch_executor(jit_cache, input_rawbuffers, var_vals)
|
||||
# if your Device supports it, condense the items into a graph executor
|
||||
if (make_graph := Device[Device.DEFAULT].graph) and getenv("JIT") != 2:
|
||||
try:
|
||||
self.jit_cache = [JitItem(make_graph(self.jit_cache, input_rawbuffers, var_vals), cast(List[Optional[RawBuffer]], input_rawbuffers))]
|
||||
except GraphException as e:
|
||||
if DEBUG >= 1: print(f"graph create failed {e}")
|
||||
|
||||
self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers)
|
||||
elif self.cnt == 0:
|
||||
# jit ignore
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
|
||||
# clear jit inputs
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
|
||||
self.cnt += 1
|
||||
return cast(ReturnType, self.ret)
|
||||
|
||||
@@ -77,7 +103,7 @@ class PlaceHolder:
|
||||
|
||||
class _CacheCollector:
|
||||
def __init__(self):
|
||||
self.cache: Optional[List[Tuple[ASTRunner, List[Union[RawBuffer, PlaceHolder]]]]] = None
|
||||
self.cache: Optional[List[Tuple[JITRunner, List[Union[RawBuffer, PlaceHolder]]]]] = None
|
||||
|
||||
def start(self, var_vals:Optional[Dict[Variable, int]]=None):
|
||||
self.cache = []
|
||||
@@ -88,7 +114,7 @@ class _CacheCollector:
|
||||
if self.cache is None: return
|
||||
for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}"
|
||||
self.placeholders[rawbufs[0]] = PlaceHolder(rawbufs[0]) # NOTE: this is making an assumption that 0 is special
|
||||
self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(prg, ASTRunner) and isinstance(x, RawBuffer) else x for x in rawbufs]))
|
||||
self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, RawBuffer) else x for x in rawbufs]))
|
||||
|
||||
def finish(self) -> List[JitItem]:
|
||||
if self.cache is None: return []
|
||||
|
||||
115
tinygrad/ops.py
115
tinygrad/ops.py
@@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
import importlib, inspect, functools, pathlib, time, re
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, cast
|
||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, Set
|
||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT, dedup, all_int
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer, NumNode, sint
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
||||
from dataclasses import dataclass
|
||||
|
||||
# these are the llops your accelerator must implement, along with toCpu
|
||||
@@ -15,7 +15,7 @@ class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto()
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702
|
||||
class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class BufferOps(Enum): MEM = auto(); CONST = auto() # noqa: E702
|
||||
class BufferOps(Enum): MEM = auto(); CONST = auto(); FROM_UNDERLYING = auto() # noqa: E702
|
||||
# Ops below this line are not allowed in ASTs
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
|
||||
@@ -144,91 +144,52 @@ def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Option
|
||||
GlobalCounters.global_mem += mem_estimate
|
||||
if et is not None: GlobalCounters.time_sum_s += et
|
||||
|
||||
# **************** batch executor ****************
|
||||
# **************** shared AST runner ****************
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JitItem:
|
||||
prg: ASTRunner
|
||||
rawbufs: List[Optional[RawBuffer]]
|
||||
|
||||
class BatchExecutor:
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int]):
|
||||
self.jit_cache: List[JitItem] = jit_cache
|
||||
self.input_replace: Dict[Tuple[int, int], int] = {}
|
||||
self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0)
|
||||
for j,ji in enumerate(jit_cache):
|
||||
if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored
|
||||
self.op_estimate += ji.prg.op_estimate
|
||||
self.mem_estimate += ji.prg.mem_estimate
|
||||
for i,a in enumerate(ji.rawbufs):
|
||||
if a in input_rawbuffers:
|
||||
self.input_replace[(j,i)] = input_rawbuffers.index(a)
|
||||
assert len(set(self.input_replace.values())) == len(input_rawbuffers), "some input tensors not found"
|
||||
self.clear_jit_inputs()
|
||||
|
||||
def __call__(self, input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int], wait=False):
|
||||
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
|
||||
for ji in self.jit_cache: ji.prg(ji.rawbufs, var_vals, jit=True)
|
||||
self.clear_jit_inputs()
|
||||
|
||||
def clear_jit_inputs(self):
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
|
||||
class ASTRunner:
|
||||
def __init__(self, ast:Optional[LazyOp]):
|
||||
if ast is None:
|
||||
self.op_estimate, self.mem_estimate, self.vars = 0, 0, set()
|
||||
else:
|
||||
info = get_lazyop_info(ast)
|
||||
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
self.vars = vars_from_ast(ast)
|
||||
assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}"
|
||||
|
||||
def exec(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, force_wait=False) -> Optional[float]:
|
||||
class JITRunner:
|
||||
def __init__(self):
|
||||
self.op_estimate, self.mem_estimate = 0, 0
|
||||
def exec(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
|
||||
var_vals = var_vals if var_vals is not None else {}
|
||||
from tinygrad.jit import CacheCollector
|
||||
et = self(rawbufs, var_vals, force_wait=force_wait)
|
||||
CacheCollector.add(self, rawbufs, var_vals if var_vals is not None else {})
|
||||
et = self(rawbufs, var_vals)
|
||||
CacheCollector.add(self, rawbufs, var_vals)
|
||||
return et
|
||||
|
||||
def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
|
||||
def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
raise NotImplementedError("override this")
|
||||
|
||||
# **************** for Interpreted Buffers ****************
|
||||
|
||||
class InterpretedASTRunner(ASTRunner):
|
||||
class InterpretedASTRunner(JITRunner):
|
||||
def __init__(self, ast:LazyOp, fxn:Callable):
|
||||
super().__init__()
|
||||
self.fxn = fxn
|
||||
super().__init__(ast)
|
||||
info = get_lazyop_info(ast)
|
||||
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
|
||||
|
||||
def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> float:
|
||||
var_vals = {k:var_vals[k] for k in sorted(self.vars)} if var_vals is not None else {}
|
||||
def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> float:
|
||||
st = time.perf_counter()
|
||||
ret: RawBuffer = self.fxn(rawbufs[1:], var_vals)
|
||||
et = time.perf_counter() - st
|
||||
update_stats(f"<interpreted {ret.size}>", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit)
|
||||
if rawbufs[0] is not None:
|
||||
assert rawbufs[0].dtype == ret.dtype
|
||||
rawbufs[0].size = ret.size # NOTE: for symbolic this can change
|
||||
rawbufs[0]._buf = ret._buf
|
||||
else: rawbufs[0] = ret
|
||||
assert getattr(rawbufs[0], 'dtype', ret.dtype) == ret.dtype
|
||||
rawbufs[0].dtype, rawbufs[0].size, rawbufs[0]._buf, rawbufs[0].offset = ret.dtype, ret.size, ret._buf, ret.offset
|
||||
return et
|
||||
|
||||
class Interpreted:
|
||||
def __init__(self, buffer: Type[RawBuffer], fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable]=None):
|
||||
self.buffer, self.fxn_for_op, self.from_underlying = buffer, fxn_for_op, from_underlying
|
||||
self.synchronize = lambda: None
|
||||
self.batch_executor = BatchExecutor
|
||||
self.codegen = None
|
||||
def __init__(self, buffer: Type[RawBuffer], fxn_for_op:Dict[Op, Callable]):
|
||||
self.buffer, self.fxn_for_op = buffer, fxn_for_op
|
||||
self.synchronize, self.codegen, self.graph = lambda: None, None, None
|
||||
self.method_cache: Dict[LazyOp, InterpretedASTRunner] = {}
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
|
||||
if ast not in self.method_cache: self.method_cache[ast] = get_interpreted_fxn(self.fxn_for_op, self.from_underlying, ast)
|
||||
if ast not in self.method_cache: self.method_cache[ast] = get_interpreted_fxn(self.fxn_for_op, ast)
|
||||
rawbufs = [output.realized if output.realized is not None else output.output_buffer] + [x.realized for x in inputs]
|
||||
if rawbufs[0] is None: rawbufs[0] = self.buffer.__new__(self.buffer)
|
||||
self.method_cache[ast].exec(rawbufs, var_vals)
|
||||
output.realized = rawbufs[0]
|
||||
|
||||
def get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable], ast:LazyOp) -> InterpretedASTRunner:
|
||||
def get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> InterpretedASTRunner:
|
||||
if DEBUG >= 3:
|
||||
from tinygrad.graph import print_tree
|
||||
print_tree(ast)
|
||||
@@ -262,19 +223,26 @@ def get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], from_underlying:Optional[
|
||||
return ret
|
||||
|
||||
ret = _interpret_ast(ast)
|
||||
src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(from_underlying, 'from_underlying')}({ret})" if from_underlying is not None else f" return {ret}"])
|
||||
src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(fxn_for_op[BufferOps.FROM_UNDERLYING], BufferOps.FROM_UNDERLYING)}({ret})" if BufferOps.FROM_UNDERLYING in fxn_for_op else f" return {ret}"])
|
||||
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
|
||||
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
|
||||
return InterpretedASTRunner(ast, tglob['run'])
|
||||
|
||||
# **************** for Compiled Buffers ****************
|
||||
|
||||
class CompiledASTRunner(ASTRunner):
|
||||
class CompiledASTRunner(JITRunner):
|
||||
def __init__(self, ast:Optional[LazyOp], name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
|
||||
super().__init__()
|
||||
if DEBUG >= 4: print(prg)
|
||||
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = \
|
||||
name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
|
||||
super().__init__(ast)
|
||||
self.vars: Set[Variable] = set()
|
||||
if ast:
|
||||
info = get_lazyop_info(ast)
|
||||
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
self.vars = vars_from_ast(ast)
|
||||
assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}"
|
||||
|
||||
def build(self, compiler, runtime):
|
||||
self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg)
|
||||
@@ -286,26 +254,25 @@ class CompiledASTRunner(ASTRunner):
|
||||
local_size = ([sym_infer(sz, var_vals) for sz in self.local_size] + [1]*(3-len(self.local_size))) if self.local_size is not None else self.local_size
|
||||
return global_size, local_size
|
||||
|
||||
def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
|
||||
def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
# filter the var_vals
|
||||
var_vals = {k:var_vals[k] for k in sorted(self.vars)} if var_vals is not None else {}
|
||||
var_vals = {k:var_vals[k] for k in sorted(self.vars)}
|
||||
global_size, local_size = self.launch_dims(var_vals)
|
||||
if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type]
|
||||
# TODO: this is copied from get_program
|
||||
from tinygrad.features.search import optimize_local_size
|
||||
local_size = self.local_size = optimize_local_size(self.clprg, global_size, cast(List[RawBuffer], rawbufs))
|
||||
local_size = self.local_size = optimize_local_size(self.clprg, global_size, rawbufs)
|
||||
global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
|
||||
lra = self.runtime_args.copy()
|
||||
if global_size: lra['global_size'] = global_size
|
||||
if local_size and 'local_size' not in lra: lra['local_size'] = local_size
|
||||
et = self.clprg(*rawbufs, *var_vals.values(), **lra, wait=force_wait or DEBUG>=2)
|
||||
et = self.clprg(*rawbufs, *var_vals.values(), **lra, wait=wait or DEBUG>=2)
|
||||
update_stats(self.display_name if self.display_name is not None else self.name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra)
|
||||
return et
|
||||
|
||||
class Compiled:
|
||||
def __init__(self, buffer: Type[RawBuffer], linearizer_opts:LinearizerOptions, renderer, compiler, runtime, synchronize=lambda: None, batch_executor=BatchExecutor):
|
||||
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize = buffer, linearizer_opts, renderer, compiler, runtime, synchronize
|
||||
self.batch_executor = BatchExecutor if getenv("JIT") == 2 else batch_executor
|
||||
def __init__(self, buffer: Type[RawBuffer], linearizer_opts:LinearizerOptions, renderer, compiler, runtime, synchronize=lambda: None, graph=None):
|
||||
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize, self.graph = buffer, linearizer_opts, renderer, compiler, runtime, synchronize, graph
|
||||
self.method_cache: Dict[LazyOp, CompiledASTRunner] = {}
|
||||
|
||||
def to_program(self, k:Linearizer) -> CompiledASTRunner:
|
||||
|
||||
@@ -10,6 +10,7 @@ class RawBuffer: # pylint: disable=abstract-method
|
||||
def __init__(self, size:int, dtype:DType, buf:Any=None, allocator:Any=None, **kwargs):
|
||||
self.size: int = size
|
||||
self.dtype: DType = dtype
|
||||
self.offset: int = 0 # TODO: this is very unsupported, only in disk
|
||||
self._buf = buf if buf is not None else (allocator(size, dtype, **kwargs) if allocator else None) # If buf is provided, use it. Otherwise try to allocate from the allocator.
|
||||
self._memsz: int = size*dtype.itemsize
|
||||
self._allocator = allocator if allocator and hasattr(allocator, 'free') else None
|
||||
|
||||
@@ -5,6 +5,12 @@ from tinygrad.helpers import dtypes, DType
|
||||
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op, Interpreted
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
|
||||
class RawNumpyBuffer(RawBuffer):
|
||||
def __init__(self, size:int, dtype:DType, buf:Optional[np.ndarray]=None, allocator=lambda size, dtype: np.empty([size], dtype.np)): super().__init__(size, dtype, buf, allocator)
|
||||
@classmethod
|
||||
def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x)
|
||||
def toCPU(self): return self._buf
|
||||
|
||||
def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
|
||||
return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b)
|
||||
@@ -32,7 +38,7 @@ def einsum_mulacc(einsum, get_strides, expand):
|
||||
return mulacc
|
||||
|
||||
numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np),
|
||||
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np), BufferOps.FROM_UNDERLYING: RawNumpyBuffer.fromCPU,
|
||||
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
|
||||
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False), UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(output_type(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
|
||||
@@ -44,9 +50,4 @@ numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
TernaryOps.WHERE: np.where,
|
||||
}}
|
||||
|
||||
class RawNumpyBuffer(RawBuffer):
|
||||
def __init__(self, size:int, dtype:DType, buf:Optional[np.ndarray]=None, allocator=lambda size, dtype: np.empty([size], dtype.np)): super().__init__(size, dtype, buf, allocator)
|
||||
@classmethod
|
||||
def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x)
|
||||
def toCPU(self): return self._buf
|
||||
CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op, RawNumpyBuffer.fromCPU)
|
||||
CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op)
|
||||
|
||||
@@ -9,10 +9,13 @@ from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps, BufferOps
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
MAP_LOCKED, MAP_POPULATE = 0x2000, 0x008000
|
||||
|
||||
class UnderlyingDiskBuffer:
|
||||
def __init__(self, fd, mem): self.fd, self.mem = fd, mem
|
||||
def __del__(self):
|
||||
if self.fd: self.fd.close()
|
||||
|
||||
class RawDiskBuffer(RawBufferMapped):
|
||||
def __init__(self, size, dtype:DType, device:Optional[str]=None, buf=None, shape=None, offset=0): # pylint: disable=super-init-not-called
|
||||
self.shape = (size, ) if shape is None else shape
|
||||
self.offset = offset # this is an offset in bytes
|
||||
def __init__(self, size, dtype:DType, buf=None, device:Optional[str]=None, offset:int=0): # pylint: disable=super-init-not-called
|
||||
assert device is not None or buf is not None, "disk tensor needs a path or a buf"
|
||||
if device is not None:
|
||||
if str(device).startswith("shm:"):
|
||||
@@ -26,28 +29,23 @@ class RawDiskBuffer(RawBufferMapped):
|
||||
shm = mmap.mmap(fd, size * dtype.itemsize, flags=mmap.MAP_SHARED | MAP_LOCKED | MAP_POPULATE)
|
||||
shm.madvise(mmap.MADV_HUGEPAGE) # type: ignore # not on OSX
|
||||
os.close(fd)
|
||||
buf = [None, shm, 1]
|
||||
buf = UnderlyingDiskBuffer(None, shm)
|
||||
else:
|
||||
f = open(device, "a+b")
|
||||
if os.path.getsize(device) < size * dtype.itemsize: os.ftruncate(f.fileno(), size * dtype.itemsize)
|
||||
buf = [f, mmap.mmap(f.fileno(), size * dtype.itemsize), 1]
|
||||
else:
|
||||
buf[2] += 1
|
||||
buf = UnderlyingDiskBuffer(f, mmap.mmap(f.fileno(), size * dtype.itemsize))
|
||||
# NOTE: we don't call super since disk tensors don't use RAM
|
||||
self.size, self.dtype, self._buf = size, dtype, buf
|
||||
def __del__(self):
|
||||
self._buf[2] -= 1
|
||||
if self._buf[2] == 0 and self._buf[0] is not None: self._buf[0].close()
|
||||
def cast(self, arg:Tuple[DType, bool]): return RawDiskBuffer(self.size, arg[0], buf=self._buf, shape=self.shape, offset=self.offset)
|
||||
self.size, self.dtype, self._buf, self.offset = size, dtype, buf, offset
|
||||
def cast(self, arg:Tuple[DType, bool]):
|
||||
return RawDiskBuffer(self.size, arg[0], self._buf, offset=self.offset)
|
||||
def as_strided(self, arg):
|
||||
assert strides_for_shape(arg[0]) == arg[1], "disk tensors don't support strides"
|
||||
return RawDiskBuffer(prod(arg[0]), self.dtype, buf=self._buf, offset=self.offset+arg[2]*self.dtype.itemsize, shape=arg[0])
|
||||
|
||||
def _buffer(self): return memoryview(self._buf[1])[self.offset:self.offset+self.size*self.dtype.itemsize]
|
||||
return RawDiskBuffer(prod(arg[0]), self.dtype, self._buf, offset=self.offset+arg[2]*self.dtype.itemsize)
|
||||
def _buffer(self): return memoryview(self._buf.mem)[self.offset:self.offset+self.size*self.dtype.itemsize]
|
||||
def readinto(self, buf:memoryview):
|
||||
if self._buf[0] is not None:
|
||||
self._buf[0].seek(self.offset)
|
||||
self._buf[0].readinto(buf)
|
||||
if self._buf.fd is not None:
|
||||
self._buf.fd.seek(self.offset)
|
||||
self._buf.fd.readinto(buf)
|
||||
else:
|
||||
buf.cast('B')[:] = self._buffer()
|
||||
def transfer(self, cls, shape, dtype, **kwargs):
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
# pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch
|
||||
import os, subprocess, pathlib, ctypes, tempfile
|
||||
import Metal, libdispatch
|
||||
from typing import List, Any, Tuple, Dict, Set, cast
|
||||
from typing import List, Any, Tuple, Dict, Set, cast, Optional
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup
|
||||
from tinygrad.ops import Compiled, BatchExecutor, JitItem, CompiledASTRunner, update_stats
|
||||
from tinygrad.ops import Compiled, CompiledASTRunner, update_stats
|
||||
from tinygrad.renderer.metal import MetalRenderer
|
||||
from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator
|
||||
from tinygrad.shape.symbolic import Variable, Node
|
||||
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, GraphException
|
||||
|
||||
class MetalAllocator(LRUAllocator):
|
||||
def _do_alloc(self, size, dtype, device, **kwargs):
|
||||
@@ -23,7 +24,6 @@ class _METAL:
|
||||
def __init__(self):
|
||||
self.mtl_buffers_in_flight: List[Any] = []
|
||||
self.device = Metal.MTLCreateSystemDefaultDevice()
|
||||
self.supports_icb = (self.device.supportsFamily_(Metal.MTLGPUFamilyMac2) or self.device.supportsFamily_(Metal.MTLGPUFamilyApple3) or self.device.supportsFamily_(Metal.MTLGPUFamilyCommon2)) and self.device.argumentBuffersSupport() is Metal.MTLArgumentBuffersTier2
|
||||
self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024)
|
||||
self.allocator = MetalAllocator(self.device.dedicatedMemorySize() or self.device.sharedMemorySize())
|
||||
# TODO: is there a better way to do this?
|
||||
@@ -84,9 +84,11 @@ class MetalProgram:
|
||||
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
||||
METAL.mtl_buffers_in_flight.append(command_buffer)
|
||||
|
||||
class MetalBatchExecutor(BatchExecutor):
|
||||
class MetalGraph:
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
self.jit_cache = jit_cache
|
||||
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
||||
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
|
||||
|
||||
# create metal batch exec
|
||||
icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new()
|
||||
@@ -95,7 +97,7 @@ class MetalBatchExecutor(BatchExecutor):
|
||||
icb_descriptor.setInheritPipelineState_(False)
|
||||
icb_descriptor.setMaxKernelBufferBindCount_(31)
|
||||
self.icb = METAL.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache), Metal.MTLResourceOptions(0))
|
||||
assert self.icb is not None, "create indirect command buffer failed, does your system support this?"
|
||||
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
||||
|
||||
self.int_buf = RawMetalBuffer(len(var_vals), dtypes.int32)
|
||||
self.input_has_variable_dims: Set[int] = set()
|
||||
@@ -127,7 +129,7 @@ class MetalBatchExecutor(BatchExecutor):
|
||||
self.command_buffer: Any = None
|
||||
self.int_buf_view = self.int_buf.buffer_view() # TODO: this is metal syncing when it doesn't need to
|
||||
|
||||
def __call__(self, input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int], wait=False):
|
||||
def __call__(self, input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
# NOTE: you at least can't update the ints if this is running
|
||||
if self.command_buffer is not None and self.command_buffer in METAL.mtl_buffers_in_flight: self.command_buffer.waitUntilCompleted()
|
||||
all_read_resources = self.read_resources + [x._buf for x in input_rawbuffers]
|
||||
@@ -151,7 +153,7 @@ class MetalBatchExecutor(BatchExecutor):
|
||||
else:
|
||||
METAL.mtl_buffers_in_flight.append(command_buffer)
|
||||
et = None
|
||||
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=True, num_kernels=len(self.jit_cache))
|
||||
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache))
|
||||
return et
|
||||
|
||||
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, batch_executor=MetalBatchExecutor if METAL.supports_icb else BatchExecutor)
|
||||
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, graph=MetalGraph)
|
||||
|
||||
@@ -10,6 +10,14 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if geten
|
||||
type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool, torch.int16: dtypes.int16}
|
||||
inverse_type_map = {v:k for k,v in type_map.items()}
|
||||
|
||||
class RawTorchBuffer(RawBuffer):
|
||||
def __init__(self, size:int, dtype:DType, buf:Optional[torch.Tensor]=None, allocator=lambda size, dtype: torch.empty([size], device=device, dtype=inverse_type_map[dtype])): super().__init__(size, dtype, buf, allocator)
|
||||
@classmethod
|
||||
def fromCPU(cls, x):
|
||||
buf = torch.from_numpy(x if all(s>=0 for s in x.strides) else x.copy()).requires_grad_(False).to(device)
|
||||
return cls(prod(x.shape), type_map[buf.dtype], buf)
|
||||
def toCPU(self): return self._buf.cpu().numpy()
|
||||
|
||||
def output_type(x, y): return x.dtype if type_map[x.dtype].priority > type_map[y.dtype].priority else y.dtype
|
||||
def match_types(x, y, disallow_bool=False):
|
||||
up = output_type(x, y)
|
||||
@@ -26,6 +34,7 @@ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
# TODO: torch.tensor should work here
|
||||
#BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]),
|
||||
BufferOps.CONST: lambda val, dtype: torch.from_numpy(np.array(val, dtype=dtype.np)).requires_grad_(False).to(device),
|
||||
BufferOps.FROM_UNDERLYING: lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x),
|
||||
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.SIN: torch.sin,
|
||||
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(next(k for k,v in type_map.items() if v==y[0])), UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x),
|
||||
BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).type(torch.promote_types(x.dtype, y.dtype)),
|
||||
@@ -39,11 +48,4 @@ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
TernaryOps.WHERE: lambda x, y, z: torch.where(x != 0, y, z),
|
||||
}}
|
||||
|
||||
class RawTorchBuffer(RawBuffer):
|
||||
def __init__(self, size:int, dtype:DType, buf:Optional[torch.Tensor]=None, allocator=lambda size, dtype: torch.empty([size], device=device, dtype=inverse_type_map[dtype])): super().__init__(size, dtype, buf, allocator)
|
||||
@classmethod
|
||||
def fromCPU(cls, x):
|
||||
buf = torch.from_numpy(x if all(s>=0 for s in x.strides) else x.copy()).requires_grad_(False).to(device)
|
||||
return cls(prod(x.shape), type_map[buf.dtype], buf)
|
||||
def toCPU(self): return self._buf.cpu().numpy()
|
||||
TorchBuffer = Interpreted(RawTorchBuffer, torch_fxn_for_op, lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x))
|
||||
TorchBuffer = Interpreted(RawTorchBuffer, torch_fxn_for_op)
|
||||
|
||||
Reference in New Issue
Block a user