From f20c5aac1f87169f75b738d5cd23a052471bdcb8 Mon Sep 17 00:00:00 2001 From: uuuvn <83587632+uuuvn@users.noreply.github.com> Date: Sun, 18 May 2025 12:15:37 +0500 Subject: [PATCH] Use `itertools.count` instead of manual increment in remote (#10389) Similar to how it's done with `UOp.unique_num`, looks a bit nicer --- tinygrad/runtime/graph/remote.py | 3 +-- tinygrad/runtime/ops_remote.py | 18 ++++++++---------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/tinygrad/runtime/graph/remote.py b/tinygrad/runtime/graph/remote.py index cd6f41d2b8..7926e23946 100644 --- a/tinygrad/runtime/graph/remote.py +++ b/tinygrad/runtime/graph/remote.py @@ -24,8 +24,7 @@ class RemoteGraph(GraphRunner): assert dest is not None and src is not None, ji return Transfer(session=cast(RemoteDevice, Device[dest.device]).session, buffer_num=dest._buf, ssession=cast(RemoteDevice, Device[src.device]).session, sbuffer_num=src._buf) - self.graph_num = self.devices[0].graph_num - self.devices[0].graph_num += 1 + self.graph_num = next(self.devices[0].graph_num) self.devices[0].q(GraphAlloc(self.graph_num, tuple(_process_ji(ji) for ji in jit_cache), self.map_rawbufs(rawbufs), var_vals)) def __del__(self): diff --git a/tinygrad/runtime/ops_remote.py b/tinygrad/runtime/ops_remote.py index 35a2dc7db4..afc75551d9 100644 --- a/tinygrad/runtime/ops_remote.py +++ b/tinygrad/runtime/ops_remote.py @@ -5,10 +5,10 @@ # it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC from __future__ import annotations -from typing import Callable, Optional, Any, cast +from typing import Callable, Iterator, Optional, Any, cast from collections import defaultdict from dataclasses import dataclass, field, replace -import multiprocessing, functools, asyncio, http, http.client, hashlib, time, os, binascii, struct, ast, contextlib, weakref +import multiprocessing, functools, itertools, asyncio, http, http.client, hashlib, time, os, binascii, struct, ast, contextlib, weakref from tinygrad.renderer import Renderer, ProgramSpec from tinygrad.dtype import DTYPES_DICT, dtypes from tinygrad.ops import UOp, Ops, Variable, sint @@ -240,9 +240,8 @@ class RemoteAllocator(Allocator['RemoteDevice']): super().__init__(dev) # TODO: ideally we shouldn't have to deal with images here def _alloc(self, size:int, options:BufferSpec) -> int: - self.dev.buffer_num += 1 - self.dev.q(BufferAlloc(self.dev.buffer_num, size, options)) - return self.dev.buffer_num + self.dev.q(BufferAlloc(buffer_num:=next(self.dev.buffer_num), size, options)) + return buffer_num # TODO: options should not be here in any Allocator def _free(self, opaque:int, options): self.dev.q(BufferFree(opaque)) def _copyin(self, dest:int, src:memoryview): self.dev.q(CopyIn(dest, self.dev.conn.req.h(bytes(src)))) @@ -257,9 +256,8 @@ class RemoteAllocator(Allocator['RemoteDevice']): src_dev.allocator._copyout(tmp:=memoryview(bytearray(sz)), src) dest_dev.allocator._copyin(dest, tmp) def _dyn_offset(self, opaque:int, size:int, offset:int) -> int: - self.dev.buffer_num += 1 - self.dev.q(BufferOffset(self.dev.buffer_num, size, offset, opaque)) - return self.dev.buffer_num + self.dev.q(BufferOffset(buffer_num:=next(self.dev.buffer_num), size, offset, opaque)) + return buffer_num class RemoteProgram: def __init__(self, dev:RemoteDevice, name:str, lib:bytes): @@ -306,8 +304,8 @@ class RemoteDevice(Compiled): # state for the connection self.session = (binascii.hexlify(os.urandom(0x10)).decode(), int(device.split(":")[1]) if ":" in device else 0) - self.buffer_num: int = 0 - self.graph_num: int = 0 + self.buffer_num: Iterator[int] = itertools.count(0) + self.graph_num: Iterator[int] = itertools.count(0) self.properties: RemoteProperties = safe_eval(ast.parse(self.q(GetProperties(), wait=True), mode="eval").body) if DEBUG >= 1: print(f"remote has device {self.properties.real_device}")