Use itertools.count instead of manual increment in remote (#10389)

Similar to how it's done with `UOp.unique_num`, looks a bit nicer
This commit is contained in:
uuuvn
2025-05-18 12:15:37 +05:00
committed by GitHub
parent 0294bfe507
commit f20c5aac1f
2 changed files with 9 additions and 12 deletions

View File

@@ -24,8 +24,7 @@ class RemoteGraph(GraphRunner):
assert dest is not None and src is not None, ji
return Transfer(session=cast(RemoteDevice, Device[dest.device]).session, buffer_num=dest._buf,
ssession=cast(RemoteDevice, Device[src.device]).session, sbuffer_num=src._buf)
self.graph_num = self.devices[0].graph_num
self.devices[0].graph_num += 1
self.graph_num = next(self.devices[0].graph_num)
self.devices[0].q(GraphAlloc(self.graph_num, tuple(_process_ji(ji) for ji in jit_cache), self.map_rawbufs(rawbufs), var_vals))
def __del__(self):

View File

@@ -5,10 +5,10 @@
# it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC
from __future__ import annotations
from typing import Callable, Optional, Any, cast
from typing import Callable, Iterator, Optional, Any, cast
from collections import defaultdict
from dataclasses import dataclass, field, replace
import multiprocessing, functools, asyncio, http, http.client, hashlib, time, os, binascii, struct, ast, contextlib, weakref
import multiprocessing, functools, itertools, asyncio, http, http.client, hashlib, time, os, binascii, struct, ast, contextlib, weakref
from tinygrad.renderer import Renderer, ProgramSpec
from tinygrad.dtype import DTYPES_DICT, dtypes
from tinygrad.ops import UOp, Ops, Variable, sint
@@ -240,9 +240,8 @@ class RemoteAllocator(Allocator['RemoteDevice']):
super().__init__(dev)
# TODO: ideally we shouldn't have to deal with images here
def _alloc(self, size:int, options:BufferSpec) -> int:
self.dev.buffer_num += 1
self.dev.q(BufferAlloc(self.dev.buffer_num, size, options))
return self.dev.buffer_num
self.dev.q(BufferAlloc(buffer_num:=next(self.dev.buffer_num), size, options))
return buffer_num
# TODO: options should not be here in any Allocator
def _free(self, opaque:int, options): self.dev.q(BufferFree(opaque))
def _copyin(self, dest:int, src:memoryview): self.dev.q(CopyIn(dest, self.dev.conn.req.h(bytes(src))))
@@ -257,9 +256,8 @@ class RemoteAllocator(Allocator['RemoteDevice']):
src_dev.allocator._copyout(tmp:=memoryview(bytearray(sz)), src)
dest_dev.allocator._copyin(dest, tmp)
def _dyn_offset(self, opaque:int, size:int, offset:int) -> int:
self.dev.buffer_num += 1
self.dev.q(BufferOffset(self.dev.buffer_num, size, offset, opaque))
return self.dev.buffer_num
self.dev.q(BufferOffset(buffer_num:=next(self.dev.buffer_num), size, offset, opaque))
return buffer_num
class RemoteProgram:
def __init__(self, dev:RemoteDevice, name:str, lib:bytes):
@@ -306,8 +304,8 @@ class RemoteDevice(Compiled):
# state for the connection
self.session = (binascii.hexlify(os.urandom(0x10)).decode(), int(device.split(":")[1]) if ":" in device else 0)
self.buffer_num: int = 0
self.graph_num: int = 0
self.buffer_num: Iterator[int] = itertools.count(0)
self.graph_num: Iterator[int] = itertools.count(0)
self.properties: RemoteProperties = safe_eval(ast.parse(self.q(GetProperties(), wait=True), mode="eval").body)
if DEBUG >= 1: print(f"remote has device {self.properties.real_device}")