Cloud graph (#9876)

This commit is contained in:
uuuvn
2025-05-07 23:41:41 +05:00
committed by GitHub
parent 2891892834
commit 10c9ede6b7
7 changed files with 102 additions and 15 deletions

View File

@@ -66,3 +66,6 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None):
def not_support_multi_device():
# REMOTE doesn't support multi device anywhere, GPU, CUDA and METAL don't support multi device if in CI
return Device.DEFAULT == "REMOTE" or (CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"))
# NOTE: This will open REMOTE if it's the default device
REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties['remotedev'])

View File

@@ -5,9 +5,9 @@ from tinygrad.device import LRUAllocator, is_dtype_supported
from tinygrad.dtype import ImageDType
from tinygrad.engine.realize import lower_schedule
from tinygrad.helpers import prod, unwrap
from test.helpers import REAL_DEV
IMAGE_SUPPORTED_DEVICES = ("QCOM", "GPU")
REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties['remotedev'])
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
class TestImageCopy(unittest.TestCase):

View File

@@ -3,7 +3,7 @@ import unittest, functools
import numpy as np
from hypothesis import given, settings, strategies as strat
from test.helpers import assert_jit_cache_len, not_support_multi_device
from test.helpers import assert_jit_cache_len, not_support_multi_device, REAL_DEV
from tinygrad.tensor import Tensor
from tinygrad.engine.jit import TinyJit
from tinygrad.device import Device
@@ -22,7 +22,7 @@ def _simple_test(add, extract=lambda x: x, N=10):
class TestJit(unittest.TestCase):
@settings(deadline=2e4)
@unittest.skipUnless(Device.DEFAULT in ["LLVM", "CPU"], f"no support on {Device.DEFAULT}")
@unittest.skipUnless(REAL_DEV in ["LLVM", "CPU"], f"no support on {REAL_DEV}")
@given(strat.sampled_from([Tensor.exp2, Tensor.log2, Tensor.sin]))
def test_approx_jit_timeout(self, op):
with Context(TRANSCENDENTAL=2):

View File

@@ -14,6 +14,8 @@ from weakref import WeakKeyDictionary
class GraphException(Exception): pass
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], max_batch_size=0) -> list[ExecItem]:
# Split JIT cache into batches for faster graph execution.
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
@@ -51,8 +53,7 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer]
case ViewOp(): continue # ViewOps are just ignored
case _: can_be_graphed = False # Everything else is not graphed and flushes existing graph if it's being constructed
graph_class = can_be_graphed and (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph)
is_multigraph = can_be_graphed and issubclass(cast(type, graph_class), MultiGraphRunner)
is_multigraph = can_be_graphed and issubclass(graph_class(ji_graph_dev), MultiGraphRunner)
can_share_graph = can_be_graphed and (type(ji_graph_dev) is type(current_device) if is_multigraph else ji_graph_dev == current_device)
can_extend_graph_batch = can_share_graph and (max_batch_size == 0 or len(current_batch) < max_batch_size)
if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()

View File

@@ -38,12 +38,12 @@ class Runner:
raise NotImplementedError("override this")
class CompiledRunner(Runner):
def __init__(self, p:ProgramSpec, precompiled:Optional[bytes]=None):
def __init__(self, p:ProgramSpec, precompiled:Optional[bytes]=None, prg=None):
if DEBUG >= 4: print(p.src)
self.p:ProgramSpec = p
self.lib:bytes = precompiled if precompiled is not None else Device[p.device].compiler.compile_cached(p.src)
if DEBUG >= 7: Device[p.device].compiler.disassemble(self.lib)
self._prg = Device[p.device].runtime(p.function_name, self.lib)
self._prg = Device[p.device].runtime(p.function_name, self.lib) if prg is None else prg
super().__init__(p.name, p.device, p.estimates)
def __reduce__(self): return self.__class__, (self.p, self.lib)

View File

