diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 68611914b7..468f0a432c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -856,9 +856,8 @@ jobs: timeout-minutes: 20 env: REMOTE: 1 - REMOTEDEV: 'AMD' + HOST: 127.0.0.1:6667*6,127.0.0.1:6668*6 PYTHONPATH: ${{ github.workspace }} - MOCKGPU: 1 steps: - name: Checkout Code uses: actions/checkout@v4 @@ -869,6 +868,21 @@ jobs: deps: testing_minimal amd: 'true' llvm: 'true' + - name: Start remote server + run: | + start_server() { + systemd-run --user \ + --unit="$1" \ + --setenv=REMOTEDEV=AMD \ + --setenv=MOCKGPU=1 \ + --setenv=PYTHONPATH=. \ + --setenv=PORT="$2" \ + --working-directory="$(pwd)" \ + python tinygrad/runtime/ops_remote.py + } + + start_server "remote-server-1" 6667 + start_server "remote-server-2" 6668 - name: Check Device.DEFAULT and print some source run: | python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT" @@ -876,7 +890,12 @@ jobs: DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus - name: Run REMOTE=1 Test run: | - python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_tensor_variable.py + python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_remote.py test/test_tensor_variable.py + - name: Show remote server logs + if: always() + run: | + journalctl --user -u remote-server-1 --no-pager + journalctl --user -u remote-server-2 --no-pager osxtests: strategy: diff --git a/test/test_remote.py b/test/test_remote.py new file mode 100644 index 0000000000..a4e5ed5307 --- /dev/null +++ b/test/test_remote.py @@ -0,0 +1,65 @@ +import numpy as np, unittest, string +from hypothesis import given, strategies as st +from tinygrad import Device, Tensor, TinyJit +from tinygrad.runtime.ops_remote import RemoteDevice, parse_hosts +from tinygrad.helpers import LazySeq, all_same + +def multihost_env(devices): + def same_hosts(devices): return all_same([h for h,_ in devices]) + return isinstance(devices, list) and len(devices) >= 12 and not same_hosts(devices[0:12]) and same_hosts(devices[0:6]) and same_hosts(devices[6:12]) + +@unittest.skipUnless(Device.DEFAULT == "REMOTE" and multihost_env(RemoteDevice.devices), "Requires special environment") +class TestRemoteMultiHost(unittest.TestCase): + def test_mutlihost_transfer(self): + a = Tensor.arange(0, 16, device='REMOTE:0').contiguous().realize() + b = a.to('REMOTE:6').contiguous().realize() + np.testing.assert_equal(b.numpy(), np.arange(0, 16)) + + # NOTE: remote graph currently throws GraphException on host mismatch, this just checks that it is being handled, not that jit graph is being used + def test_multihost_matmul_jit(self): + @TinyJit + def do(a:Tensor, b:Tensor): return (a @ b).contiguous().realize() + ds = ('REMOTE:0', 'REMOTE:1', 'REMOTE:6', 'REMOTE:7') + for _ in range(3): + na, nb = np.random.rand(128, 128).astype(np.float32), np.random.rand(128, 128).astype(np.float32) + a, b = Tensor(na).shard(ds, 0).contiguous().realize(), Tensor(nb).shard(ds, 0).contiguous().realize() + nc = na @ nb + c = do(a, b) + np.testing.assert_allclose(nc, c.numpy(), rtol=3e-2, atol=1e-4) # tolerances from extra/gemm/simple_matmul.py + +class TestParseHosts(unittest.TestCase): + def assert_seq(self, result:LazySeq, host:str): + self.assertIsInstance(result, LazySeq) + for i in [0, 1, 5, 10]: self.assertEqual(result[i], (host, i)) + + @given(st.sampled_from(["", "localhost", "192.168.1.1:8080", "host"])) + def test_single_host_no_count(self, host:str): + self.assert_seq(parse_hosts(host), host) + + @given(host=st.sampled_from(["localhost", "host", "192.168.1.1:8080"]), count=st.integers(0, 10)) + def test_single_host_with_count(self, host:str, count:int): + self.assertEqual(parse_hosts(f"{host}*{count}"), [(host, i) for i in range(count)]) + + def test_multiple_hosts_with_counts_simple(self): + self.assertEqual(parse_hosts("host1*2,host2*3"), [("host1", i) for i in range(2)] + [("host2", i) for i in range(3)]) + + @given(st.lists(st.tuples(st.text(alphabet=string.ascii_letters + string.digits + ".-:"), st.integers(1, 16)), min_size=1)) + def test_multiple_hosts_with_counts_sampled(self, host_count_pairs): + hosts_str = ",".join(f"{host}*{count}" for host, count in host_count_pairs) + expected = [(host, i) for host, count in host_count_pairs for i in range(count)] + self.assertEqual(parse_hosts(hosts_str), expected) + + @given(st.sampled_from(["host1*2,host2", "a*1,b", "x*3,y*2,z"])) + def test_mixed_hosts_fails(self, hosts): + with self.assertRaises(AssertionError): parse_hosts(hosts) + + @given(st.sampled_from(["host*abc", "test*xyz", "a*1.5"])) + def test_invalid_count_fails(self, hosts): + with self.assertRaises(ValueError): parse_hosts(hosts) + + @given(st.sampled_from(["host*2*3", "a*1*2*3", "test*x*y"])) + def test_multiple_asterisks_fails(self, hosts): + with self.assertRaises(ValueError): parse_hosts(hosts) + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 1c2f7ff3f8..57743e76eb 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -75,6 +75,10 @@ def get_child(obj, key): def word_wrap(x, wrap=80): return x if len(x) <= wrap or '\n' in x[0:wrap] else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap)) def pluralize(st:str, cnt:int): return f"{cnt} {st}"+('' if cnt == 1 else 's') +class LazySeq(Generic[T]): # NOTE: Mapping requires __iter__ and __len__, Sequence requires supporting __len__ and slicing in __getitem__ + def __init__(self, gen:Callable[[int], T]): self.gen = gen + def __getitem__(self, idx:int) -> T: return self.gen(idx) + # for length N coefficients `p`, returns p[0] * x**(N-1) + p[1] * x**(N-2) + ... + p[-2] * x + p[-1] def polyN(x:T, p:list[float]) -> T: return functools.reduce(lambda acc,c: acc*x+c, p, 0.0) # type: ignore diff --git a/tinygrad/runtime/graph/remote.py b/tinygrad/runtime/graph/remote.py index 04c6dbdbc9..f086dc1ffc 100644 --- a/tinygrad/runtime/graph/remote.py +++ b/tinygrad/runtime/graph/remote.py @@ -28,7 +28,8 @@ class RemoteGraph(GraphRunner): self.devices[0].q(GraphAlloc(self.graph_num, tuple(_process_ji(ji) for ji in jit_cache), self.map_rawbufs(rawbufs), var_vals)) def __del__(self): - self.devices[0].q(GraphFree(self.graph_num)) + # This can happen if `GraphException` is thrown at the very start + if hasattr(self, "graph_num"): self.devices[0].q(GraphFree(self.graph_num)) def map_rawbufs(self, rawbufs:list[Buffer]): return tuple((cast(RemoteDevice, Device[rawbufs[i].device]).session, rawbufs[i]._buf) for i in self.iids) diff --git a/tinygrad/runtime/ops_remote.py b/tinygrad/runtime/ops_remote.py index 991da324c6..18d93672ab 100644 --- a/tinygrad/runtime/ops_remote.py +++ b/tinygrad/runtime/ops_remote.py @@ -12,7 +12,7 @@ import multiprocessing, functools, itertools, asyncio, http, http.client, hashli from tinygrad.renderer import Renderer, ProgramSpec from tinygrad.dtype import DTYPES_DICT, dtypes from tinygrad.uop.ops import UOp, Ops, Variable, sint -from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing +from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, LazySeq, Timing from tinygrad.engine.jit import GraphRunner, MultiGraphRunner, ExecItem, graph_class from tinygrad.engine.realize import CompiledRunner, BufferXfer from tinygrad.device import Compiled, Buffer, Allocator, Compiler, Device, BufferSpec @@ -301,12 +301,22 @@ class RemoteConnection: self.req = BatchRequest() return ret -class RemoteDevice(Compiled): - def __init__(self, device:str): - self.conn: RemoteConnection = RemoteConnection(getenv("HOST", "") or RemoteDevice.local_server()) +def parse_hosts(hs:str) -> list[tuple[str, int]]|LazySeq[tuple[str, int]]: + hosts = [(unwrap(h), int(c) if c is not None else c) for h,c in ((h.split("*", maxsplit=1)+[None,])[:2] for h in hs.split(","))] + if len(hosts) == 1 and hosts[0][1] is None: return LazySeq(lambda idx: (hosts[0][0], idx)) + return [(h, i) for h,c in hosts for i in range(unwrap(c))] - # state for the connection - self.session = (binascii.hexlify(os.urandom(0x10)).decode(), int(device.split(":")[1]) if ":" in device else 0) +class RemoteDevice(Compiled): + devices = parse_hosts(getenv("HOST", "")) + + def __init__(self, device:str): + host, idx = RemoteDevice.devices[int(device.split(":")[1]) if ":" in device else 0] + + # connection is shared between sessions on the same host + self.conn: RemoteConnection = RemoteConnection(host or RemoteDevice.local_server()) + + # state for the session + self.session = (binascii.hexlify(os.urandom(0x10)).decode(), idx) self.buffer_num: Iterator[int] = itertools.count(0) self.graph_num: Iterator[int] = itertools.count(0)