mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
Use itertools.count instead of manual increment in remote (#10389)
Similar to how it's done with `UOp.unique_num`, looks a bit nicer
This commit is contained in:
@@ -24,8 +24,7 @@ class RemoteGraph(GraphRunner):
|
||||
assert dest is not None and src is not None, ji
|
||||
return Transfer(session=cast(RemoteDevice, Device[dest.device]).session, buffer_num=dest._buf,
|
||||
ssession=cast(RemoteDevice, Device[src.device]).session, sbuffer_num=src._buf)
|
||||
self.graph_num = self.devices[0].graph_num
|
||||
self.devices[0].graph_num += 1
|
||||
self.graph_num = next(self.devices[0].graph_num)
|
||||
self.devices[0].q(GraphAlloc(self.graph_num, tuple(_process_ji(ji) for ji in jit_cache), self.map_rawbufs(rawbufs), var_vals))
|
||||
|
||||
def __del__(self):
|
||||
|
||||
@@ -5,10 +5,10 @@
|
||||
# it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Callable, Optional, Any, cast
|
||||
from typing import Callable, Iterator, Optional, Any, cast
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field, replace
|
||||
import multiprocessing, functools, asyncio, http, http.client, hashlib, time, os, binascii, struct, ast, contextlib, weakref
|
||||
import multiprocessing, functools, itertools, asyncio, http, http.client, hashlib, time, os, binascii, struct, ast, contextlib, weakref
|
||||
from tinygrad.renderer import Renderer, ProgramSpec
|
||||
from tinygrad.dtype import DTYPES_DICT, dtypes
|
||||
from tinygrad.ops import UOp, Ops, Variable, sint
|
||||
@@ -240,9 +240,8 @@ class RemoteAllocator(Allocator['RemoteDevice']):
|
||||
super().__init__(dev)
|
||||
# TODO: ideally we shouldn't have to deal with images here
|
||||
def _alloc(self, size:int, options:BufferSpec) -> int:
|
||||
self.dev.buffer_num += 1
|
||||
self.dev.q(BufferAlloc(self.dev.buffer_num, size, options))
|
||||
return self.dev.buffer_num
|
||||
self.dev.q(BufferAlloc(buffer_num:=next(self.dev.buffer_num), size, options))
|
||||
return buffer_num
|
||||
# TODO: options should not be here in any Allocator
|
||||
def _free(self, opaque:int, options): self.dev.q(BufferFree(opaque))
|
||||
def _copyin(self, dest:int, src:memoryview): self.dev.q(CopyIn(dest, self.dev.conn.req.h(bytes(src))))
|
||||
@@ -257,9 +256,8 @@ class RemoteAllocator(Allocator['RemoteDevice']):
|
||||
src_dev.allocator._copyout(tmp:=memoryview(bytearray(sz)), src)
|
||||
dest_dev.allocator._copyin(dest, tmp)
|
||||
def _dyn_offset(self, opaque:int, size:int, offset:int) -> int:
|
||||
self.dev.buffer_num += 1
|
||||
self.dev.q(BufferOffset(self.dev.buffer_num, size, offset, opaque))
|
||||
return self.dev.buffer_num
|
||||
self.dev.q(BufferOffset(buffer_num:=next(self.dev.buffer_num), size, offset, opaque))
|
||||
return buffer_num
|
||||
|
||||
class RemoteProgram:
|
||||
def __init__(self, dev:RemoteDevice, name:str, lib:bytes):
|
||||
@@ -306,8 +304,8 @@ class RemoteDevice(Compiled):
|
||||
|
||||
# state for the connection
|
||||
self.session = (binascii.hexlify(os.urandom(0x10)).decode(), int(device.split(":")[1]) if ":" in device else 0)
|
||||
self.buffer_num: int = 0
|
||||
self.graph_num: int = 0
|
||||
self.buffer_num: Iterator[int] = itertools.count(0)
|
||||
self.graph_num: Iterator[int] = itertools.count(0)
|
||||
|
||||
self.properties: RemoteProperties = safe_eval(ast.parse(self.q(GetProperties(), wait=True), mode="eval").body)
|
||||
if DEBUG >= 1: print(f"remote has device {self.properties.real_device}")
|
||||
|
||||
Reference in New Issue
Block a user