From ddff9857b8527c16e5917eeec2dc55da69fa74d2 Mon Sep 17 00:00:00 2001 From: uuuvn <83587632+uuuvn@users.noreply.github.com> Date: Tue, 13 May 2025 23:56:58 +0500 Subject: [PATCH] Remote properties is a dataclass (#10283) Not strictly required for anything but soon there will be like 4 new properties and having it be a huge json just seems like a bad taste. It also seems right to not have a separate endpoint for this, just `GetProperties` request that returns a repr of this similar to how requests are sent in `BatchRequest`. This will also make a switch to anything other than http much simpler if it will be required for any reason, like just a tcp stream of `BatchRequest`s --- .github/workflows/test.yml | 2 +- test/helpers.py | 2 +- tinygrad/runtime/ops_remote.py | 46 +++++++++++++++++++--------------- 3 files changed, 28 insertions(+), 22 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f75f5d4fdd..39520be7e2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -791,7 +791,7 @@ jobs: - name: Check Device.DEFAULT and print some source run: | python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT" - python -c "from tinygrad import Device; assert Device.default.properties['remotedev'] == 'METAL', Device.default.properties['remotedev']" + python -c "from tinygrad import Device; assert Device.default.properties.real_device == 'METAL', Device.default.properties.real_device" DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus - name: Run REMOTE=1 Test run: | diff --git a/test/helpers.py b/test/helpers.py index 95143d18b7..ba821ce784 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -68,4 +68,4 @@ def not_support_multi_device(): return CI and REAL_DEV in ("GPU", "CUDA") # NOTE: This will open REMOTE if it's the default device -REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties['remotedev']) +REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties.real_device) diff --git a/tinygrad/runtime/ops_remote.py b/tinygrad/runtime/ops_remote.py index ec77d5cd21..113a62d0c0 100644 --- a/tinygrad/runtime/ops_remote.py +++ b/tinygrad/runtime/ops_remote.py @@ -8,7 +8,7 @@ from __future__ import annotations from typing import Callable, Optional, Any, cast from collections import defaultdict from dataclasses import dataclass, field, replace -import multiprocessing, functools, asyncio, http, http.client, hashlib, json, time, os, binascii, struct, ast, contextlib +import multiprocessing, functools, asyncio, http, http.client, hashlib, time, os, binascii, struct, ast, contextlib from tinygrad.renderer import Renderer, ProgramSpec from tinygrad.dtype import DTYPES_DICT, dtypes from tinygrad.ops import UOp, Ops, Variable, sint @@ -23,6 +23,15 @@ from tinygrad.runtime.graph.cpu import CPUGraph @dataclass(frozen=True) class RemoteRequest: session: tuple[str, int]|None = field(default=None, kw_only=True) +@dataclass(frozen=True) +class RemoteProperties: + real_device: str + renderer: tuple[str, str, tuple[Any, ...]] + graph_supported: bool + +@dataclass(frozen=True) +class GetProperties(RemoteRequest): pass + @dataclass(frozen=True) class BufferAlloc(RemoteRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702 @@ -74,8 +83,8 @@ class GraphExec(RemoteRequest): wait: bool # for safe deserialization -eval_globals = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, GraphComputeItem, - GraphAlloc, GraphFree, GraphExec, BufferSpec, UOp, Ops, dtypes]} +eval_globals = {x.__name__:x for x in [RemoteProperties, GetProperties, BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, + ProgramExec, GraphComputeItem, GraphAlloc, GraphFree, GraphExec, BufferSpec, UOp, Ops, dtypes]} attribute_whitelist: dict[Any, set[str]] = {dtypes: {*DTYPES_DICT.keys(), 'imagef', 'imageh'}, Ops: {x.name for x in Ops}} eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)), ast.Dict: lambda x: {safe_eval(k):safe_eval(v) for k,v in zip(x.keys, x.values)}, @@ -141,6 +150,11 @@ class RemoteHandler: if DEBUG >= 1: print(c) session, dev = self.sessions[unwrap(c.session)], Device[f"{self.base_device}:{unwrap(c.session)[1]}"] match c: + case GetProperties(): + cls, args = dev.renderer.__reduce__() + # CPUGraph re-renders kernel from uops specified in CompiledRunner, this is not supported + graph_cls = gt if (gt:=graph_class(Device[self.base_device])) is not CPUGraph else None + ret = repr(RemoteProperties(dev.device, (cls.__module__, cls.__name__, args), graph_cls is not None)).encode() case BufferAlloc(): assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated" session.buffers[c.buffer_num] = Buffer(dev.device, c.size, dtypes.uint8, options=c.options, preallocate=True) @@ -170,11 +184,6 @@ class RemoteHandler: case GraphExec(): r = session.graphs[c.graph_num]([session.buffers[buf] for buf in c.bufs], c.var_vals, wait=c.wait) if r is not None: ret = str(r).encode() - elif path == "/properties" and method == "GET": - cls, args = Device[self.base_device].renderer.__reduce__() - # CPUGraph re-renders kernel from uops specified in CompiledRunner, this is not supported - graph_cls = gt if (gt:=graph_class(Device[self.base_device])) is not CPUGraph else None - ret = json.dumps({'remotedev': self.base_device, 'renderer': (cls.__module__, cls.__name__, args), 'graph': graph_cls is not None}).encode() else: status, ret = http.HTTPStatus.NOT_FOUND, b"Not Found" return status, ret @@ -231,18 +240,19 @@ class RemoteDevice(Compiled): while 1: try: self.conn = http.client.HTTPConnection(self.host, timeout=60.0) - self.properties = json.loads(self.send("GET", "properties").decode()) + self.q(GetProperties()) + self.properties = safe_eval(ast.parse(self.batch_submit(), mode="eval").body) break except Exception as e: print(e) time.sleep(0.1) - if DEBUG >= 1: print(f"remote has device {self.properties['remotedev']}") + if DEBUG >= 1: print(f"remote has device {self.properties.real_device}") # TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer - renderer = self.properties['renderer'] + renderer = self.properties.renderer if not renderer[0].startswith("tinygrad.renderer.") or not renderer[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {renderer}") renderer_class = fromimport(renderer[0], renderer[1]) # TODO: is this secure? if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {renderer}") - graph = fromimport('tinygrad.runtime.graph.remote', 'RemoteGraph') if self.properties['graph'] else None + graph = fromimport('tinygrad.runtime.graph.remote', 'RemoteGraph') if self.properties.graph_supported else None super().__init__(device, RemoteAllocator(self), renderer_class(*renderer[2]), Compiler(), functools.partial(RemoteProgram, self), graph) def __del__(self): @@ -261,15 +271,11 @@ class RemoteDevice(Compiled): def batch_submit(self): data = self.req.serialize() with Timing(f"*** send {len(self.req._q):-3d} requests {len(self.req._h):-3d} hashes with len {len(data)/1024:.2f} kB in ", enabled=DEBUG>=1): - ret = self.send("POST", "batch", data) + self.conn.request("POST", "/batch", data) + response = self.conn.getresponse() + assert response.status == 200, f"POST /batch failed: {response}" + ret = response.read() self.req = BatchRequest() return ret - def send(self, method, path, data:Optional[bytes]=None) -> bytes: - # TODO: retry logic - self.conn.request(method, "/"+path, data) - response = self.conn.getresponse() - assert response.status == 200, f"failed on {method} {path}" - return response.read() - if __name__ == "__main__": remote_server(getenv("PORT", 6667))