mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
Remote multihost (#10598)
This commit is contained in:
25
.github/workflows/test.yml
vendored
25
.github/workflows/test.yml
vendored
@@ -856,9 +856,8 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
env:
|
||||
REMOTE: 1
|
||||
REMOTEDEV: 'AMD'
|
||||
HOST: 127.0.0.1:6667*6,127.0.0.1:6668*6
|
||||
PYTHONPATH: ${{ github.workspace }}
|
||||
MOCKGPU: 1
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
@@ -869,6 +868,21 @@ jobs:
|
||||
deps: testing_minimal
|
||||
amd: 'true'
|
||||
llvm: 'true'
|
||||
- name: Start remote server
|
||||
run: |
|
||||
start_server() {
|
||||
systemd-run --user \
|
||||
--unit="$1" \
|
||||
--setenv=REMOTEDEV=AMD \
|
||||
--setenv=MOCKGPU=1 \
|
||||
--setenv=PYTHONPATH=. \
|
||||
--setenv=PORT="$2" \
|
||||
--working-directory="$(pwd)" \
|
||||
python tinygrad/runtime/ops_remote.py
|
||||
}
|
||||
|
||||
start_server "remote-server-1" 6667
|
||||
start_server "remote-server-2" 6668
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
run: |
|
||||
python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT"
|
||||
@@ -876,7 +890,12 @@ jobs:
|
||||
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
|
||||
- name: Run REMOTE=1 Test
|
||||
run: |
|
||||
python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_tensor_variable.py
|
||||
python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_remote.py test/test_tensor_variable.py
|
||||
- name: Show remote server logs
|
||||
if: always()
|
||||
run: |
|
||||
journalctl --user -u remote-server-1 --no-pager
|
||||
journalctl --user -u remote-server-2 --no-pager
|
||||
|
||||
osxtests:
|
||||
strategy:
|
||||
|
||||
65
test/test_remote.py
Normal file
65
test/test_remote.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import numpy as np, unittest, string
|
||||
from hypothesis import given, strategies as st
|
||||
from tinygrad import Device, Tensor, TinyJit
|
||||
from tinygrad.runtime.ops_remote import RemoteDevice, parse_hosts
|
||||
from tinygrad.helpers import LazySeq, all_same
|
||||
|
||||
def multihost_env(devices):
|
||||
def same_hosts(devices): return all_same([h for h,_ in devices])
|
||||
return isinstance(devices, list) and len(devices) >= 12 and not same_hosts(devices[0:12]) and same_hosts(devices[0:6]) and same_hosts(devices[6:12])
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "REMOTE" and multihost_env(RemoteDevice.devices), "Requires special environment")
|
||||
class TestRemoteMultiHost(unittest.TestCase):
|
||||
def test_mutlihost_transfer(self):
|
||||
a = Tensor.arange(0, 16, device='REMOTE:0').contiguous().realize()
|
||||
b = a.to('REMOTE:6').contiguous().realize()
|
||||
np.testing.assert_equal(b.numpy(), np.arange(0, 16))
|
||||
|
||||
# NOTE: remote graph currently throws GraphException on host mismatch, this just checks that it is being handled, not that jit graph is being used
|
||||
def test_multihost_matmul_jit(self):
|
||||
@TinyJit
|
||||
def do(a:Tensor, b:Tensor): return (a @ b).contiguous().realize()
|
||||
ds = ('REMOTE:0', 'REMOTE:1', 'REMOTE:6', 'REMOTE:7')
|
||||
for _ in range(3):
|
||||
na, nb = np.random.rand(128, 128).astype(np.float32), np.random.rand(128, 128).astype(np.float32)
|
||||
a, b = Tensor(na).shard(ds, 0).contiguous().realize(), Tensor(nb).shard(ds, 0).contiguous().realize()
|
||||
nc = na @ nb
|
||||
c = do(a, b)
|
||||
np.testing.assert_allclose(nc, c.numpy(), rtol=3e-2, atol=1e-4) # tolerances from extra/gemm/simple_matmul.py
|
||||
|
||||
class TestParseHosts(unittest.TestCase):
|
||||
def assert_seq(self, result:LazySeq, host:str):
|
||||
self.assertIsInstance(result, LazySeq)
|
||||
for i in [0, 1, 5, 10]: self.assertEqual(result[i], (host, i))
|
||||
|
||||
@given(st.sampled_from(["", "localhost", "192.168.1.1:8080", "host"]))
|
||||
def test_single_host_no_count(self, host:str):
|
||||
self.assert_seq(parse_hosts(host), host)
|
||||
|
||||
@given(host=st.sampled_from(["localhost", "host", "192.168.1.1:8080"]), count=st.integers(0, 10))
|
||||
def test_single_host_with_count(self, host:str, count:int):
|
||||
self.assertEqual(parse_hosts(f"{host}*{count}"), [(host, i) for i in range(count)])
|
||||
|
||||
def test_multiple_hosts_with_counts_simple(self):
|
||||
self.assertEqual(parse_hosts("host1*2,host2*3"), [("host1", i) for i in range(2)] + [("host2", i) for i in range(3)])
|
||||
|
||||
@given(st.lists(st.tuples(st.text(alphabet=string.ascii_letters + string.digits + ".-:"), st.integers(1, 16)), min_size=1))
|
||||
def test_multiple_hosts_with_counts_sampled(self, host_count_pairs):
|
||||
hosts_str = ",".join(f"{host}*{count}" for host, count in host_count_pairs)
|
||||
expected = [(host, i) for host, count in host_count_pairs for i in range(count)]
|
||||
self.assertEqual(parse_hosts(hosts_str), expected)
|
||||
|
||||
@given(st.sampled_from(["host1*2,host2", "a*1,b", "x*3,y*2,z"]))
|
||||
def test_mixed_hosts_fails(self, hosts):
|
||||
with self.assertRaises(AssertionError): parse_hosts(hosts)
|
||||
|
||||
@given(st.sampled_from(["host*abc", "test*xyz", "a*1.5"]))
|
||||
def test_invalid_count_fails(self, hosts):
|
||||
with self.assertRaises(ValueError): parse_hosts(hosts)
|
||||
|
||||
@given(st.sampled_from(["host*2*3", "a*1*2*3", "test*x*y"]))
|
||||
def test_multiple_asterisks_fails(self, hosts):
|
||||
with self.assertRaises(ValueError): parse_hosts(hosts)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -75,6 +75,10 @@ def get_child(obj, key):
|
||||
def word_wrap(x, wrap=80): return x if len(x) <= wrap or '\n' in x[0:wrap] else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
|
||||
def pluralize(st:str, cnt:int): return f"{cnt} {st}"+('' if cnt == 1 else 's')
|
||||
|
||||
class LazySeq(Generic[T]): # NOTE: Mapping requires __iter__ and __len__, Sequence requires supporting __len__ and slicing in __getitem__
|
||||
def __init__(self, gen:Callable[[int], T]): self.gen = gen
|
||||
def __getitem__(self, idx:int) -> T: return self.gen(idx)
|
||||
|
||||
# for length N coefficients `p`, returns p[0] * x**(N-1) + p[1] * x**(N-2) + ... + p[-2] * x + p[-1]
|
||||
def polyN(x:T, p:list[float]) -> T: return functools.reduce(lambda acc,c: acc*x+c, p, 0.0) # type: ignore
|
||||
|
||||
|
||||
@@ -28,7 +28,8 @@ class RemoteGraph(GraphRunner):
|
||||
self.devices[0].q(GraphAlloc(self.graph_num, tuple(_process_ji(ji) for ji in jit_cache), self.map_rawbufs(rawbufs), var_vals))
|
||||
|
||||
def __del__(self):
|
||||
self.devices[0].q(GraphFree(self.graph_num))
|
||||
# This can happen if `GraphException` is thrown at the very start
|
||||
if hasattr(self, "graph_num"): self.devices[0].q(GraphFree(self.graph_num))
|
||||
|
||||
def map_rawbufs(self, rawbufs:list[Buffer]):
|
||||
return tuple((cast(RemoteDevice, Device[rawbufs[i].device]).session, rawbufs[i]._buf) for i in self.iids)
|
||||
|
||||
@@ -12,7 +12,7 @@ import multiprocessing, functools, itertools, asyncio, http, http.client, hashli
|
||||
from tinygrad.renderer import Renderer, ProgramSpec
|
||||
from tinygrad.dtype import DTYPES_DICT, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, Variable, sint
|
||||
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing
|
||||
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, LazySeq, Timing
|
||||
from tinygrad.engine.jit import GraphRunner, MultiGraphRunner, ExecItem, graph_class
|
||||
from tinygrad.engine.realize import CompiledRunner, BufferXfer
|
||||
from tinygrad.device import Compiled, Buffer, Allocator, Compiler, Device, BufferSpec
|
||||
@@ -301,12 +301,22 @@ class RemoteConnection:
|
||||
self.req = BatchRequest()
|
||||
return ret
|
||||
|
||||
class RemoteDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
self.conn: RemoteConnection = RemoteConnection(getenv("HOST", "") or RemoteDevice.local_server())
|
||||
def parse_hosts(hs:str) -> list[tuple[str, int]]|LazySeq[tuple[str, int]]:
|
||||
hosts = [(unwrap(h), int(c) if c is not None else c) for h,c in ((h.split("*", maxsplit=1)+[None,])[:2] for h in hs.split(","))]
|
||||
if len(hosts) == 1 and hosts[0][1] is None: return LazySeq(lambda idx: (hosts[0][0], idx))
|
||||
return [(h, i) for h,c in hosts for i in range(unwrap(c))]
|
||||
|
||||
# state for the connection
|
||||
self.session = (binascii.hexlify(os.urandom(0x10)).decode(), int(device.split(":")[1]) if ":" in device else 0)
|
||||
class RemoteDevice(Compiled):
|
||||
devices = parse_hosts(getenv("HOST", ""))
|
||||
|
||||
def __init__(self, device:str):
|
||||
host, idx = RemoteDevice.devices[int(device.split(":")[1]) if ":" in device else 0]
|
||||
|
||||
# connection is shared between sessions on the same host
|
||||
self.conn: RemoteConnection = RemoteConnection(host or RemoteDevice.local_server())
|
||||
|
||||
# state for the session
|
||||
self.session = (binascii.hexlify(os.urandom(0x10)).decode(), idx)
|
||||
self.buffer_num: Iterator[int] = itertools.count(0)
|
||||
self.graph_num: Iterator[int] = itertools.count(0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user