mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
Cloud graph (#9876)
This commit is contained in:
@@ -66,3 +66,6 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None):
|
||||
def not_support_multi_device():
|
||||
# REMOTE doesn't support multi device anywhere, GPU, CUDA and METAL don't support multi device if in CI
|
||||
return Device.DEFAULT == "REMOTE" or (CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"))
|
||||
|
||||
# NOTE: This will open REMOTE if it's the default device
|
||||
REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties['remotedev'])
|
||||
|
||||
@@ -5,9 +5,9 @@ from tinygrad.device import LRUAllocator, is_dtype_supported
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from tinygrad.helpers import prod, unwrap
|
||||
from test.helpers import REAL_DEV
|
||||
|
||||
IMAGE_SUPPORTED_DEVICES = ("QCOM", "GPU")
|
||||
REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties['remotedev'])
|
||||
|
||||
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
|
||||
class TestImageCopy(unittest.TestCase):
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest, functools
|
||||
import numpy as np
|
||||
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
from test.helpers import assert_jit_cache_len, not_support_multi_device
|
||||
from test.helpers import assert_jit_cache_len, not_support_multi_device, REAL_DEV
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.device import Device
|
||||
@@ -22,7 +22,7 @@ def _simple_test(add, extract=lambda x: x, N=10):
|
||||
class TestJit(unittest.TestCase):
|
||||
|
||||
@settings(deadline=2e4)
|
||||
@unittest.skipUnless(Device.DEFAULT in ["LLVM", "CPU"], f"no support on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(REAL_DEV in ["LLVM", "CPU"], f"no support on {REAL_DEV}")
|
||||
@given(strat.sampled_from([Tensor.exp2, Tensor.log2, Tensor.sin]))
|
||||
def test_approx_jit_timeout(self, op):
|
||||
with Context(TRANSCENDENTAL=2):
|
||||
|
||||
@@ -14,6 +14,8 @@ from weakref import WeakKeyDictionary
|
||||
|
||||
class GraphException(Exception): pass
|
||||
|
||||
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
|
||||
|
||||
def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], max_batch_size=0) -> list[ExecItem]:
|
||||
# Split JIT cache into batches for faster graph execution.
|
||||
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
|
||||
@@ -51,8 +53,7 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer]
|
||||
case ViewOp(): continue # ViewOps are just ignored
|
||||
case _: can_be_graphed = False # Everything else is not graphed and flushes existing graph if it's being constructed
|
||||
|
||||
graph_class = can_be_graphed and (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph)
|
||||
is_multigraph = can_be_graphed and issubclass(cast(type, graph_class), MultiGraphRunner)
|
||||
is_multigraph = can_be_graphed and issubclass(graph_class(ji_graph_dev), MultiGraphRunner)
|
||||
can_share_graph = can_be_graphed and (type(ji_graph_dev) is type(current_device) if is_multigraph else ji_graph_dev == current_device)
|
||||
can_extend_graph_batch = can_share_graph and (max_batch_size == 0 or len(current_batch) < max_batch_size)
|
||||
if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
|
||||
|
||||
@@ -38,12 +38,12 @@ class Runner:
|
||||
raise NotImplementedError("override this")
|
||||
|
||||
class CompiledRunner(Runner):
|
||||
def __init__(self, p:ProgramSpec, precompiled:Optional[bytes]=None):
|
||||
def __init__(self, p:ProgramSpec, precompiled:Optional[bytes]=None, prg=None):
|
||||
if DEBUG >= 4: print(p.src)
|
||||
self.p:ProgramSpec = p
|
||||
self.lib:bytes = precompiled if precompiled is not None else Device[p.device].compiler.compile_cached(p.src)
|
||||
if DEBUG >= 7: Device[p.device].compiler.disassemble(self.lib)
|
||||
self._prg = Device[p.device].runtime(p.function_name, self.lib)
|
||||
self._prg = Device[p.device].runtime(p.function_name, self.lib) if prg is None else prg
|
||||
super().__init__(p.name, p.device, p.estimates)
|
||||
|
||||
def __reduce__(self): return self.__class__, (self.p, self.lib)
|
||||
|
||||
28
tinygrad/runtime/graph/remote.py
Normal file
28
tinygrad/runtime/graph/remote.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from tinygrad.ops import Variable
|
||||
from tinygrad.engine.jit import GraphRunner
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.runtime.ops_remote import GraphComputeItem, GraphAlloc, GraphFree, GraphExec
|
||||
from tinygrad.helpers import unwrap, flatten, dedup, all_same
|
||||
|
||||
class RemoteGraph(GraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[Variable, int]):
|
||||
super().__init__(jit_cache, rawbufs, var_vals)
|
||||
self.devices = dedup(flatten([[Device[unwrap(buf).device] for buf in ji.bufs] for ji in jit_cache]))
|
||||
assert all_same(self.devices), self.devices
|
||||
self.iids = sorted(self.input_replace.values())
|
||||
def _process_ji(ji: ExecItem):
|
||||
assert isinstance(ji.prg, CompiledRunner), f'Only compiled runners are supported: {ji.prg}'
|
||||
return GraphComputeItem(ji.prg._prg.name, ji.prg._prg.datahash, tuple(unwrap(buf)._buf for buf in ji.bufs), tuple(ji.prg.p.vars),
|
||||
tuple(ji.prg.p.global_size) if ji.prg.p.global_size is not None else None,
|
||||
tuple(ji.prg.p.local_size) if ji.prg.p.local_size is not None else None)
|
||||
self.graph_num = self.devices[0].graph_num
|
||||
self.devices[0].graph_num += 1
|
||||
self.devices[0].req.q(GraphAlloc(self.graph_num, tuple(_process_ji(ji) for ji in jit_cache), tuple(rawbufs[i]._buf for i in self.iids), var_vals))
|
||||
|
||||
def __del__(self):
|
||||
self.devices[0].req.q(GraphFree(self.graph_num))
|
||||
|
||||
def __call__(self, rawbufs: list[Buffer], var_vals: dict[Variable, int], wait=False):
|
||||
self.devices[0].req.q(GraphExec(self.graph_num, tuple(rawbufs[i]._buf for i in self.iids), var_vals, wait))
|
||||
if wait: return float(self.devices[0].batch_submit())
|
||||
@@ -5,15 +5,19 @@
|
||||
# it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Any
|
||||
from typing import Callable, Optional, Any, cast
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
import multiprocessing, functools, http.client, hashlib, json, time, os, binascii, struct, ast, contextlib
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.renderer import Renderer, ProgramSpec
|
||||
from tinygrad.dtype import DTYPES_DICT, dtypes
|
||||
from tinygrad.ops import UOp, Ops, Variable, sint
|
||||
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing
|
||||
from tinygrad.engine.jit import GraphRunner, ExecItem, graph_class
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.device import Compiled, Buffer, Allocator, Compiler, Device, BufferSpec
|
||||
from tinygrad.runtime.graph.cpu import CPUGraph
|
||||
|
||||
# ***** API *****
|
||||
|
||||
@@ -42,11 +46,44 @@ class ProgramExec(RemoteRequest):
|
||||
name: str; datahash: str; bufs: tuple[int, ...]; vals: tuple[int, ...] # noqa: E702
|
||||
global_size: Optional[tuple[int, ...]]; local_size: Optional[tuple[int, ...]]; wait: bool # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GraphComputeItem:
|
||||
name: str
|
||||
datahash: str
|
||||
bufs: tuple[int, ...]
|
||||
vars: tuple[Variable, ...]
|
||||
global_size: tuple[sint, ...]|None
|
||||
local_size: tuple[sint, ...]|None
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GraphAlloc(RemoteRequest):
|
||||
graph_num: int
|
||||
jit_cache: tuple[GraphComputeItem, ...]
|
||||
bufs: tuple[int, ...]
|
||||
var_vals: dict[Variable, int]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GraphFree(RemoteRequest):
|
||||
graph_num: int
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GraphExec(RemoteRequest):
|
||||
graph_num: int
|
||||
bufs: tuple[int, ...]
|
||||
var_vals: dict[Variable, int]
|
||||
wait: bool
|
||||
|
||||
# for safe deserialization
|
||||
whitelist = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, BufferSpec]}
|
||||
eval_globals = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, GraphComputeItem,
|
||||
GraphAlloc, GraphFree, GraphExec, BufferSpec, UOp, Ops, dtypes]}
|
||||
attribute_whitelist: dict[Any, set[str]] = {dtypes: {*DTYPES_DICT.keys(), 'imagef', 'imageh'}, Ops: {x.name for x in Ops}}
|
||||
eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
|
||||
ast.Dict: lambda x: {safe_eval(k):safe_eval(v) for k,v in zip(x.keys, x.values)},
|
||||
ast.Call: lambda x: safe_eval(x.func)(*[safe_eval(arg) for arg in x.args], **{kwarg.arg: safe_eval(kwarg.value) for kwarg in x.keywords}),
|
||||
ast.Name: lambda x: whitelist[x.id], ast.Attribute: lambda x: {"imagef": dtypes.imagef, "imageh": dtypes.imageh}[x.attr]}
|
||||
ast.Name: lambda x: eval_globals[x.id], ast.Attribute: lambda x: safe_getattr(safe_eval(x.value), x.attr)}
|
||||
def safe_getattr(value, attr):
|
||||
assert attr in attribute_whitelist.get(value, set()), f'getattr({value}, {repr(attr)}) is not whitelisted'
|
||||
return getattr(value, attr)
|
||||
def safe_eval(node): return eval_fxns[node.__class__](node)
|
||||
|
||||
class BatchRequest:
|
||||
@@ -75,6 +112,7 @@ class BatchRequest:
|
||||
@dataclass
|
||||
class RemoteSession:
|
||||
programs: dict[tuple[str, str], Any] = field(default_factory=dict)
|
||||
graphs: dict[int, GraphRunner] = field(default_factory=dict)
|
||||
buffers: dict[int, Buffer] = field(default_factory=dict)
|
||||
|
||||
class RemoteHandler(BaseHTTPRequestHandler):
|
||||
@@ -111,9 +149,25 @@ class RemoteHandler(BaseHTTPRequestHandler):
|
||||
extra_args = {k:v for k,v in [("global_size", c.global_size), ("local_size", c.local_size)] if v is not None}
|
||||
r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args)
|
||||
if r is not None: ret = str(r).encode()
|
||||
case GraphAlloc():
|
||||
graph_fn: Callable = unwrap(Device[RemoteHandler.device].graph)
|
||||
def _parse_ji(gi: GraphComputeItem):
|
||||
prg = session.programs[(gi.name, gi.datahash)]
|
||||
ps = ProgramSpec(gi.name, '', RemoteHandler.device, UOp(Ops.NOOP), vars=list(gi.vars),
|
||||
global_size=list(cast(tuple[int], gi.global_size)) if gi.global_size is not None else None,
|
||||
local_size=list(cast(tuple[int], gi.local_size)) if gi.local_size is not None else None)
|
||||
return ExecItem(CompiledRunner(ps, precompiled=b'', prg=prg), [session.buffers[buf] for buf in gi.bufs])
|
||||
assert c.graph_num not in session.graphs, f"graph {c.graph_num} already allocated"
|
||||
session.graphs[c.graph_num] = graph_fn(list(map(_parse_ji, c.jit_cache)), [session.buffers[buf] for buf in c.bufs], c.var_vals)
|
||||
case GraphFree(): del session.graphs[c.graph_num]
|
||||
case GraphExec():
|
||||
r = session.graphs[c.graph_num]([session.buffers[buf] for buf in c.bufs], c.var_vals, wait=c.wait)
|
||||
if r is not None: ret = str(r).encode()
|
||||
elif self.path == "/properties" and method == "GET":
|
||||
cls, args = Device[RemoteHandler.device].renderer.__reduce__()
|
||||
ret = json.dumps({'remotedev': RemoteHandler.device, 'renderer': (cls.__module__, cls.__name__, args)}).encode()
|
||||
# CPUGraph re-renders kernel from uops specified in CompiledRunner, this is not supported
|
||||
graph_cls = gt if (gt:=graph_class(Device[RemoteHandler.device])) is not CPUGraph else None
|
||||
ret = json.dumps({'remotedev': RemoteHandler.device, 'renderer': (cls.__module__, cls.__name__, args), 'graph': graph_cls is not None}).encode()
|
||||
else: status_code = 404
|
||||
self.send_response(status_code)
|
||||
self.send_header('Content-Length', str(len(ret)))
|
||||
@@ -170,7 +224,7 @@ class RemoteDevice(Compiled):
|
||||
|
||||
# state for the connection
|
||||
self.session = binascii.hexlify(os.urandom(0x10)).decode()
|
||||
self.buffer_num = 0
|
||||
self.buffer_num, self.graph_num = 0, 0
|
||||
self.req: BatchRequest = BatchRequest()
|
||||
|
||||
if DEBUG >= 1: print(f"remote with host {self.host}")
|
||||
@@ -188,7 +242,8 @@ class RemoteDevice(Compiled):
|
||||
if not renderer[0].startswith("tinygrad.renderer.") or not renderer[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {renderer}")
|
||||
renderer_class = fromimport(renderer[0], renderer[1]) # TODO: is this secure?
|
||||
if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {renderer}")
|
||||
super().__init__(device, RemoteAllocator(self), renderer_class(*renderer[2]), Compiler(), functools.partial(RemoteProgram, self))
|
||||
graph = fromimport('tinygrad.runtime.graph.remote', 'RemoteGraph') if self.properties['graph'] else None
|
||||
super().__init__(device, RemoteAllocator(self), renderer_class(*renderer[2]), Compiler(), functools.partial(RemoteProgram, self), graph)
|
||||
|
||||
def __del__(self):
|
||||
# TODO: this is never being called
|
||||
|
||||
Reference in New Issue
Block a user