diff --git a/extra/dist/world.py b/extra/dist/world.py index 70c1728499..c4fb9e08f7 100644 --- a/extra/dist/world.py +++ b/extra/dist/world.py @@ -12,8 +12,8 @@ from tinygrad.tensor import Tensor, Function import extra.hip_wrapper as hip import numpy as np -# fake the function signature of ASTRunner so we can put it in the cache -def __send_rb(args, variables=None, jit=False, force_wait=False): +# match the function signature of JITRunner so we can put it in the cache +def __send_rb(args, variables=None, wait=False, jit=False): x, target_rank, y = args[:3] if RawHIPBuffer and x.__class__ is RawHIPBuffer: hip.hipSetDevice(x._device) @@ -24,7 +24,7 @@ def __send_rb(args, variables=None, jit=False, force_wait=False): dist.OOB.send(None, target_rank) if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} sent {x} to rank {target_rank}") -def __recv_rb(args, variables=None, jit=False, force_wait=False): +def __recv_rb(args, variables=None, wait=False, jit=False): x, target_rank, y = args[:3] dist.OOB.recv(target_rank) if RawHIPBuffer and x.__class__ is RawHIPBuffer: diff --git a/test/external/external_copy_benchmark.py b/test/external/external_copy_benchmark.py deleted file mode 100644 index 351be42a66..0000000000 --- a/test/external/external_copy_benchmark.py +++ /dev/null @@ -1,26 +0,0 @@ -import unittest -from tinygrad.helpers import prod -from tinygrad.tensor import Tensor -from tinygrad.helpers import GlobalCounters -from tinygrad.jit import CacheCollector - -class TestCopy(unittest.TestCase): - def test_add1(self): - pts = [] - for i in range(16384, 16384*256, 16384): - t = Tensor.randn(i).realize() - CacheCollector.start() - t.assign(t+1).realize() - ji = CacheCollector.finish()[0] - GlobalCounters.reset() - def run(): return ji.prg(ji.rawbufs, force_wait=True) - ct = min([run() for _ in range(10)]) - mb = prod(t.shape)*t.dtype.itemsize*2*1e-6 - print(f"{mb*1e3:.2f} kB, {ct*1e3:.2f} ms, {mb/ct:.2f} MB/s") - pts.append((mb, mb/ct)) - from matplotlib import pyplot as plt - plt.plot([x[0] for x in pts], [x[1] for x in pts]) - plt.show() - -if __name__ == '__main__': - unittest.main() \ No newline at end of file diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 66f91d6827..f03fa24edb 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -21,7 +21,7 @@ def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None): if isinstance(device, Compiled): prg = device.to_program(lin) else: - prg = get_interpreted_fxn(device.fxn_for_op, device.from_underlying, lin.ast) + prg = get_interpreted_fxn(device.fxn_for_op, lin.ast) except: print(lin.ast) traceback.print_exc() @@ -29,7 +29,7 @@ def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None): return "COMPILE_ERROR" try: - prg.exec(rawbufs, var_vals, force_wait=True) + prg.exec(rawbufs, var_vals) except: print(lin.ast) traceback.print_exc() diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 890c80e271..86060734bd 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -349,13 +349,13 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False): k.apply_opt(opt) prg = to_prg(k) real_bufs[0] = real_bufs[0].fromCPU(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np)) # Zero to check that all values are filled - prg.exec(real_bufs, force_wait=True) + prg.exec(real_bufs) np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4) # Get baseline, which is not optimized at all. k = Linearizer(realized_ast) prg = Device[Device.DEFAULT].to_program(k) - prg.exec(real_bufs, force_wait=True) + prg.exec(real_bufs) wanna_output = real_bufs[0].toCPU().copy() # Check correctness of handcoded optimiztions. @@ -363,7 +363,7 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False): k.hand_coded_optimizations() prg = Device[Device.DEFAULT].to_program(k) real_bufs[0] = real_bufs[0].fromCPU(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np)) # Zero to check that all values are filled - prg.exec(real_bufs, force_wait=True) + prg.exec(real_bufs) np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4) for x in opts: # Check custom transformations if any. check_opt(x, lambda: Linearizer(realized_ast), Device[Device.DEFAULT].to_program) diff --git a/test/test_uops.py b/test/test_uops.py index bbe2975656..8749e89123 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -26,7 +26,7 @@ def _test_single_value(vals, op, dtype): buf = Device[Device.DEFAULT].buffer(1, dtype) buf2 = [Device[Device.DEFAULT].buffer.fromCPU(np.array([a], dtype=dtype.np)) for a in vals] prg = _uops_to_prg(uops) - prg([buf]+buf2) + prg.exec([buf]+buf2) return buf.toCPU()[0] def _test_single_value_const(vals, op, dtype): @@ -37,7 +37,7 @@ def _test_single_value_const(vals, op, dtype): uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) buf = Device[Device.DEFAULT].buffer(1, dtype) prg = _uops_to_prg(uops) - prg([buf]) + prg.exec([buf]) return buf.toCPU()[0] class TestUOps(unittest.TestCase): diff --git a/tinygrad/jit.py b/tinygrad/jit.py index 93ea6400ad..d6a60092c4 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -1,32 +1,46 @@ from __future__ import annotations from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar, Generic -import functools, itertools -from tinygrad.helpers import DEBUG, DType, merge_dicts -from tinygrad.ops import RawBuffer, Device, ASTRunner, BatchExecutor, JitItem +import functools, itertools, operator +from tinygrad.helpers import DEBUG, DType, merge_dicts, getenv +from tinygrad.ops import RawBuffer, Device, JITRunner from tinygrad.tensor import Tensor from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.shape.symbolic import Variable +from tinygrad.shape.symbolic import Variable, NumNode, Node from weakref import ref, WeakKeyDictionary +from dataclasses import dataclass + +@dataclass(frozen=True) +class JitItem: + prg: JITRunner + rawbufs: List[Optional[RawBuffer]] + +def get_jit_stats(jit_cache: List[JitItem]) -> Tuple[Node, Node]: + return functools.reduce(operator.__add__, [ji.prg.op_estimate for ji in jit_cache], NumNode(0)), functools.reduce(operator.__add__, [ji.prg.mem_estimate for ji in jit_cache], NumNode(0)) +def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[RawBuffer]) -> Dict[Tuple[int, int], int]: + input_replace: Dict[Tuple[int, int], int] = {} + for j,ji in enumerate(jit_cache): + for i,a in enumerate(ji.rawbufs): + if a in input_rawbuffers: + input_replace[(j,i)] = input_rawbuffers.index(a) + assert len(set(input_replace.values())) == len(input_rawbuffers), "some input tensors not found" + return input_replace + +class GraphException(Exception): pass ReturnType = TypeVar('ReturnType') - class TinyJit(Generic[ReturnType]): def __init__(self, fxn:Callable[..., ReturnType]): self.fxn = fxn self.reset() def reset(self): - self.jit_fxn: Optional[BatchExecutor] = None + self.jit_cache: List[JitItem] = [] + self.input_replace: Dict[Tuple[int, int], int] = {} self.cnt: int = 0 self.ret: Optional[ReturnType] = None self.expected_vals: Optional[Tuple[Variable, ...]] = None self.expected_name_sts_dtype: Optional[Tuple[Tuple[Union[int, str], ShapeTracker, DType], ...]] = None - @property - def jit_cache(self) -> List[JitItem]: return self.jit_fxn.jit_cache if self.jit_fxn else [] - @property - def input_replace(self) -> Dict[Tuple[int, int], int]: return self.jit_fxn.input_replace if self.jit_fxn else {} - # add support for instance methods def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) @@ -44,23 +58,35 @@ class TinyJit(Generic[ReturnType]): expected_vals = tuple(var_vals.keys()) if self.cnt >= 2: + # jit exec assert self.expected_vals == expected_vals, "mismatch of var_vals" assert self.expected_name_sts_dtype == expected_name_sts_dtype, f"mismatch of sts, expected {self.expected_name_sts_dtype} got {expected_name_sts_dtype}" - assert self.jit_fxn, "didn't get jitted?" - self.jit_fxn(input_rawbuffers, var_vals, DEBUG>=2) + for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx] + for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), var_vals, jit=True) elif self.cnt == 1: + # jit capture self.expected_vals, self.expected_name_sts_dtype = expected_vals, expected_name_sts_dtype - CacheCollector.start(var_vals) self.ret = self.fxn(*args, **kwargs) - jit_cache = CacheCollector.finish() - assert len(jit_cache) != 0, "didn't JIT anything!" - if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_rawbuffers)} inputs") + self.jit_cache = CacheCollector.finish() + assert len(self.jit_cache) != 0, "didn't JIT anything!" + if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs") - self.jit_fxn = Device[Device.DEFAULT].batch_executor(jit_cache, input_rawbuffers, var_vals) + # if your Device supports it, condense the items into a graph executor + if (make_graph := Device[Device.DEFAULT].graph) and getenv("JIT") != 2: + try: + self.jit_cache = [JitItem(make_graph(self.jit_cache, input_rawbuffers, var_vals), cast(List[Optional[RawBuffer]], input_rawbuffers))] + except GraphException as e: + if DEBUG >= 1: print(f"graph create failed {e}") + + self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers) elif self.cnt == 0: + # jit ignore self.ret = self.fxn(*args, **kwargs) + # clear jit inputs + for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None + self.cnt += 1 return cast(ReturnType, self.ret) @@ -77,7 +103,7 @@ class PlaceHolder: class _CacheCollector: def __init__(self): - self.cache: Optional[List[Tuple[ASTRunner, List[Union[RawBuffer, PlaceHolder]]]]] = None + self.cache: Optional[List[Tuple[JITRunner, List[Union[RawBuffer, PlaceHolder]]]]] = None def start(self, var_vals:Optional[Dict[Variable, int]]=None): self.cache = [] @@ -88,7 +114,7 @@ class _CacheCollector: if self.cache is None: return for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}" self.placeholders[rawbufs[0]] = PlaceHolder(rawbufs[0]) # NOTE: this is making an assumption that 0 is special - self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(prg, ASTRunner) and isinstance(x, RawBuffer) else x for x in rawbufs])) + self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, RawBuffer) else x for x in rawbufs])) def finish(self) -> List[JitItem]: if self.cache is None: return [] diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 66dafe7da9..9c9c846e66 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,10 +1,10 @@ from __future__ import annotations import importlib, inspect, functools, pathlib, time, re from enum import Enum, auto -from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, cast +from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, Set from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT, dedup, all_int from tinygrad.runtime.lib import RawBuffer -from tinygrad.shape.symbolic import Variable, sym_infer, NumNode, sint +from tinygrad.shape.symbolic import Variable, sym_infer, sint from dataclasses import dataclass # these are the llops your accelerator must implement, along with toCpu @@ -15,7 +15,7 @@ class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto() class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702 class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702 class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 -class BufferOps(Enum): MEM = auto(); CONST = auto() # noqa: E702 +class BufferOps(Enum): MEM = auto(); CONST = auto(); FROM_UNDERLYING = auto() # noqa: E702 # Ops below this line are not allowed in ASTs class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702 class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702 @@ -144,91 +144,52 @@ def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Option GlobalCounters.global_mem += mem_estimate if et is not None: GlobalCounters.time_sum_s += et -# **************** batch executor **************** +# **************** shared AST runner **************** -@dataclass(frozen=True) -class JitItem: - prg: ASTRunner - rawbufs: List[Optional[RawBuffer]] - -class BatchExecutor: - def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int]): - self.jit_cache: List[JitItem] = jit_cache - self.input_replace: Dict[Tuple[int, int], int] = {} - self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0) - for j,ji in enumerate(jit_cache): - if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored - self.op_estimate += ji.prg.op_estimate - self.mem_estimate += ji.prg.mem_estimate - for i,a in enumerate(ji.rawbufs): - if a in input_rawbuffers: - self.input_replace[(j,i)] = input_rawbuffers.index(a) - assert len(set(self.input_replace.values())) == len(input_rawbuffers), "some input tensors not found" - self.clear_jit_inputs() - - def __call__(self, input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int], wait=False): - for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx] - for ji in self.jit_cache: ji.prg(ji.rawbufs, var_vals, jit=True) - self.clear_jit_inputs() - - def clear_jit_inputs(self): - for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None - -class ASTRunner: - def __init__(self, ast:Optional[LazyOp]): - if ast is None: - self.op_estimate, self.mem_estimate, self.vars = 0, 0, set() - else: - info = get_lazyop_info(ast) - self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate - from tinygrad.lazy import vars_from_ast - self.vars = vars_from_ast(ast) - assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}" - - def exec(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, force_wait=False) -> Optional[float]: +class JITRunner: + def __init__(self): + self.op_estimate, self.mem_estimate = 0, 0 + def exec(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]: + var_vals = var_vals if var_vals is not None else {} from tinygrad.jit import CacheCollector - et = self(rawbufs, var_vals, force_wait=force_wait) - CacheCollector.add(self, rawbufs, var_vals if var_vals is not None else {}) + et = self(rawbufs, var_vals) + CacheCollector.add(self, rawbufs, var_vals) return et - - def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]: + def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: raise NotImplementedError("override this") # **************** for Interpreted Buffers **************** -class InterpretedASTRunner(ASTRunner): +class InterpretedASTRunner(JITRunner): def __init__(self, ast:LazyOp, fxn:Callable): + super().__init__() self.fxn = fxn - super().__init__(ast) + info = get_lazyop_info(ast) + self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate - def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> float: - var_vals = {k:var_vals[k] for k in sorted(self.vars)} if var_vals is not None else {} + def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> float: st = time.perf_counter() ret: RawBuffer = self.fxn(rawbufs[1:], var_vals) et = time.perf_counter() - st update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit) - if rawbufs[0] is not None: - assert rawbufs[0].dtype == ret.dtype - rawbufs[0].size = ret.size # NOTE: for symbolic this can change - rawbufs[0]._buf = ret._buf - else: rawbufs[0] = ret + assert getattr(rawbufs[0], 'dtype', ret.dtype) == ret.dtype + rawbufs[0].dtype, rawbufs[0].size, rawbufs[0]._buf, rawbufs[0].offset = ret.dtype, ret.size, ret._buf, ret.offset return et class Interpreted: - def __init__(self, buffer: Type[RawBuffer], fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable]=None): - self.buffer, self.fxn_for_op, self.from_underlying = buffer, fxn_for_op, from_underlying - self.synchronize = lambda: None - self.batch_executor = BatchExecutor - self.codegen = None + def __init__(self, buffer: Type[RawBuffer], fxn_for_op:Dict[Op, Callable]): + self.buffer, self.fxn_for_op = buffer, fxn_for_op + self.synchronize, self.codegen, self.graph = lambda: None, None, None self.method_cache: Dict[LazyOp, InterpretedASTRunner] = {} def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs): - if ast not in self.method_cache: self.method_cache[ast] = get_interpreted_fxn(self.fxn_for_op, self.from_underlying, ast) + if ast not in self.method_cache: self.method_cache[ast] = get_interpreted_fxn(self.fxn_for_op, ast) rawbufs = [output.realized if output.realized is not None else output.output_buffer] + [x.realized for x in inputs] + if rawbufs[0] is None: rawbufs[0] = self.buffer.__new__(self.buffer) self.method_cache[ast].exec(rawbufs, var_vals) output.realized = rawbufs[0] -def get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable], ast:LazyOp) -> InterpretedASTRunner: +def get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> InterpretedASTRunner: if DEBUG >= 3: from tinygrad.graph import print_tree print_tree(ast) @@ -262,19 +223,26 @@ def get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], from_underlying:Optional[ return ret ret = _interpret_ast(ast) - src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(from_underlying, 'from_underlying')}({ret})" if from_underlying is not None else f" return {ret}"]) + src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(fxn_for_op[BufferOps.FROM_UNDERLYING], BufferOps.FROM_UNDERLYING)}({ret})" if BufferOps.FROM_UNDERLYING in fxn_for_op else f" return {ret}"]) if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src)) exec(compile(src, "", "exec"), tglob) # pylint: disable=exec-used return InterpretedASTRunner(ast, tglob['run']) # **************** for Compiled Buffers **************** -class CompiledASTRunner(ASTRunner): +class CompiledASTRunner(JITRunner): def __init__(self, ast:Optional[LazyOp], name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None): + super().__init__() if DEBUG >= 4: print(prg) self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = \ name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {} - super().__init__(ast) + self.vars: Set[Variable] = set() + if ast: + info = get_lazyop_info(ast) + self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate + from tinygrad.lazy import vars_from_ast + self.vars = vars_from_ast(ast) + assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}" def build(self, compiler, runtime): self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg) @@ -286,26 +254,25 @@ class CompiledASTRunner(ASTRunner): local_size = ([sym_infer(sz, var_vals) for sz in self.local_size] + [1]*(3-len(self.local_size))) if self.local_size is not None else self.local_size return global_size, local_size - def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]: + def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: # filter the var_vals - var_vals = {k:var_vals[k] for k in sorted(self.vars)} if var_vals is not None else {} + var_vals = {k:var_vals[k] for k in sorted(self.vars)} global_size, local_size = self.launch_dims(var_vals) if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type] # TODO: this is copied from get_program from tinygrad.features.search import optimize_local_size - local_size = self.local_size = optimize_local_size(self.clprg, global_size, cast(List[RawBuffer], rawbufs)) + local_size = self.local_size = optimize_local_size(self.clprg, global_size, rawbufs) global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)] lra = self.runtime_args.copy() if global_size: lra['global_size'] = global_size if local_size and 'local_size' not in lra: lra['local_size'] = local_size - et = self.clprg(*rawbufs, *var_vals.values(), **lra, wait=force_wait or DEBUG>=2) + et = self.clprg(*rawbufs, *var_vals.values(), **lra, wait=wait or DEBUG>=2) update_stats(self.display_name if self.display_name is not None else self.name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra) return et class Compiled: - def __init__(self, buffer: Type[RawBuffer], linearizer_opts:LinearizerOptions, renderer, compiler, runtime, synchronize=lambda: None, batch_executor=BatchExecutor): - self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize = buffer, linearizer_opts, renderer, compiler, runtime, synchronize - self.batch_executor = BatchExecutor if getenv("JIT") == 2 else batch_executor + def __init__(self, buffer: Type[RawBuffer], linearizer_opts:LinearizerOptions, renderer, compiler, runtime, synchronize=lambda: None, graph=None): + self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize, self.graph = buffer, linearizer_opts, renderer, compiler, runtime, synchronize, graph self.method_cache: Dict[LazyOp, CompiledASTRunner] = {} def to_program(self, k:Linearizer) -> CompiledASTRunner: diff --git a/tinygrad/runtime/lib.py b/tinygrad/runtime/lib.py index 4a6ebcdc3d..699f4530bd 100644 --- a/tinygrad/runtime/lib.py +++ b/tinygrad/runtime/lib.py @@ -10,6 +10,7 @@ class RawBuffer: # pylint: disable=abstract-method def __init__(self, size:int, dtype:DType, buf:Any=None, allocator:Any=None, **kwargs): self.size: int = size self.dtype: DType = dtype + self.offset: int = 0 # TODO: this is very unsupported, only in disk self._buf = buf if buf is not None else (allocator(size, dtype, **kwargs) if allocator else None) # If buf is provided, use it. Otherwise try to allocate from the allocator. self._memsz: int = size*dtype.itemsize self._allocator = allocator if allocator and hasattr(allocator, 'free') else None diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index 213bf18101..57c9b9263a 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -5,6 +5,12 @@ from tinygrad.helpers import dtypes, DType from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op, Interpreted from tinygrad.runtime.lib import RawBuffer +class RawNumpyBuffer(RawBuffer): + def __init__(self, size:int, dtype:DType, buf:Optional[np.ndarray]=None, allocator=lambda size, dtype: np.empty([size], dtype.np)): super().__init__(size, dtype, buf, allocator) + @classmethod + def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x) + def toCPU(self): return self._buf + def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]: assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions" return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b) @@ -32,7 +38,7 @@ def einsum_mulacc(einsum, get_strides, expand): return mulacc numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ - BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np), + BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np), BufferOps.FROM_UNDERLYING: RawNumpyBuffer.fromCPU, UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin, UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False), UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x), BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x Optional[float]: # NOTE: you at least can't update the ints if this is running if self.command_buffer is not None and self.command_buffer in METAL.mtl_buffers_in_flight: self.command_buffer.waitUntilCompleted() all_read_resources = self.read_resources + [x._buf for x in input_rawbuffers] @@ -151,7 +153,7 @@ class MetalBatchExecutor(BatchExecutor): else: METAL.mtl_buffers_in_flight.append(command_buffer) et = None - update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=True, num_kernels=len(self.jit_cache)) + update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache)) return et -MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, batch_executor=MetalBatchExecutor if METAL.supports_icb else BatchExecutor) +MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, graph=MetalGraph) diff --git a/tinygrad/runtime/ops_torch.py b/tinygrad/runtime/ops_torch.py index 910fb83d58..81dfb14a26 100644 --- a/tinygrad/runtime/ops_torch.py +++ b/tinygrad/runtime/ops_torch.py @@ -10,6 +10,14 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if geten type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool, torch.int16: dtypes.int16} inverse_type_map = {v:k for k,v in type_map.items()} +class RawTorchBuffer(RawBuffer): + def __init__(self, size:int, dtype:DType, buf:Optional[torch.Tensor]=None, allocator=lambda size, dtype: torch.empty([size], device=device, dtype=inverse_type_map[dtype])): super().__init__(size, dtype, buf, allocator) + @classmethod + def fromCPU(cls, x): + buf = torch.from_numpy(x if all(s>=0 for s in x.strides) else x.copy()).requires_grad_(False).to(device) + return cls(prod(x.shape), type_map[buf.dtype], buf) + def toCPU(self): return self._buf.cpu().numpy() + def output_type(x, y): return x.dtype if type_map[x.dtype].priority > type_map[y.dtype].priority else y.dtype def match_types(x, y, disallow_bool=False): up = output_type(x, y) @@ -26,6 +34,7 @@ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ # TODO: torch.tensor should work here #BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]), BufferOps.CONST: lambda val, dtype: torch.from_numpy(np.array(val, dtype=dtype.np)).requires_grad_(False).to(device), + BufferOps.FROM_UNDERLYING: lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x), UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.SIN: torch.sin, UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(next(k for k,v in type_map.items() if v==y[0])), UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x), BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: (x=0 for s in x.strides) else x.copy()).requires_grad_(False).to(device) - return cls(prod(x.shape), type_map[buf.dtype], buf) - def toCPU(self): return self._buf.cpu().numpy() -TorchBuffer = Interpreted(RawTorchBuffer, torch_fxn_for_op, lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x)) +TorchBuffer = Interpreted(RawTorchBuffer, torch_fxn_for_op)