diff --git a/test/helpers.py b/test/helpers.py index a0a6318315..5d91036af2 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -66,3 +66,6 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None): def not_support_multi_device(): # REMOTE doesn't support multi device anywhere, GPU, CUDA and METAL don't support multi device if in CI return Device.DEFAULT == "REMOTE" or (CI and Device.DEFAULT in ("GPU", "CUDA", "METAL")) + +# NOTE: This will open REMOTE if it's the default device +REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties['remotedev']) diff --git a/test/test_image_dtype.py b/test/test_image_dtype.py index 39a2973578..41b32de81f 100644 --- a/test/test_image_dtype.py +++ b/test/test_image_dtype.py @@ -5,9 +5,9 @@ from tinygrad.device import LRUAllocator, is_dtype_supported from tinygrad.dtype import ImageDType from tinygrad.engine.realize import lower_schedule from tinygrad.helpers import prod, unwrap +from test.helpers import REAL_DEV IMAGE_SUPPORTED_DEVICES = ("QCOM", "GPU") -REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties['remotedev']) @unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported") class TestImageCopy(unittest.TestCase): diff --git a/test/test_jit.py b/test/test_jit.py index 9282b0a0a2..1b30ffc3b7 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3,7 +3,7 @@ import unittest, functools import numpy as np from hypothesis import given, settings, strategies as strat -from test.helpers import assert_jit_cache_len, not_support_multi_device +from test.helpers import assert_jit_cache_len, not_support_multi_device, REAL_DEV from tinygrad.tensor import Tensor from tinygrad.engine.jit import TinyJit from tinygrad.device import Device @@ -22,7 +22,7 @@ def _simple_test(add, extract=lambda x: x, N=10): class TestJit(unittest.TestCase): @settings(deadline=2e4) - @unittest.skipUnless(Device.DEFAULT in ["LLVM", "CPU"], f"no support on {Device.DEFAULT}") + @unittest.skipUnless(REAL_DEV in ["LLVM", "CPU"], f"no support on {REAL_DEV}") @given(strat.sampled_from([Tensor.exp2, Tensor.log2, Tensor.sin])) def test_approx_jit_timeout(self, op): with Context(TRANSCENDENTAL=2): diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index e8c3675c13..a618fb3462 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -14,6 +14,8 @@ from weakref import WeakKeyDictionary class GraphException(Exception): pass +def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph + def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], max_batch_size=0) -> list[ExecItem]: # Split JIT cache into batches for faster graph execution. # This allows the accelerator to run some batches while subsequent graphs are still being updated. @@ -51,8 +53,7 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer] case ViewOp(): continue # ViewOps are just ignored case _: can_be_graphed = False # Everything else is not graphed and flushes existing graph if it's being constructed - graph_class = can_be_graphed and (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) - is_multigraph = can_be_graphed and issubclass(cast(type, graph_class), MultiGraphRunner) + is_multigraph = can_be_graphed and issubclass(graph_class(ji_graph_dev), MultiGraphRunner) can_share_graph = can_be_graphed and (type(ji_graph_dev) is type(current_device) if is_multigraph else ji_graph_dev == current_device) can_extend_graph_batch = can_share_graph and (max_batch_size == 0 or len(current_batch) < max_batch_size) if not can_extend_graph_batch and len(current_batch) > 0: flush_batch() diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 8423282a04..e8e10b19f2 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -38,12 +38,12 @@ class Runner: raise NotImplementedError("override this") class CompiledRunner(Runner): - def __init__(self, p:ProgramSpec, precompiled:Optional[bytes]=None): + def __init__(self, p:ProgramSpec, precompiled:Optional[bytes]=None, prg=None): if DEBUG >= 4: print(p.src) self.p:ProgramSpec = p self.lib:bytes = precompiled if precompiled is not None else Device[p.device].compiler.compile_cached(p.src) if DEBUG >= 7: Device[p.device].compiler.disassemble(self.lib) - self._prg = Device[p.device].runtime(p.function_name, self.lib) + self._prg = Device[p.device].runtime(p.function_name, self.lib) if prg is None else prg super().__init__(p.name, p.device, p.estimates) def __reduce__(self): return self.__class__, (self.p, self.lib) diff --git a/tinygrad/runtime/graph/remote.py b/tinygrad/runtime/graph/remote.py new file mode 100644 index 0000000000..2ae2f7f325 --- /dev/null +++ b/tinygrad/runtime/graph/remote.py @@ -0,0 +1,28 @@ +from tinygrad.ops import Variable +from tinygrad.engine.jit import GraphRunner +from tinygrad.engine.realize import CompiledRunner, ExecItem +from tinygrad.device import Device, Buffer +from tinygrad.runtime.ops_remote import GraphComputeItem, GraphAlloc, GraphFree, GraphExec +from tinygrad.helpers import unwrap, flatten, dedup, all_same + +class RemoteGraph(GraphRunner): + def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[Variable, int]): + super().__init__(jit_cache, rawbufs, var_vals) + self.devices = dedup(flatten([[Device[unwrap(buf).device] for buf in ji.bufs] for ji in jit_cache])) + assert all_same(self.devices), self.devices + self.iids = sorted(self.input_replace.values()) + def _process_ji(ji: ExecItem): + assert isinstance(ji.prg, CompiledRunner), f'Only compiled runners are supported: {ji.prg}' + return GraphComputeItem(ji.prg._prg.name, ji.prg._prg.datahash, tuple(unwrap(buf)._buf for buf in ji.bufs), tuple(ji.prg.p.vars), + tuple(ji.prg.p.global_size) if ji.prg.p.global_size is not None else None, + tuple(ji.prg.p.local_size) if ji.prg.p.local_size is not None else None) + self.graph_num = self.devices[0].graph_num + self.devices[0].graph_num += 1 + self.devices[0].req.q(GraphAlloc(self.graph_num, tuple(_process_ji(ji) for ji in jit_cache), tuple(rawbufs[i]._buf for i in self.iids), var_vals)) + + def __del__(self): + self.devices[0].req.q(GraphFree(self.graph_num)) + + def __call__(self, rawbufs: list[Buffer], var_vals: dict[Variable, int], wait=False): + self.devices[0].req.q(GraphExec(self.graph_num, tuple(rawbufs[i]._buf for i in self.iids), var_vals, wait)) + if wait: return float(self.devices[0].batch_submit()) diff --git a/tinygrad/runtime/ops_remote.py b/tinygrad/runtime/ops_remote.py index 6ec4d377b8..b4877cb93e 100644 --- a/tinygrad/runtime/ops_remote.py +++ b/tinygrad/runtime/ops_remote.py @@ -5,15 +5,19 @@ # it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC from __future__ import annotations -from typing import Optional, Any +from typing import Callable, Optional, Any, cast from collections import defaultdict from dataclasses import dataclass, field import multiprocessing, functools, http.client, hashlib, json, time, os, binascii, struct, ast, contextlib from http.server import HTTPServer, BaseHTTPRequestHandler -from tinygrad.renderer import Renderer -from tinygrad.dtype import dtypes +from tinygrad.renderer import Renderer, ProgramSpec +from tinygrad.dtype import DTYPES_DICT, dtypes +from tinygrad.ops import UOp, Ops, Variable, sint from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing +from tinygrad.engine.jit import GraphRunner, ExecItem, graph_class +from tinygrad.engine.realize import CompiledRunner from tinygrad.device import Compiled, Buffer, Allocator, Compiler, Device, BufferSpec +from tinygrad.runtime.graph.cpu import CPUGraph # ***** API ***** @@ -42,11 +46,44 @@ class ProgramExec(RemoteRequest): name: str; datahash: str; bufs: tuple[int, ...]; vals: tuple[int, ...] # noqa: E702 global_size: Optional[tuple[int, ...]]; local_size: Optional[tuple[int, ...]]; wait: bool # noqa: E702 +@dataclass(frozen=True) +class GraphComputeItem: + name: str + datahash: str + bufs: tuple[int, ...] + vars: tuple[Variable, ...] + global_size: tuple[sint, ...]|None + local_size: tuple[sint, ...]|None + +@dataclass(frozen=True) +class GraphAlloc(RemoteRequest): + graph_num: int + jit_cache: tuple[GraphComputeItem, ...] + bufs: tuple[int, ...] + var_vals: dict[Variable, int] + +@dataclass(frozen=True) +class GraphFree(RemoteRequest): + graph_num: int + +@dataclass(frozen=True) +class GraphExec(RemoteRequest): + graph_num: int + bufs: tuple[int, ...] + var_vals: dict[Variable, int] + wait: bool + # for safe deserialization -whitelist = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, BufferSpec]} +eval_globals = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, GraphComputeItem, + GraphAlloc, GraphFree, GraphExec, BufferSpec, UOp, Ops, dtypes]} +attribute_whitelist: dict[Any, set[str]] = {dtypes: {*DTYPES_DICT.keys(), 'imagef', 'imageh'}, Ops: {x.name for x in Ops}} eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)), + ast.Dict: lambda x: {safe_eval(k):safe_eval(v) for k,v in zip(x.keys, x.values)}, ast.Call: lambda x: safe_eval(x.func)(*[safe_eval(arg) for arg in x.args], **{kwarg.arg: safe_eval(kwarg.value) for kwarg in x.keywords}), - ast.Name: lambda x: whitelist[x.id], ast.Attribute: lambda x: {"imagef": dtypes.imagef, "imageh": dtypes.imageh}[x.attr]} + ast.Name: lambda x: eval_globals[x.id], ast.Attribute: lambda x: safe_getattr(safe_eval(x.value), x.attr)} +def safe_getattr(value, attr): + assert attr in attribute_whitelist.get(value, set()), f'getattr({value}, {repr(attr)}) is not whitelisted' + return getattr(value, attr) def safe_eval(node): return eval_fxns[node.__class__](node) class BatchRequest: @@ -75,6 +112,7 @@ class BatchRequest: @dataclass class RemoteSession: programs: dict[tuple[str, str], Any] = field(default_factory=dict) + graphs: dict[int, GraphRunner] = field(default_factory=dict) buffers: dict[int, Buffer] = field(default_factory=dict) class RemoteHandler(BaseHTTPRequestHandler): @@ -111,9 +149,25 @@ class RemoteHandler(BaseHTTPRequestHandler): extra_args = {k:v for k,v in [("global_size", c.global_size), ("local_size", c.local_size)] if v is not None} r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args) if r is not None: ret = str(r).encode() + case GraphAlloc(): + graph_fn: Callable = unwrap(Device[RemoteHandler.device].graph) + def _parse_ji(gi: GraphComputeItem): + prg = session.programs[(gi.name, gi.datahash)] + ps = ProgramSpec(gi.name, '', RemoteHandler.device, UOp(Ops.NOOP), vars=list(gi.vars), + global_size=list(cast(tuple[int], gi.global_size)) if gi.global_size is not None else None, + local_size=list(cast(tuple[int], gi.local_size)) if gi.local_size is not None else None) + return ExecItem(CompiledRunner(ps, precompiled=b'', prg=prg), [session.buffers[buf] for buf in gi.bufs]) + assert c.graph_num not in session.graphs, f"graph {c.graph_num} already allocated" + session.graphs[c.graph_num] = graph_fn(list(map(_parse_ji, c.jit_cache)), [session.buffers[buf] for buf in c.bufs], c.var_vals) + case GraphFree(): del session.graphs[c.graph_num] + case GraphExec(): + r = session.graphs[c.graph_num]([session.buffers[buf] for buf in c.bufs], c.var_vals, wait=c.wait) + if r is not None: ret = str(r).encode() elif self.path == "/properties" and method == "GET": cls, args = Device[RemoteHandler.device].renderer.__reduce__() - ret = json.dumps({'remotedev': RemoteHandler.device, 'renderer': (cls.__module__, cls.__name__, args)}).encode() + # CPUGraph re-renders kernel from uops specified in CompiledRunner, this is not supported + graph_cls = gt if (gt:=graph_class(Device[RemoteHandler.device])) is not CPUGraph else None + ret = json.dumps({'remotedev': RemoteHandler.device, 'renderer': (cls.__module__, cls.__name__, args), 'graph': graph_cls is not None}).encode() else: status_code = 404 self.send_response(status_code) self.send_header('Content-Length', str(len(ret))) @@ -170,7 +224,7 @@ class RemoteDevice(Compiled): # state for the connection self.session = binascii.hexlify(os.urandom(0x10)).decode() - self.buffer_num = 0 + self.buffer_num, self.graph_num = 0, 0 self.req: BatchRequest = BatchRequest() if DEBUG >= 1: print(f"remote with host {self.host}") @@ -188,7 +242,8 @@ class RemoteDevice(Compiled): if not renderer[0].startswith("tinygrad.renderer.") or not renderer[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {renderer}") renderer_class = fromimport(renderer[0], renderer[1]) # TODO: is this secure? if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {renderer}") - super().__init__(device, RemoteAllocator(self), renderer_class(*renderer[2]), Compiler(), functools.partial(RemoteProgram, self)) + graph = fromimport('tinygrad.runtime.graph.remote', 'RemoteGraph') if self.properties['graph'] else None + super().__init__(device, RemoteAllocator(self), renderer_class(*renderer[2]), Compiler(), functools.partial(RemoteProgram, self), graph) def __del__(self): # TODO: this is never being called