@@ -0,0 +1,28 @@
from tinygrad.ops import Variable
from tinygrad.engine.jit import GraphRunner
from tinygrad.engine.realize import CompiledRunner, ExecItem
from tinygrad.device import Device, Buffer
from tinygrad.runtime.ops_remote import GraphComputeItem, GraphAlloc, GraphFree, GraphExec
from tinygrad.helpers import unwrap, flatten, dedup, all_same
class RemoteGraph(GraphRunner):
def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[Variable, int]):
super().__init__(jit_cache, rawbufs, var_vals)
self.devices = dedup(flatten([[Device[unwrap(buf).device] for buf in ji.bufs] for ji in jit_cache]))
assert all_same(self.devices), self.devices
self.iids = sorted(self.input_replace.values())
def _process_ji(ji: ExecItem):
assert isinstance(ji.prg, CompiledRunner), f'Only compiled runners are supported: {ji.prg}'
return GraphComputeItem(ji.prg._prg.name, ji.prg._prg.datahash, tuple(unwrap(buf)._buf for buf in ji.bufs), tuple(ji.prg.p.vars),
tuple(ji.prg.p.global_size) if ji.prg.p.global_size is not None else None,
tuple(ji.prg.p.local_size) if ji.prg.p.local_size is not None else None)
self.graph_num = self.devices[0].graph_num
self.devices[0].graph_num += 1
self.devices[0].req.q(GraphAlloc(self.graph_num, tuple(_process_ji(ji) for ji in jit_cache), tuple(rawbufs[i]._buf for i in self.iids), var_vals))
def __del__(self):
self.devices[0].req.q(GraphFree(self.graph_num))
def __call__(self, rawbufs: list[Buffer], var_vals: dict[Variable, int], wait=False):
self.devices[0].req.q(GraphExec(self.graph_num, tuple(rawbufs[i]._buf for i in self.iids), var_vals, wait))
if wait: return float(self.devices[0].batch_submit())

View File

