Remote multihost (#10598)

This commit is contained in:
uuuvn
2025-06-16 20:18:56 +00:00
committed by GitHub
parent 0629e45332
commit 18d936f981
5 changed files with 109 additions and 10 deletions

View File

@@ -856,9 +856,8 @@ jobs:
timeout-minutes: 20
env:
REMOTE: 1
REMOTEDEV: 'AMD'
HOST: 127.0.0.1:6667*6,127.0.0.1:6668*6
PYTHONPATH: ${{ github.workspace }}
MOCKGPU: 1
steps:
- name: Checkout Code
uses: actions/checkout@v4
@@ -869,6 +868,21 @@ jobs:
deps: testing_minimal
amd: 'true'
llvm: 'true'
- name: Start remote server
run: |
start_server() {
systemd-run --user \
--unit="$1" \
--setenv=REMOTEDEV=AMD \
--setenv=MOCKGPU=1 \
--setenv=PYTHONPATH=. \
--setenv=PORT="$2" \
--working-directory="$(pwd)" \
python tinygrad/runtime/ops_remote.py
}
start_server "remote-server-1" 6667
start_server "remote-server-2" 6668
- name: Check Device.DEFAULT and print some source
run: |
python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT"
@@ -876,7 +890,12 @@ jobs:
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
- name: Run REMOTE=1 Test
run: |
python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_tensor_variable.py
python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_remote.py test/test_tensor_variable.py
- name: Show remote server logs
if: always()
run: |
journalctl --user -u remote-server-1 --no-pager
journalctl --user -u remote-server-2 --no-pager
osxtests:
strategy:

65
test/test_remote.py Normal file
View File

@@ -0,0 +1,65 @@
import numpy as np, unittest, string
from hypothesis import given, strategies as st
from tinygrad import Device, Tensor, TinyJit
from tinygrad.runtime.ops_remote import RemoteDevice, parse_hosts
from tinygrad.helpers import LazySeq, all_same
def multihost_env(devices):
def same_hosts(devices): return all_same([h for h,_ in devices])
return isinstance(devices, list) and len(devices) >= 12 and not same_hosts(devices[0:12]) and same_hosts(devices[0:6]) and same_hosts(devices[6:12])
@unittest.skipUnless(Device.DEFAULT == "REMOTE" and multihost_env(RemoteDevice.devices), "Requires special environment")
class TestRemoteMultiHost(unittest.TestCase):
def test_mutlihost_transfer(self):
a = Tensor.arange(0, 16, device='REMOTE:0').contiguous().realize()
b = a.to('REMOTE:6').contiguous().realize()
np.testing.assert_equal(b.numpy(), np.arange(0, 16))
# NOTE: remote graph currently throws GraphException on host mismatch, this just checks that it is being handled, not that jit graph is being used
def test_multihost_matmul_jit(self):
@TinyJit
def do(a:Tensor, b:Tensor): return (a @ b).contiguous().realize()
ds = ('REMOTE:0', 'REMOTE:1', 'REMOTE:6', 'REMOTE:7')
for _ in range(3):
na, nb = np.random.rand(128, 128).astype(np.float32), np.random.rand(128, 128).astype(np.float32)
a, b = Tensor(na).shard(ds, 0).contiguous().realize(), Tensor(nb).shard(ds, 0).contiguous().realize()
nc = na @ nb
c = do(a, b)
np.testing.assert_allclose(nc, c.numpy(), rtol=3e-2, atol=1e-4) # tolerances from extra/gemm/simple_matmul.py
class TestParseHosts(unittest.TestCase):
def assert_seq(self, result:LazySeq, host:str):
self.assertIsInstance(result, LazySeq)
for i in [0, 1, 5, 10]: self.assertEqual(result[i], (host, i))
@given(st.sampled_from(["", "localhost", "192.168.1.1:8080", "host"]))
def test_single_host_no_count(self, host:str):
self.assert_seq(parse_hosts(host), host)
@given(host=st.sampled_from(["localhost", "host", "192.168.1.1:8080"]), count=st.integers(0, 10))
def test_single_host_with_count(self, host:str, count:int):
self.assertEqual(parse_hosts(f"{host}*{count}"), [(host, i) for i in range(count)])
def test_multiple_hosts_with_counts_simple(self):
self.assertEqual(parse_hosts("host1*2,host2*3"), [("host1", i) for i in range(2)] + [("host2", i) for i in range(3)])
@given(st.lists(st.tuples(st.text(alphabet=string.ascii_letters + string.digits + ".-:"), st.integers(1, 16)), min_size=1))
def test_multiple_hosts_with_counts_sampled(self, host_count_pairs):
hosts_str = ",".join(f"{host}*{count}" for host, count in host_count_pairs)
expected = [(host, i) for host, count in host_count_pairs for i in range(count)]
self.assertEqual(parse_hosts(hosts_str), expected)
@given(st.sampled_from(["host1*2,host2", "a*1,b", "x*3,y*2,z"]))
def test_mixed_hosts_fails(self, hosts):
with self.assertRaises(AssertionError): parse_hosts(hosts)
@given(st.sampled_from(["host*abc", "test*xyz", "a*1.5"]))
def test_invalid_count_fails(self, hosts):
with self.assertRaises(ValueError): parse_hosts(hosts)
@given(st.sampled_from(["host*2*3", "a*1*2*3", "test*x*y"]))
def test_multiple_asterisks_fails(self, hosts):
with self.assertRaises(ValueError): parse_hosts(hosts)
if __name__ == '__main__':
unittest.main()

View File

@@ -75,6 +75,10 @@ def get_child(obj, key):
def word_wrap(x, wrap=80): return x if len(x) <= wrap or '\n' in x[0:wrap] else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
def pluralize(st:str, cnt:int): return f"{cnt} {st}"+('' if cnt == 1 else 's')
class LazySeq(Generic[T]): # NOTE: Mapping requires __iter__ and __len__, Sequence requires supporting __len__ and slicing in __getitem__
def __init__(self, gen:Callable[[int], T]): self.gen = gen
def __getitem__(self, idx:int) -> T: return self.gen(idx)
# for length N coefficients `p`, returns p[0] * x**(N-1) + p[1] * x**(N-2) + ... + p[-2] * x + p[-1]
def polyN(x:T, p:list[float]) -> T: return functools.reduce(lambda acc,c: acc*x+c, p, 0.0) # type: ignore

View File

@@ -28,7 +28,8 @@ class RemoteGraph(GraphRunner):
self.devices[0].q(GraphAlloc(self.graph_num, tuple(_process_ji(ji) for ji in jit_cache), self.map_rawbufs(rawbufs), var_vals))
def __del__(self):
self.devices[0].q(GraphFree(self.graph_num))
# This can happen if `GraphException` is thrown at the very start
if hasattr(self, "graph_num"): self.devices[0].q(GraphFree(self.graph_num))
def map_rawbufs(self, rawbufs:list[Buffer]):
return tuple((cast(RemoteDevice, Device[rawbufs[i].device]).session, rawbufs[i]._buf) for i in self.iids)

View File

@@ -12,7 +12,7 @@ import multiprocessing, functools, itertools, asyncio, http, http.client, hashli
from tinygrad.renderer import Renderer, ProgramSpec
from tinygrad.dtype import DTYPES_DICT, dtypes
from tinygrad.uop.ops import UOp, Ops, Variable, sint
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, LazySeq, Timing
from tinygrad.engine.jit import GraphRunner, MultiGraphRunner, ExecItem, graph_class
from tinygrad.engine.realize import CompiledRunner, BufferXfer
from tinygrad.device import Compiled, Buffer, Allocator, Compiler, Device, BufferSpec
@@ -301,12 +301,22 @@ class RemoteConnection:
self.req = BatchRequest()
return ret
class RemoteDevice(Compiled):
def __init__(self, device:str):
self.conn: RemoteConnection = RemoteConnection(getenv("HOST", "") or RemoteDevice.local_server())
def parse_hosts(hs:str) -> list[tuple[str, int]]|LazySeq[tuple[str, int]]:
hosts = [(unwrap(h), int(c) if c is not None else c) for h,c in ((h.split("*", maxsplit=1)+[None,])[:2] for h in hs.split(","))]
if len(hosts) == 1 and hosts[0][1] is None: return LazySeq(lambda idx: (hosts[0][0], idx))
return [(h, i) for h,c in hosts for i in range(unwrap(c))]
# state for the connection
self.session = (binascii.hexlify(os.urandom(0x10)).decode(), int(device.split(":")[1]) if ":" in device else 0)
class RemoteDevice(Compiled):
devices = parse_hosts(getenv("HOST", ""))
def __init__(self, device:str):
host, idx = RemoteDevice.devices[int(device.split(":")[1]) if ":" in device else 0]
# connection is shared between sessions on the same host
self.conn: RemoteConnection = RemoteConnection(host or RemoteDevice.local_server())
# state for the session
self.session = (binascii.hexlify(os.urandom(0x10)).decode(), idx)
self.buffer_num: Iterator[int] = itertools.count(0)
self.graph_num: Iterator[int] = itertools.count(0)