diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1ad7ccf123..f75f5d4fdd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -473,7 +473,7 @@ jobs: run: CPU=1 PYTHONPATH=. python3 test/test_quantize_onnx.py - name: Run REMOTE=1 Test run: | - REMOTEDEV=CPU REMOTE=1 python3 -m pytest test/test_tiny.py test/test_jit.py + REMOTEDEV=CPU REMOTE=1 python3 -m pytest test/test_tiny.py test/test_jit.py test/test_multitensor.py REMOTEDEV=GPU REMOTE=1 python3 -m pytest test/test_tiny.py test/test_image_dtype.py test/test_jit.py REMOTEDEV=GPU IMAGE=2 REMOTE=1 python3 -m pytest test/test_tiny.py test/test_image_dtype.py - name: Test Optimization Helpers @@ -795,7 +795,7 @@ jobs: DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus - name: Run REMOTE=1 Test run: | - python3 -m pytest test/test_tiny.py test/test_jit.py + python3 -m pytest test/test_tiny.py test/test_jit.py test/test_multitensor.py osxtests: strategy: diff --git a/test/helpers.py b/test/helpers.py index f51a77be79..95143d18b7 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -64,8 +64,8 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None): return out_buf.cast(uop.dtype.fmt).tolist()[0] def not_support_multi_device(): - # REMOTE doesn't support multi device anywhere, GPU and CUDA don't support multi device if in CI - return Device.DEFAULT == "REMOTE" or (CI and Device.DEFAULT in ("GPU", "CUDA")) + # GPU and CUDA don't support multi device if in CI + return CI and REAL_DEV in ("GPU", "CUDA") # NOTE: This will open REMOTE if it's the default device REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties['remotedev']) diff --git a/tinygrad/runtime/ops_remote.py b/tinygrad/runtime/ops_remote.py index 7f313c181f..ec77d5cd21 100644 --- a/tinygrad/runtime/ops_remote.py +++ b/tinygrad/runtime/ops_remote.py @@ -21,7 +21,7 @@ from tinygrad.runtime.graph.cpu import CPUGraph # ***** API ***** @dataclass(frozen=True) -class RemoteRequest: session: str|None = field(default=None, kw_only=True) +class RemoteRequest: session: tuple[str, int]|None = field(default=None, kw_only=True) @dataclass(frozen=True) class BufferAlloc(RemoteRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702 @@ -116,9 +116,9 @@ class RemoteSession: buffers: dict[int, Buffer] = field(default_factory=dict) class RemoteHandler: - def __init__(self, device: str): - self.device = device - self.sessions: defaultdict[str, RemoteSession] = defaultdict(RemoteSession) + def __init__(self, base_device: str): + self.base_device = base_device + self.sessions: defaultdict[tuple[str, int], RemoteSession] = defaultdict(RemoteSession) async def __call__(self, reader:asyncio.StreamReader, writer:asyncio.StreamWriter): while (req_hdr:=(await reader.readline()).decode().strip()): @@ -139,17 +139,17 @@ class RemoteHandler: # the cmds are always last (currently in datahash) for c in req._q: if DEBUG >= 1: print(c) - session = self.sessions[unwrap(c.session)] + session, dev = self.sessions[unwrap(c.session)], Device[f"{self.base_device}:{unwrap(c.session)[1]}"] match c: case BufferAlloc(): assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated" - session.buffers[c.buffer_num] = Buffer(self.device, c.size, dtypes.uint8, options=c.options, preallocate=True) + session.buffers[c.buffer_num] = Buffer(dev.device, c.size, dtypes.uint8, options=c.options, preallocate=True) case BufferFree(): del session.buffers[c.buffer_num] case CopyIn(): session.buffers[c.buffer_num].copyin(memoryview(bytearray(req._h[c.datahash]))) case CopyOut(): session.buffers[c.buffer_num].copyout(memoryview(ret:=bytearray(session.buffers[c.buffer_num].nbytes))) case ProgramAlloc(): - lib = Device[self.device].compiler.compile_cached(req._h[c.datahash].decode()) - session.programs[(c.name, c.datahash)] = Device[self.device].runtime(c.name, lib) + lib = dev.compiler.compile_cached(req._h[c.datahash].decode()) + session.programs[(c.name, c.datahash)] = dev.runtime(c.name, lib) case ProgramFree(): del session.programs[(c.name, c.datahash)] case ProgramExec(): bufs = [session.buffers[x]._buf for x in c.bufs] @@ -157,10 +157,10 @@ class RemoteHandler: r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args) if r is not None: ret = str(r).encode() case GraphAlloc(): - graph_fn: Callable = unwrap(Device[self.device].graph) + graph_fn: Callable = unwrap(dev.graph) def _parse_ji(gi: GraphComputeItem): prg = session.programs[(gi.name, gi.datahash)] - ps = ProgramSpec(gi.name, '', self.device, UOp(Ops.NOOP), vars=list(gi.vars), + ps = ProgramSpec(gi.name, '', dev.device, UOp(Ops.NOOP), vars=list(gi.vars), global_size=list(cast(tuple[int], gi.global_size)) if gi.global_size is not None else None, local_size=list(cast(tuple[int], gi.local_size)) if gi.local_size is not None else None) return ExecItem(CompiledRunner(ps, precompiled=b'', prg=prg), [session.buffers[buf] for buf in gi.bufs]) @@ -171,10 +171,10 @@ class RemoteHandler: r = session.graphs[c.graph_num]([session.buffers[buf] for buf in c.bufs], c.var_vals, wait=c.wait) if r is not None: ret = str(r).encode() elif path == "/properties" and method == "GET": - cls, args = Device[self.device].renderer.__reduce__() + cls, args = Device[self.base_device].renderer.__reduce__() # CPUGraph re-renders kernel from uops specified in CompiledRunner, this is not supported - graph_cls = gt if (gt:=graph_class(Device[self.device])) is not CPUGraph else None - ret = json.dumps({'remotedev': self.device, 'renderer': (cls.__module__, cls.__name__, args), 'graph': graph_cls is not None}).encode() + graph_cls = gt if (gt:=graph_class(Device[self.base_device])) is not CPUGraph else None + ret = json.dumps({'remotedev': self.base_device, 'renderer': (cls.__module__, cls.__name__, args), 'graph': graph_cls is not None}).encode() else: status, ret = http.HTTPStatus.NOT_FOUND, b"Not Found" return status, ret @@ -219,14 +219,12 @@ class RemoteProgram: class RemoteDevice(Compiled): def __init__(self, device:str): - if (host:=getenv("HOST", "")) != "": self.host = host - else: - multiprocessing.Process(target=remote_server, args=(6667,), name="MainProcess", daemon=True).start() - self.host = "127.0.0.1:6667" + self.host = getenv("HOST", "") or RemoteDevice.local_server() # state for the connection - self.session = binascii.hexlify(os.urandom(0x10)).decode() - self.buffer_num, self.graph_num = 0, 0 + self.session = (binascii.hexlify(os.urandom(0x10)).decode(), int(device.split(":")[1]) if ":" in device else 0) + self.buffer_num: int = 0 + self.graph_num: int = 0 self.req: BatchRequest = BatchRequest() if DEBUG >= 1: print(f"remote with host {self.host}") @@ -254,6 +252,12 @@ class RemoteDevice(Compiled): def q(self, x:RemoteRequest): self.req.q(replace(x, session=self.session)) + @functools.cache + @staticmethod + def local_server(): + multiprocessing.Process(target=remote_server, args=(6667,), name="MainProcess", daemon=True).start() + return "127.0.0.1:6667" + def batch_submit(self): data = self.req.serialize() with Timing(f"*** send {len(self.req._q):-3d} requests {len(self.req._h):-3d} hashes with len {len(data)/1024:.2f} kB in ", enabled=DEBUG>=1):