@@ -5,15 +5,19 @@
# it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC
from __future__ import annotations
from typing import Optional, Any
from typing import Callable, Optional, Any, cast
from collections import defaultdict
from dataclasses import dataclass, field
import multiprocessing, functools, http.client, hashlib, json, time, os, binascii, struct, ast, contextlib
from http.server import HTTPServer, BaseHTTPRequestHandler
from tinygrad.renderer import Renderer
from tinygrad.dtype import dtypes
from tinygrad.renderer import Renderer, ProgramSpec
from tinygrad.dtype import DTYPES_DICT, dtypes
from tinygrad.ops import UOp, Ops, Variable, sint
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing
from tinygrad.engine.jit import GraphRunner, ExecItem, graph_class
from tinygrad.engine.realize import CompiledRunner
from tinygrad.device import Compiled, Buffer, Allocator, Compiler, Device, BufferSpec
from tinygrad.runtime.graph.cpu import CPUGraph
# ***** API *****
@@ -42,11 +46,44 @@ class ProgramExec(RemoteRequest):
name: str; datahash: str; bufs: tuple[int, ...]; vals: tuple[int, ...] # noqa: E702
global_size: Optional[tuple[int, ...]]; local_size: Optional[tuple[int, ...]]; wait: bool # noqa: E702
@dataclass(frozen=True)
class GraphComputeItem:
name: str
datahash: str
bufs: tuple[int, ...]
vars: tuple[Variable, ...]
global_size: tuple[sint, ...]|None
local_size: tuple[sint, ...]|None
@dataclass(frozen=True)
class GraphAlloc(RemoteRequest):
graph_num: int
jit_cache: tuple[GraphComputeItem, ...]
bufs: tuple[int, ...]
var_vals: dict[Variable, int]
@dataclass(frozen=True)
class GraphFree(RemoteRequest):
graph_num: int
@dataclass(frozen=True)
class GraphExec(RemoteRequest):
graph_num: int
bufs: tuple[int, ...]
var_vals: dict[Variable, int]
wait: bool
# for safe deserialization
whitelist = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, BufferSpec]}
eval_globals = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, GraphComputeItem,
GraphAlloc, GraphFree, GraphExec, BufferSpec, UOp, Ops, dtypes]}
attribute_whitelist: dict[Any, set[str]] = {dtypes: {*DTYPES_DICT.keys(), 'imagef', 'imageh'}, Ops: {x.name for x in Ops}}
eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
ast.Dict: lambda x: {safe_eval(k):safe_eval(v) for k,v in zip(x.keys, x.values)},
ast.Call: lambda x: safe_eval(x.func)(*[safe_eval(arg) for arg in x.args], **{kwarg.arg: safe_eval(kwarg.value) for kwarg in x.keywords}),
ast.Name: lambda x: whitelist[x.id], ast.Attribute: lambda x: {"imagef": dtypes.imagef, "imageh": dtypes.imageh}[x.attr]}
ast.Name: lambda x: eval_globals[x.id], ast.Attribute: lambda x: safe_getattr(safe_eval(x.value), x.attr)}
def safe_getattr(value, attr):
assert attr in attribute_whitelist.get(value, set()), f'getattr({value}, {repr(attr)}) is not whitelisted'
return getattr(value, attr)
def safe_eval(node): return eval_fxns[node.__class__](node)
class BatchRequest:
@@ -75,6 +112,7 @@ class BatchRequest:
@dataclass
class RemoteSession:
programs: dict[tuple[str, str], Any] = field(default_factory=dict)
graphs: dict[int, GraphRunner] = field(default_factory=dict)
buffers: dict[int, Buffer] = field(default_factory=dict)
class RemoteHandler(BaseHTTPRequestHandler):
@@ -111,9 +149,25 @@ class RemoteHandler(BaseHTTPRequestHandler):
extra_args = {k:v for k,v in [("global_size", c.global_size), ("local_size", c.local_size)] if v is not None}
r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args)
if r is not None: ret = str(r).encode()
case GraphAlloc():
graph_fn: Callable = unwrap(Device[RemoteHandler.device].graph)
def _parse_ji(gi: GraphComputeItem):
prg = session.programs[(gi.name, gi.datahash)]
ps = ProgramSpec(gi.name, '', RemoteHandler.device, UOp(Ops.NOOP), vars=list(gi.vars),
global_size=list(cast(tuple[int], gi.global_size)) if gi.global_size is not None else None,
local_size=list(cast(tuple[int], gi.local_size)) if gi.local_size is not None else None)
return ExecItem(CompiledRunner(ps, precompiled=b'', prg=prg), [session.buffers[buf] for buf in gi.bufs])
assert c.graph_num not in session.graphs, f"graph {c.graph_num} already allocated"
session.graphs[c.graph_num] = graph_fn(list(map(_parse_ji, c.jit_cache)), [session.buffers[buf] for buf in c.bufs], c.var_vals)
case GraphFree(): del session.graphs[c.graph_num]
case GraphExec():
r = session.graphs[c.graph_num]([session.buffers[buf] for buf in c.bufs], c.var_vals, wait=c.wait)
if r is not None: ret = str(r).encode()
elif self.path == "/properties" and method == "GET":
cls, args = Device[RemoteHandler.device].renderer.__reduce__()
ret = json.dumps({'remotedev': RemoteHandler.device, 'renderer': (cls.__module__, cls.__name__, args)}).encode()
# CPUGraph re-renders kernel from uops specified in CompiledRunner, this is not supported
graph_cls = gt if (gt:=graph_class(Device[RemoteHandler.device])) is not CPUGraph else None
ret = json.dumps({'remotedev': RemoteHandler.device, 'renderer': (cls.__module__, cls.__name__, args), 'graph': graph_cls is not None}).encode()
else: status_code = 404
self.send_response(status_code)
self.send_header('Content-Length', str(len(ret)))
@@ -170,7 +224,7 @@ class RemoteDevice(Compiled):
# state for the connection
self.session = binascii.hexlify(os.urandom(0x10)).decode()
self.buffer_num = 0
self.buffer_num, self.graph_num = 0, 0
self.req: BatchRequest = BatchRequest()
if DEBUG >= 1: print(f"remote with host {self.host}")
@@ -188,7 +242,8 @@ class RemoteDevice(Compiled):
if not renderer[0].startswith("tinygrad.renderer.") or not renderer[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {renderer}")
renderer_class = fromimport(renderer[0], renderer[1]) # TODO: is this secure?
if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {renderer}")
super().__init__(device, RemoteAllocator(self), renderer_class(*renderer[2]), Compiler(), functools.partial(RemoteProgram, self))
graph = fromimport('tinygrad.runtime.graph.remote', 'RemoteGraph') if self.properties['graph'] else None
super().__init__(device, RemoteAllocator(self), renderer_class(*renderer[2]), Compiler(), functools.partial(RemoteProgram, self), graph)
def __del__(self):
# TODO: this is never being called