diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 5ded2e60af..4d00ea1190 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -554,8 +554,6 @@ jobs: run: time BENCHMARK_LOG=cifar AMD=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt - name: Run full CIFAR training steps w 6 GPUS run: time BENCHMARK_LOG=cifar_6gpu AMD=1 DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee train_cifar_six_gpu.txt - - name: Run full CIFAR training steps w 6 GPUS (REMOTE) - run: time BENCHMARK_LOG=cifar_6gpu_remote REMOTE=1 REMOTEDEV=AMD DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee train_cifar_six_gpu_remote.txt - uses: actions/upload-artifact@v4 with: name: Speed (AMD Training) @@ -567,7 +565,6 @@ jobs: train_cifar_wino.txt train_cifar_one_gpu.txt train_cifar_six_gpu.txt - train_cifar_six_gpu_remote.txt - name: Run process replay tests run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 27d97fd484..8be43cae8b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -721,71 +721,6 @@ jobs: - name: Run process replay tests uses: ./.github/actions/process-replay - amdremote: - name: Linux (remote) - runs-on: ubuntu-22.04 - timeout-minutes: 20 - env: - REMOTE: 1 - steps: - - name: Checkout Code - uses: actions/checkout@v4 - - name: Setup Environment - uses: ./.github/actions/setup-tinygrad - with: - key: linux-remote - deps: testing_minimal - amd: 'true' - llvm: 'true' - opencl: 'true' - - name: Start remote server - run: | - start_server() { - systemd-run --user \ - --unit="$1" \ - --setenv=REMOTEDEV="$2" \ - --setenv=MOCKGPU=1 \ - --setenv=PYTHONPATH=. \ - --setenv=PORT="$3" \ - --working-directory="$(pwd)" \ - python tinygrad/runtime/ops_remote.py - } - - start_server "remote-server-amd-1" "AMD" 6667 - start_server "remote-server-amd-2" "AMD" 6668 - start_server "remote-server-gpu" "CL" 7667 - start_server "remote-server-cpu" "CPU" 8667 - - name: Check Device.DEFAULT and print some source - env: - HOST: 127.0.0.1:6667*6,127.0.0.1:6668*6 - run: | - python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT" - python -c "from tinygrad import Device; assert Device.default.properties.real_device == 'AMD', Device.default.properties.real_device" - DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus - - name: Run REMOTE=1 Test (AMD) - env: - HOST: 127.0.0.1:6667*6,127.0.0.1:6668*6 - run: | - python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_remote.py test/test_tensor_variable.py --durations 20 - - name: Run REMOTE=1 Test (CL) - env: - HOST: 127.0.0.1:7667*6 - run: | - python3 -m pytest test/test_tiny.py test/test_image_dtype.py test/test_jit.py --durations 20 - IMAGE=2 python3 -m pytest test/test_tiny.py test/test_image_dtype.py - - name: Run REMOTE=1 Test (CPU) - env: - HOST: 127.0.0.1:8667*6 - run: | - python3 -m pytest test/test_tiny.py test/test_jit.py test/test_multitensor.py --durations 20 - - name: Show remote server logs - if: always() - run: | - journalctl --user -u remote-server-amd-1 --no-pager - journalctl --user -u remote-server-amd-2 --no-pager - journalctl --user -u remote-server-gpu --no-pager - journalctl --user -u remote-server-cpu --no-pager - # ****** OSX Tests ****** testmetal: @@ -883,30 +818,6 @@ jobs: - name: Test ONNX Runner (WEBGPU) run: WEBGPU=1 python3 test/external/external_test_onnx_runner.py - osxremote: - name: MacOS (remote metal) - runs-on: macos-15 - timeout-minutes: 10 - env: - REMOTE: 1 - REMOTEDEV: METAL - steps: - - name: Checkout Code - uses: actions/checkout@v4 - - name: Setup Environment - uses: ./.github/actions/setup-tinygrad - with: - key: macos-remote - deps: testing_minimal - - name: Check Device.DEFAULT and print some source - run: | - python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT" - python -c "from tinygrad import Device; assert Device.default.properties.real_device == 'METAL', Device.default.properties.real_device" - DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus - - name: Run REMOTE=1 Test - run: | - python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_tensor_variable.py - osxtests: strategy: fail-fast: false diff --git a/test/external/external_test_example.py b/test/external/external_test_example.py index 1114a5917d..73c74a10e2 100644 --- a/test/external/external_test_example.py +++ b/test/external/external_test_example.py @@ -8,7 +8,7 @@ def multidevice_test(fxn): def ret(self): for device in Device._devices: # broken on OSX USB AMD, why? - if device in ["REMOTE", "DISK", "NPY", "FAKE", "DSP", "NULL"] or (OSX and device in ["AMD"]): continue + if device in ["DISK", "NPY", "FAKE", "DSP", "NULL"] or (OSX and device in ["AMD"]): continue if not CI: print(device) if device in exclude_devices: if not CI: print(f"WARNING: {device} test is excluded") diff --git a/test/helpers.py b/test/helpers.py index 03dd567a6b..0e9661bf41 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -69,5 +69,4 @@ def needs_second_gpu(fn): return fn(self, *args, **kwargs) return wrapper -# NOTE: This will open REMOTE if it's the default device -REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties.real_device) +REAL_DEV = Device.DEFAULT diff --git a/test/test_remote.py b/test/test_remote.py deleted file mode 100644 index 66533ea851..0000000000 --- a/test/test_remote.py +++ /dev/null @@ -1,101 +0,0 @@ -import numpy as np, unittest, string -from hypothesis import given, strategies as st -from tinygrad import Device, Tensor, TinyJit, dtypes -from tinygrad.runtime.ops_remote import RemoteDevice, parse_hosts -from tinygrad.runtime.graph.remote import RemoteGraph -from tinygrad.helpers import LazySeq, all_same, Context - -def multihost_env(devices): - def same_hosts(devices): return all_same([h for h,_ in devices]) - return isinstance(devices, list) and len(devices) >= 12 and not same_hosts(devices[0:12]) and same_hosts(devices[0:6]) and same_hosts(devices[6:12]) - -@unittest.skipUnless(Device.DEFAULT == "REMOTE" and multihost_env(RemoteDevice.devices), "Requires special environment") -class TestRemoteMultiHost(unittest.TestCase): - def test_mutlihost_transfer(self): - a = Tensor.arange(0, 16, device='REMOTE:0').contiguous().realize() - b = a.to('REMOTE:6').contiguous().realize() - np.testing.assert_equal(b.numpy(), np.arange(0, 16)) - - @Context(JIT_BATCH_SIZE=2**32) - @unittest.skip("kernel must all be multibuffer") - def test_multihost_matmul_jit_graph(self): - @TinyJit - def do(a:Tensor, b:Tensor): return (a @ b).contiguous().realize() - - ds = ('REMOTE:0', 'REMOTE:1', 'REMOTE:6', 'REMOTE:7') - for _ in range(3): - na, nb = np.random.rand(128, 128).astype(np.float32), np.random.rand(128, 128).astype(np.float32) - a, b = Tensor(na).shard(ds, 0).contiguous().realize(), Tensor(nb).shard(ds, 0).contiguous().realize() - nc = na @ nb - c = do(a, b) - np.testing.assert_allclose(nc, c.numpy(), rtol=3e-2, atol=1e-4) # tolerances from extra/gemm/simple_matmul.py - - # Verify that everything is in one big cross-host graph - assert len(do.captured._jit_cache) == 1 and isinstance(do.captured._jit_cache[0].prg, RemoteGraph), repr(do.captured) - - @Context(JIT_BATCH_SIZE=2**32) - @unittest.skip("assign target and input devices mismatch") - def test_multihost_aware_schedule(self): - @TinyJit - def do(*ts:Tensor): - acc = Tensor.zeros(1, dtype=dtypes.float32).contiguous().realize() - for t in ts: acc += t.sum() - return acc.realize() - - def do_np(*ts:np.ndarray): - acc = np.zeros(1, np.float32) - for t in ts: acc += t.sum() - return acc - - ds = ('REMOTE:0', 'REMOTE:1', 'REMOTE:6', 'REMOTE:7') - TS = 64 - for _ in range(3): - inp_np = [np.random.rand(256).astype(np.float32) for _ in range(TS)] - inp = [Tensor(inp).shard(ds, 0).contiguous().realize() for inp in inp_np] - out_np = do_np(*inp_np) - out = do(*inp) - np.testing.assert_allclose(out_np, out.numpy(), rtol=3e-2, atol=1e-4) - - # Verify that everything is in one big cross-host graph and that the scheduling is reasonable - assert len(do.captured._jit_cache) == 1 and isinstance(do.captured._jit_cache[0].prg, RemoteGraph), repr(do.captured) - # At the time of writing this: 2050 graph breaks without multihost aware scheduling, 14 with it. I've set fail threshold to 28 to not fail on - # unrelated scheduling changes. Maybe 2x is a bit too pessimistic, but remote should perform just fine as long as this is not like a half hundred - # or more here. - self.assertLess(len(do.captured._jit_cache[0].prg.template), 28, "Very bad scheduling! Many unnecesary graph breaks!") - -class TestParseHosts(unittest.TestCase): - def assert_seq(self, result:LazySeq, host:str): - self.assertIsInstance(result, LazySeq) - for i in [0, 1, 5, 10]: self.assertEqual(result[i], (host, i)) - - @given(st.sampled_from(["", "localhost", "192.168.1.1:8080", "host"])) - def test_single_host_no_count(self, host:str): - self.assert_seq(parse_hosts(host), host) - - @given(host=st.sampled_from(["localhost", "host", "192.168.1.1:8080"]), count=st.integers(0, 10)) - def test_single_host_with_count(self, host:str, count:int): - self.assertEqual(parse_hosts(f"{host}*{count}"), [(host, i) for i in range(count)]) - - def test_multiple_hosts_with_counts_simple(self): - self.assertEqual(parse_hosts("host1*2,host2*3"), [("host1", i) for i in range(2)] + [("host2", i) for i in range(3)]) - - @given(st.lists(st.tuples(st.text(alphabet=string.ascii_letters + string.digits + ".-:"), st.integers(1, 16)), min_size=1)) - def test_multiple_hosts_with_counts_sampled(self, host_count_pairs): - hosts_str = ",".join(f"{host}*{count}" for host, count in host_count_pairs) - expected = [(host, i) for host, count in host_count_pairs for i in range(count)] - self.assertEqual(parse_hosts(hosts_str), expected) - - @given(st.sampled_from(["host1*2,host2", "a*1,b", "x*3,y*2,z"])) - def test_mixed_hosts_fails(self, hosts): - with self.assertRaises(AssertionError): parse_hosts(hosts) - - @given(st.sampled_from(["host*abc", "test*xyz", "a*1.5"])) - def test_invalid_count_fails(self, hosts): - with self.assertRaises(ValueError): parse_hosts(hosts) - - @given(st.sampled_from(["host*2*3", "a*1*2*3", "test*x*y"])) - def test_multiple_asterisks_fails(self, hosts): - with self.assertRaises(ValueError): parse_hosts(hosts) - -if __name__ == '__main__': - unittest.main() diff --git a/tinygrad/runtime/graph/remote.py b/tinygrad/runtime/graph/remote.py deleted file mode 100644 index 8a9c5516b7..0000000000 --- a/tinygrad/runtime/graph/remote.py +++ /dev/null @@ -1,113 +0,0 @@ -import time, itertools -from tinygrad.engine.jit import MultiGraphRunner -from tinygrad.engine.realize import CompiledRunner, BufferXfer, ExecItem -from tinygrad.device import Device, Compiled, Buffer -from tinygrad.runtime.ops_remote import RemoteDevice, RemoteConnection, RemoteRequest, GraphComputeItem, Transfer, GraphAlloc, GraphFree, GraphExec -from tinygrad.runtime.ops_remote import BatchTransfer, Event, Wait -from tinygrad.helpers import unwrap, flatten, dedup -from enum import Enum, auto -from dataclasses import replace -from collections import defaultdict -from typing import cast - -class StagingType(Enum): NONE = auto(); GRAPH = auto(); TRANSFER = auto() # noqa: E702 - -def rd(dev:Compiled) -> RemoteDevice: return cast(RemoteDevice, dev) -def dev_key(dev:RemoteDevice): return dev.conn if dev.properties.graph_supports_multi else dev -def map_rawbuf(rawbuf:Buffer): return (cast(RemoteDevice, Device[rawbuf.device]).session, rawbuf._buf) - -class RemoteGraph(MultiGraphRunner): - def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[str, int]): - super().__init__(jit_cache, rawbufs, var_vals) - devices = dedup(flatten([[Device[unwrap(buf).device] for buf in ji.bufs] for ji in jit_cache])) - c2d = {device.conn: device for device in devices} - self.handle_indexes = {map_rawbuf(rawbufs[i]): i for i in sorted(dedup(self.input_replace.values()))} - - self.template: list[RemoteRequest] = [] - - stagings: dict[RemoteDevice|RemoteConnection, list[GraphComputeItem|Transfer]] = defaultdict(list) - clobbered_buffers: set[Buffer] = set() - cur_staging_type: StagingType = StagingType.NONE - - def _flush(new_staging_type:StagingType, force_break:bool=False): - nonlocal cur_staging_type - if cur_staging_type == new_staging_type and not force_break: return - # Pre-sync - if cur_staging_type == StagingType.TRANSFER: - for sdev,ddev in itertools.permutations(c2d.values(), 2): - self.template.append(Event(ddev.session, event:=next(ddev.event_num), session=sdev.session)) - self.template.append(Wait(event, session=ddev.session)) - # Flush - for dev in devices: - dk = dev_key(dev) - staging = stagings[dk] - if not staging: continue - match cur_staging_type: - case StagingType.GRAPH: - bufs = tuple(map_rawbuf(rawbufs[i]) for i in sorted(dedup(self.input_replace.values())) if dev_key(rd(Device[rawbufs[i].device])) == dk) - dev.q(GraphAlloc(graph_num:=next(dev.graph_num), tuple(staging), tuple(bufs), var_vals)) - self.template.append(GraphExec(graph_num, bufs, var_vals, wait=False, session=dev.session)) - case StagingType.TRANSFER: - st = cast(list[Transfer], staging) - for host in dedup(t.dsession.host for t in st): - sbuffer_nums = [(unwrap(t.session), t.buffer_num) for t in st if t.dsession.host == host] - dbuffer_nums = [(t.dsession, t.dbuffer_num) for t in st if t.dsession.host == host] - self.template.append(BatchTransfer(sbuffer_nums, dbuffer_nums, session=dev.session)) - staging.clear() - # Post-sync - if cur_staging_type == StagingType.TRANSFER: - for sdev,ddev in itertools.permutations(c2d.values(), 2): - self.template.append(Event(ddev.session, event:=next(ddev.event_num), session=sdev.session)) - self.template.append(Wait(event, session=ddev.session)) - cur_staging_type = new_staging_type - clobbered_buffers.clear() - - for ji in jit_cache: - match ji.prg: - case CompiledRunner(): - _flush(StagingType.GRAPH) - gi = GraphComputeItem(ji.prg.dev.session, ji.prg._prg.name, ji.prg._prg.datahash, tuple(unwrap(buf)._buf for buf in ji.bufs), - tuple(ji.prg.p.vars), ji.fixedvars, tuple(ji.prg.p.ins), tuple(ji.prg.p.outs), - tuple(ji.prg.p.global_size) if ji.prg.p.global_size is not None else None, - tuple(ji.prg.p.local_size) if ji.prg.p.local_size is not None else None) - stagings[dev_key(ji.prg.dev)].append(gi) - case BufferXfer(): - dest, src = ji.bufs[0:2] - dest_dev, src_dev = cast(RemoteDevice, Device[unwrap(dest).device]), cast(RemoteDevice, Device[unwrap(src).device]) - assert dest is not None and src is not None, ji - ti = Transfer(session=src_dev.session, buffer_num=src._buf, dsession=dest_dev.session, dbuffer_num=dest._buf) - if dev_key(dest_dev) == dev_key(src_dev): - _flush(StagingType.GRAPH) - stagings[dev_key(src_dev)].append(ti) - elif dest_dev.conn == src_dev.conn: - _flush(StagingType.NONE) - self.template.append(ti) - else: - _flush(StagingType.TRANSFER, force_break=src in clobbered_buffers) - clobbered_buffers.add(dest) - stagings[dev_key(src_dev)].append(ti) - case _: raise NotImplementedError(ji.prg) - _flush(StagingType.NONE) - def __del__(self): - for req in self.template: - match req: - case GraphExec(): RemoteConnection(unwrap(req.session).host).q(GraphFree(req.graph_num, session=req.session)) - def __call__(self, rawbufs: list[Buffer], var_vals: dict[str, int], wait=False): - if wait: st = time.perf_counter() - rmap = {orig: map_rawbuf(rawbufs[replace_idx]) for orig,replace_idx in self.handle_indexes.items()} - for req in self.template: - match req: - case GraphExec(): - req = replace(req, bufs=tuple(rmap[buf] for buf in req.bufs), var_vals=var_vals, wait=wait) - case Transfer(): - if (req.session, req.buffer_num) in rmap: req = replace(req, buffer_num=rmap[(req.session, req.buffer_num)][1]) - if (req.dsession, req.dbuffer_num) in rmap: req = replace(req, dbuffer_num=rmap[(req.dsession, req.dbuffer_num)][1]) - case BatchTransfer(): - req = replace(req, sbuffer_nums=[rmap.get(b, b) for b in req.sbuffer_nums], dbuffer_nums=[rmap.get(b, b) for b in req.dbuffer_nums]) - case Event()|Wait(): - pass # event number can be reused - case _: raise NotImplementedError(req) - RemoteConnection(unwrap(req.session).host).q(req) - if wait: - RemoteConnection(unwrap(req.session).host).batch_submit() - return time.perf_counter() - st diff --git a/tinygrad/runtime/ops_remote.py b/tinygrad/runtime/ops_remote.py deleted file mode 100644 index 70eb0857ad..0000000000 --- a/tinygrad/runtime/ops_remote.py +++ /dev/null @@ -1,491 +0,0 @@ -# the REMOTE=1 device is a process boundary between the frontend/runtime -# normally tinygrad is frontend <-> middleware <-> runtime <-> hardware -# with REMOTE tinygrad is frontend <-> middleware <-> RemoteDevice ///HTTP/// remote_server <-> runtime <-> hardware -# this client and server can be on the same machine, same network, or just same internet -# it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC - -from __future__ import annotations -from typing import Callable, Iterator, Any, cast -from collections import defaultdict -from dataclasses import dataclass, field, replace -import multiprocessing, threading, functools, itertools, asyncio, http, http.client, hashlib, time, os, binascii, struct, ast, contextlib, weakref -import traceback, builtins -from tinygrad.renderer import Renderer, ProgramSpec -from tinygrad.dtype import DTYPES_DICT, dtypes -from tinygrad.uop.ops import UOp, Ops, Variable, sint -from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, LazySeq, Timing -from tinygrad.engine.jit import GraphRunner, MultiGraphRunner, ExecItem, graph_class -from tinygrad.engine.realize import CompiledRunner, BufferXfer -from tinygrad.device import Compiled, Buffer, Allocator, Compiler, Device, BufferSpec, CompilerSet, CompilerPair -from tinygrad.runtime.support.ib import IBCtx, IBConn, SGE - -# ***** API ***** - -@dataclass(frozen=True) -class SessionKey: host: str; idx: int; nonce: str # noqa: E702 - -@dataclass(frozen=True) -class RemoteRequest: session: SessionKey|None = field(default=None, kw_only=True) - -@dataclass(frozen=True) -class SessionFree(RemoteRequest): pass - -@dataclass(frozen=True) -class RemoteProperties: - real_device: str - renderer: tuple[str, str, tuple[Any, ...]] - offset_supported: bool - graph_supported: bool - graph_supports_multi: bool - ib_gid: bytes|None - -@dataclass(frozen=True) -class RemoteException: - exc: Exception - trace: str = "" - -@dataclass(frozen=True) -class GetProperties(RemoteRequest): pass - -@dataclass(frozen=True) -class Event(RemoteRequest): event_session: SessionKey; event: int # noqa: E702 - -@dataclass(frozen=True) -class Wait(RemoteRequest): event: int - -@dataclass(frozen=True) -class IBConnect(RemoteRequest): host: str; gid: bytes; qp_num: int # noqa: E702 - -@dataclass(frozen=True) -class BufferAlloc(RemoteRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702 - -@dataclass(frozen=True) -class BufferOffset(RemoteRequest): buffer_num: int; size: int; offset: int; sbuffer_num: int # noqa: E702 - -@dataclass(frozen=True) -class BufferIOVAS(RemoteRequest): buffer_nums: list[tuple[SessionKey, int]] # noqa: E702 - -@dataclass(frozen=True) -class BufferFree(RemoteRequest): buffer_num: int # noqa: E702 - -@dataclass(frozen=True) -class CopyIn(RemoteRequest): buffer_num: int; datahash: str # noqa: E702 - -@dataclass(frozen=True) -class CopyOut(RemoteRequest): buffer_num: int - -@dataclass(frozen=True) -class Transfer(RemoteRequest): buffer_num: int; dsession: SessionKey; dbuffer_num: int # noqa: E702 - -@dataclass(frozen=True) -class BatchTransfer(RemoteRequest): - sbuffer_nums: list[tuple[SessionKey, int]] - dbuffer_nums: list[tuple[SessionKey, int]] - -@dataclass(frozen=True) -class ProgramAlloc(RemoteRequest): name: str; datahash: str # noqa: E702 - -@dataclass(frozen=True) -class ProgramFree(RemoteRequest): name: str; datahash: str # noqa: E702 - -@dataclass(frozen=True) -class ProgramExec(RemoteRequest): - name: str; datahash: str; bufs: tuple[int, ...]; vals: tuple[int, ...] # noqa: E702 - global_size: tuple[int, ...]|None; local_size: tuple[int, ...]|None; wait: bool # noqa: E702 - -@dataclass(frozen=True) -class GraphComputeItem: - session: SessionKey - name: str - datahash: str - bufs: tuple[int, ...] - vars: tuple[Variable, ...] - fixedvars: dict[str, int] - ins: tuple[int, ...] - outs: tuple[int, ...] - global_size: tuple[sint, ...]|None - local_size: tuple[sint, ...]|None - -@dataclass(frozen=True) -class GraphAlloc(RemoteRequest): - graph_num: int - jit_cache: tuple[GraphComputeItem|Transfer, ...] - bufs: tuple[tuple[SessionKey, int], ...] - var_vals: dict[str, int] - -@dataclass(frozen=True) -class GraphFree(RemoteRequest): - graph_num: int - -@dataclass(frozen=True) -class GraphExec(RemoteRequest): - graph_num: int - bufs: tuple[tuple[SessionKey, int], ...] - var_vals: dict[str, int] - wait: bool - -# for safe deserialization -eval_excs = [v for k,v in builtins.__dict__.items() if isinstance(v, type) and issubclass(v, Exception) and not k.endswith("Warning")] -eval_globals = {x.__name__:x for x in [SessionKey, SessionFree, RemoteProperties, GetProperties, Event, Wait, BufferAlloc, BufferOffset, BufferIOVAS, - BufferFree, CopyIn, CopyOut, Transfer, BatchTransfer, IBConnect, ProgramAlloc, ProgramFree, ProgramExec, - GraphComputeItem, GraphAlloc, GraphFree, GraphExec, BufferSpec, UOp, Ops, dtypes, RemoteException] + eval_excs} -attribute_whitelist: dict[Any, set[str]] = {dtypes: {*DTYPES_DICT.keys(), 'imagef', 'imageh'}, Ops: {x.name for x in Ops}} -eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)), - ast.Dict: lambda x: {safe_eval(k):safe_eval(v) for k,v in zip(x.keys, x.values)}, - ast.Call: lambda x: safe_eval(x.func)(*[safe_eval(arg) for arg in x.args], **{kwarg.arg: safe_eval(kwarg.value) for kwarg in x.keywords}), - ast.Name: lambda x: eval_globals[x.id], ast.Attribute: lambda x: safe_getattr(safe_eval(x.value), x.attr)} -def safe_getattr(value, attr): - assert attr in attribute_whitelist.get(value, set()), f'getattr({value}, {repr(attr)}) is not whitelisted' - return getattr(value, attr) -def safe_eval(node): return eval_fxns[node.__class__](node) - -class BatchRequest: - def __init__(self): - self._q: list[RemoteRequest] = [] - self._h: dict[str, bytes] = {} - def h(self, d:bytes|memoryview) -> str: - datahash = hashlib.sha256(d).hexdigest() # NOTE: this is very slow, should use blake3 on gpu instead - if datahash not in self._h: - self._h[datahash] = bytes.fromhex(datahash)+struct.pack(" bytes: - self.h(repr(self._q).encode()) - return b''.join(self._h.values()) - def deserialize(self, dat:bytes) -> BatchRequest: - ptr = 0 - while ptr < len(dat): - datahash, datalen = binascii.hexlify(dat[ptr:ptr+0x20]).decode(), struct.unpack(" IBConn|None: - if self.ib_ctx is None: return None - await self.ib_lock.acquire() - conn = RemoteConnection(dsession.host) - if dsession.host not in self.ib_conns: - props = safe_eval(ast.parse(conn.q(GetProperties(session=dsession), wait=True), mode="eval").body) - if props.ib_gid is not None: - self.ib_conns[dsession.host] = ib_conn = IBConn(self.ib_ctx) - ibxc_ret = conn.q(IBConnect(ssession.host, ib_conn.gid, ib_conn.qp_num, session=dsession), wait=True) - ib_conn.connect(*struct.unpack('<16sQ', ibxc_ret)) - else: - self.ib_conns[dsession.host] = None - self.ib_lock.release() - return self.ib_conns[dsession.host] - - async def get_iovas(self, bufs:list[tuple[SessionKey, int]]) -> list[tuple[int, int, int]]: - await self.ib_lock.acquire() - if (rbufs:=[buf for buf in bufs if buf not in self.iova_cache]): - conn = RemoteConnection(rbufs[0][0].host) - resp = await conn.aq(BufferIOVAS(rbufs, session=rbufs[0][0]), wait=True) - self.iova_cache.update({rbuf: struct.unpack(' tuple[http.HTTPStatus, bytes]: - status, ret = http.HTTPStatus.OK, b"" - if path == "/batch" and method == "POST": - # TODO: streaming deserialize? - req = BatchRequest().deserialize(body) - # the cmds are always last (currently in datahash) - for c in req._q: - if DEBUG >= 1: print(c) - session, dev = self.sessions[unwrap(c.session)], Device[f"{self.base_device}:{unwrap(c.session).idx}"] - match c: - case SessionFree(): del self.sessions[unwrap(c.session)] - case GetProperties(): - cls, args = dev.renderer.__reduce__() - graph_cls = graph_class(Device[self.base_device]) - rp = RemoteProperties( - real_device=dev.device, renderer=(cls.__module__, cls.__name__, args), offset_supported=hasattr(dev.allocator, '_offset'), - graph_supported=graph_cls is not None, - graph_supports_multi=graph_cls is not None and issubclass(graph_cls, MultiGraphRunner) and hasattr(dev.allocator, '_transfer'), - ib_gid=bytes(self.ib_ctx.gid_attr.raw) if self.ib_ctx is not None else None, - ) - ret = repr(rp).encode() - case Event(): - if c.session == c.event_session: - session.events[c.event].set() - else: - for d in Device._opened_devices: Device[d].synchronize() # wait for device*s* to finish executing previous stuff - # TODO: don't wait, just send - await RemoteConnection(c.event_session.host).aq(Event(c.event_session, c.event, session=c.event_session), wait=True) - case Wait(): - assert await session.events[c.event].wait() - del session.events[c.event] # do not leak memory - case IBConnect(): - self.ib_conns[c.host] = ibc = IBConn(unwrap(self.ib_ctx)) - ibc.connect(c.gid, c.qp_num) - ret = struct.pack('<16sQ', ibc.gid, ibc.qp_num) - case BufferAlloc(): - assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated" - session.buffers[c.buffer_num] = Buffer(dev.device, c.size, dtypes.uint8, options=c.options, preallocate=True) - case BufferIOVAS(): - rets = [] - for buffer_session,buffer_num in c.buffer_nums: - iova, mr = unwrap(self.ib_ctx).reg(buf:=self.sessions[buffer_session].buffers[buffer_num]) - rets.append(struct.pack(" int: - self.dev.q(BufferAlloc(buffer_num:=next(self.dev.buffer_num), size, options)) - return buffer_num - # TODO: options should not be here in any Allocator - def _free(self, opaque:int, options): - try: self.dev.q(BufferFree(opaque)) - except (TypeError, AttributeError): pass - def _copyin(self, dest:int, src:memoryview): self.dev.q(CopyIn(dest, self.dev.conn.req.h(src))) - def _copyout(self, dest:memoryview, src:int): - resp = self.dev.q(CopyOut(src), wait=True) - assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}" - dest[:] = resp - def _transfer(self, dest, src, sz, src_dev, dest_dev): - if dest_dev.conn != src_dev.conn: - dest_dev.q(Event(src_dev.session, start_event:=next(src_dev.event_num))) - src_dev.q(Wait(start_event)) - src_dev.q(Transfer(src, dest_dev.session, dest)) - if dest_dev.conn != src_dev.conn: - src_dev.q(Event(dest_dev.session, end_event:=next(dest_dev.event_num))) - dest_dev.q(Wait(end_event)) - if DEBUG >= 2: dest_dev.conn.batch_submit() - def _dyn_offset(self, opaque:int, size:int, offset:int) -> int: - self.dev.q(BufferOffset(buffer_num:=next(self.dev.buffer_num), size, offset, opaque)) - return buffer_num - -class RemoteProgram: - def __init__(self, dev:RemoteDevice, name:str, lib:bytes): - self.dev, self.name = dev, name - self.datahash = self.dev.conn.req.h(lib) - self.dev.q(ProgramAlloc(self.name, self.datahash)) - super().__init__() - weakref.finalize(self, self._fini, self.dev, self.name, self.datahash) - - @staticmethod - def _fini(dev:RemoteDevice, name:str, datahash:str): dev.q(ProgramFree(name, datahash)) - - def __call__(self, *bufs, global_size=None, local_size=None, vals:tuple[int, ...]=(), wait=False): - ret = self.dev.q(ProgramExec(self.name, self.datahash, bufs, vals, global_size, local_size, wait), wait=wait) - if wait: return float(ret) - -@functools.cache -class RemoteConnection: - q_lock = threading.Lock() - all: dict[RemoteConnection, None] = {} # dict instead of set for deterministic ordering - - def __init__(self, host:str): - if DEBUG >= 1: print(f"remote with host {host}") - while 1: - try: - self.conn = http.client.HTTPConnection(host, timeout=getenv("REMOTE_TIMEOUT", 300.0)) - self.conn.connect() - break - except Exception as e: - print(e) - time.sleep(0.1) - self.req: BatchRequest = BatchRequest() - RemoteConnection.all[self] = None - - def q(self, x:RemoteRequest, wait:bool=False): - with RemoteConnection.q_lock: - self.req.q(x) - if wait: return self.batch_submit(take_q=False) - - async def aq(self, x:RemoteRequest, wait:bool=False): return await asyncio.to_thread(self.q, x, wait=wait) - - def batch_submit(self, take_q:bool=True): - if take_q: RemoteConnection.q_lock.acquire() - conns = RemoteConnection.all.keys() - datas = {conn: conn.req.serialize() for conn in conns} - reqs, hashes, hash_datas = sum(len(c.req._q) for c in conns), sum(len(c.req._h) for c in conns), sum(len(data) for data in datas.values()) - ret, resps = None, [] - with Timing(f"*** send {reqs:-3d} requests {hashes:-3d} hashes with len {hash_datas/1024:.2f} kB in ", enabled=DEBUG>=3): - for conn,data in datas.items(): conn.conn.request("POST", "/batch", data) - for conn in datas.keys(): - resp = conn.conn.getresponse() - body = resp.read() - resps.append((conn, resp, body)) - conn.req = BatchRequest() - if take_q: RemoteConnection.q_lock.release() - for conn,resp,body in resps: - match resp.status: - case http.HTTPStatus.OK: pass - case http.HTTPStatus.INTERNAL_SERVER_ERROR: - exc_wrapper = safe_eval(ast.parse(body.decode(), mode="eval").body) - exc_wrapper.exc.add_note(exc_wrapper.trace) - raise exc_wrapper.exc - case code: raise RuntimeError(f"POST /batch failed with {code}: {body.decode()}") - if conn == self: ret = body - return ret - -def parse_hosts(hs:str) -> list[tuple[str, int]]|LazySeq[tuple[str, int]]: - hosts = [(unwrap(h), int(c) if c is not None else c) for h,c in ((h.split("*", maxsplit=1)+[None,])[:2] for h in hs.split(","))] - if len(hosts) == 1 and hosts[0][1] is None: return LazySeq(lambda idx: (hosts[0][0], idx)) - return [(h, i) for h,c in hosts for i in range(unwrap(c))] - -class RemoteDevice(Compiled): - devices = parse_hosts(getenv("HOST", "")) - - def __init__(self, device:str): - host, idx = RemoteDevice.devices[int(device.split(":")[1]) if ":" in device else 0] - - # connection is shared between sessions on the same host - self.session: SessionKey = SessionKey(host or RemoteDevice.local_server(), idx, binascii.hexlify(os.urandom(0x10)).decode()) - self.conn: RemoteConnection = RemoteConnection(self.session.host) - - # state for the session - self.buffer_num: Iterator[int] = itertools.count(0) - self.graph_num: Iterator[int] = itertools.count(0) - self.event_num: Iterator[int] = itertools.count(0) - - self.properties: RemoteProperties = safe_eval(ast.parse(self.q(GetProperties(), wait=True), mode="eval").body) - if DEBUG >= 1: print(f"remote has device {self.properties.real_device}") - # TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer - renderer = self.properties.renderer - if not renderer[0].startswith("tinygrad.") or not renderer[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {renderer}") - renderer_class = fromimport(renderer[0], renderer[1]) # TODO: is this secure? - if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {renderer}") - - graph = fromimport('tinygrad.runtime.graph.remote', "RemoteGraph") if self.properties.graph_supported else None - compilers = CompilerSet([CompilerPair(functools.partial(renderer_class, *renderer[2]), Compiler)]) - super().__init__(device, RemoteAllocator(self), compilers, functools.partial(RemoteProgram, self), graph, id(self.conn)) - self.renderer.device = device - - def finalize(self): - with contextlib.suppress(ConnectionError, http.client.HTTPException): self.q(SessionFree(), wait=True) - - def q(self, x:RemoteRequest, wait:bool=False): return self.conn.q(replace(x, session=self.session), wait=wait) - - @functools.cache - @staticmethod - def local_server(): - multiprocessing.Process(target=remote_server, args=(6667,), name="MainProcess", daemon=True).start() - return "127.0.0.1:6667" - -if __name__ == "__main__": remote_server(getenv("PORT", 6667)) diff --git a/tinygrad/runtime/support/ib.py b/tinygrad/runtime/support/ib.py deleted file mode 100644 index b42ecba5ec..0000000000 --- a/tinygrad/runtime/support/ib.py +++ /dev/null @@ -1,173 +0,0 @@ -from __future__ import annotations -import resource, ctypes, weakref, functools, itertools -from tinygrad.runtime.autogen import ib -from typing import Iterator -from dataclasses import dataclass -from weakref import WeakKeyDictionary -from tinygrad.device import Buffer, DMACPURef, DMAFdRef -from tinygrad.helpers import getenv, round_up, DEBUG - -DEFAULT_PORT, DEFAULT_GID = getenv("DEFAULT_PORT", 1), getenv("DEFAULT_GID", 3) # DEFAULT_GID=0 for RXE -IOVA_ALIGN = resource.getpagesize() - -def checkz(x, ret=None): - if x != 0: raise RuntimeError(f'{x} != 0 (errno {ctypes.get_errno()})') - return ret - -@dataclass(frozen=True) -class SGE: - dst_iova: int - dst_key: int - src_iova: int - src_key: int - size: int - -class IBCtx: - def __init__(self, idx:int): - # Open the device (aka Host Channel Adapter in ib-speak) - devs = ib.ibv_get_device_list(ctypes.byref(ndevs:=ctypes.c_int32())) - if idx >= ndevs.value: raise IndexError(f"{idx} > {ndevs.value}") - self.ctx = ib.ibv_open_device(devs[idx]) - ib.ibv_free_device_list(devs) - - # HACK: remove this (and all usage of `ctx.contents.ops`) when clang2py can deal with `static inline` wrapper-functions - self.vctx = ctypes.cast(ctypes.addressof(self.ctx.contents) - ib.struct_verbs_context.context.offset, ctypes.POINTER(ib.struct_verbs_context)) - - # Get attributes. Something like port_attr.max_msg_sz sound like it might requre taking the min of host's and remote's attributes if they differ - self.device_attr = checkz(ib.ibv_query_device(self.ctx, ctypes.byref(da:=ib.struct_ibv_device_attr())), da) - self.port_attr = checkz(self.vctx.contents.query_port(self.ctx, DEFAULT_PORT, ctypes.byref(pa:=ib.struct_ibv_port_attr()), ctypes.sizeof(pa)), pa) - self.gid_attr = checkz(ib.ibv_query_gid(self.ctx, DEFAULT_PORT, DEFAULT_GID, ctypes.byref(ga:=ib.union_ibv_gid())), ga) - - # Allocate protection domain - self.pd = ib.ibv_alloc_pd(self.ctx) - self.next_iova: int = IOVA_ALIGN # don't start at zero (nullptr) - - # weakref(buf) => (iova, mr, mr_dealloc). mr_dealloc is kept here to avoid double freeing mrs that are deallocated in __del__ - self.mrs: WeakKeyDictionary[Buffer, tuple[int, ctypes._Pointer[ib.struct_ibv_mr], weakref.finalize]] = WeakKeyDictionary() - - # Default soft fd limit is 1024, which is not enough, set soft to hard (maximum allowed by the os) - IBCtx.rlimit_fix() - - def __del__(self): - # must deallocate all mrs in protection domain before deallocating the protection domain - if hasattr(self, "mrs"): [fin() for _,_,fin in self.mrs.values()] - if hasattr(self, "pd"): ib.ibv_dealloc_pd(self.pd) - if hasattr(self, "ctx"): ib.ibv_close_device(self.ctx) - - @functools.cache # run once - @staticmethod - def rlimit_fix(): - soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) - resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard)) - if DEBUG>=2: print(f"IB: Increased fd limit from {soft} to {hard}") - - def alloc_iova(self, size:int, required_offset:int): - iova = round_up(self.next_iova - required_offset, IOVA_ALIGN) + required_offset - self.next_iova = iova + size - return iova - - def reg(self, buf:Buffer) -> tuple[int, ctypes._Pointer[ib.struct_ibv_mr]]: - buf = buf.base - if buf not in self.mrs: - if buf.nbytes > self.device_attr.max_mr_size: raise RuntimeError(f"Buffer too big: {buf.nbytes:#x} > {self.device_attr.max_mr_size:#x}") - if len(self.mrs) >= self.device_attr.max_mr: raise RuntimeError(f"Out of memory region cap: {len(self.mrs)} >= {self.device_attr.max_mr}") - # Local read is implied (but still have to create the memory region, except for short sends/writes with IBV_SEND_INLINE that are inlined by cpu) - mr_flags = ib.IBV_ACCESS_LOCAL_WRITE | ib.IBV_ACCESS_REMOTE_READ | ib.IBV_ACCESS_REMOTE_WRITE - match (dmaref:=buf.as_dmaref()): - case DMACPURef(): - iova = self.alloc_iova(dmaref.size, dmaref.addr % IOVA_ALIGN) - mr = ib.ibv_reg_mr_iova2(self.pd, ctypes.c_void_p(dmaref.addr), dmaref.size, iova, mr_flags) - case DMAFdRef(): - iova = self.alloc_iova(dmaref.size, dmaref.offset % IOVA_ALIGN) - mr = ib.ibv_reg_dmabuf_mr(self.pd, dmaref.offset, dmaref.size, iova, dmaref.fd, mr_flags) - case _: raise RuntimeError(f"Unknown type of dma ref: {dmaref}") - if not mr: raise RuntimeError(f"Couldn't register memory region for {buf} {dmaref} (errno={ctypes.get_errno()})") - self.mrs[buf] = (iova, mr, weakref.finalize(buf, ib.ibv_dereg_mr, mr)) - return self.mrs[buf][0:2] - -class IBConn: - def __init__(self, ctx:IBCtx): - self.ctx = ctx - - # Create Completion Channel. It is a file descriptor that kernel sends notifications through, not a thing in infiniband spec, just linux-ism - self.comp_channel = ib.ibv_create_comp_channel(self.ctx.ctx) - # Create Completion Queue. When a Work Request with signaled flag is completed a Completion Queue Entry is pushed onto this queue - self.cq = ib.ibv_create_cq(self.ctx.ctx, _capacity:=256, _cq_context:=None, self.comp_channel, _comp_vector:=0) - self.pending_wrids: set[int] = set() - self.wrid_num: Iterator[int] = itertools.count(0) # wc_id is uint64, this will never overflow - - # Create Queue Pair. It's the closest thing to a socket in infiniband with QP num being the closest thing to a port, except it's allocated by hca - qp_init_attrs_cap = ib.struct_ibv_qp_cap(max_send_wr=1024, max_recv_wr=64, max_send_sge=8, max_recv_sge=8, max_inline_data=64) - qp_init_attrs = ib.struct_ibv_qp_init_attr(send_cq=self.cq, recv_cq=self.cq, cap=qp_init_attrs_cap, qp_type=ib.IBV_QPT_RC) # Reliable Connection - self.qp = ib.ibv_create_qp(self.ctx.pd, ctypes.byref(qp_init_attrs)) - self.qp_cap = qp_init_attrs.cap - - # The most important thing about QPs is their state, when a new QP is created it's in the RESET state, before it can be properly used it has to go - # through Init, Ready To Receive, Ready To Send. A good docs on QP state machine: https://www.rdmamojo.com/2012/05/05/qp-state-machine/ - - # INIT - qp_access_flags = ib.IBV_ACCESS_REMOTE_WRITE | ib.IBV_ACCESS_REMOTE_READ - qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_INIT, port_num=DEFAULT_PORT, qp_access_flags=qp_access_flags) - checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_PORT | ib.IBV_QP_ACCESS_FLAGS | ib.IBV_QP_PKEY_INDEX)) - - self.gid, self.qp_num = bytes(self.ctx.gid_attr.raw), self.qp.contents.qp_num - - # Exchange GID and QP num with remote. At least in RoCEv2 gid can be guessed from remote's ip, QP num can't. - - def connect(self, remote_gid:bytes, remote_qp_num:int): - # RTR - qp_ah_attr_grh = ib.struct_ibv_global_route(hop_limit=1, dgid=ib.union_ibv_gid(raw=(ctypes.c_ubyte * 16)(*remote_gid)), sgid_index=DEFAULT_GID) - qp_ah_attr = ib.struct_ibv_ah_attr(is_global=1, port_num=DEFAULT_PORT, grh=qp_ah_attr_grh) - qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_RTR, path_mtu=ib.IBV_MTU_4096, dest_qp_num=remote_qp_num, rq_psn=0, max_dest_rd_atomic=1, - min_rnr_timer=12, ah_attr=qp_ah_attr) - checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_PATH_MTU | ib.IBV_QP_DEST_QPN | ib.IBV_QP_RQ_PSN | \ - ib.IBV_QP_MAX_DEST_RD_ATOMIC | ib.IBV_QP_MIN_RNR_TIMER | ib.IBV_QP_AV)) - - # RTS - qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_RTS, timeout=14, retry_cnt=7, rnr_retry=7, sq_psn=0, max_rd_atomic=1) - checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_TIMEOUT | ib.IBV_QP_RETRY_CNT | ib.IBV_QP_RNR_RETRY | ib.IBV_QP_SQ_PSN | \ - ib.IBV_QP_MAX_QP_RD_ATOMIC)) - - def __del__(self): - self.wait_cq() # need to wait for **everything** to complete before it's safe to dealloc queues and stuff - ib.ibv_destroy_qp(self.qp) - ib.ibv_destroy_cq(self.cq) - ib.ibv_destroy_comp_channel(self.comp_channel) - - def next_wrid(self): - self.pending_wrids.add(wrid:=next(self.wrid_num)) - return wrid - - def wait_cq(self, wr_id: int|None=None): - while (wr_id in self.pending_wrids) if wr_id is not None else self.pending_wrids: - if self.ctx.ctx.contents.ops.poll_cq(self.cq, _num_entries:=1, ctypes.byref(wc:=ib.struct_ibv_wc())): - if wc.status != ib.IBV_WC_SUCCESS: - raise RuntimeError(f'Work Request completed with error: wr_id={wc.wr_id} status={ib.enum_ibv_wc_status.get(wc.status, wc.status)}') - self.pending_wrids.remove(wc.wr_id) - - def rdma_write(self, sgl:list[SGE]): - swr: ctypes._Pointer[ib.struct_ibv_send_wr]|None = None - swr_cnt, wr_id = 0, self.next_wrid() - def _post(): - nonlocal swr, swr_cnt, wr_id - if swr is not None: - # The swr can be freed when this returns, the memory that sge points to can be unmapped after work completion is retrieved from cq - checkz(self.ctx.ctx.contents.ops.post_send(self.qp, swr, ctypes.byref(_bad_wr:=ctypes.POINTER(ib.struct_ibv_send_wr)()))) - # TODO: async - self.wait_cq(wr_id) - swr, swr_cnt, wr_id = None, 0, self.next_wrid() - # Everything is in reverse for elegant chaining - for sg in reversed(sgl): - # Message size limit (max 2GB per ib spec, 1GB on tinybox mellanoxes) applies to both scatter-gather entries and entire wrs - for off in reversed(range(0, sg.size, self.ctx.port_attr.max_msg_sz)): - # Scatter-Gather Entry for local memory - sge = ctypes.pointer(ib.struct_ibv_sge(addr=sg.src_iova+off, length=min(sg.size-off, self.ctx.port_attr.max_msg_sz), lkey=sg.src_key)) - # RDMA struct for remote memory - wr = ib.struct_ibv_send_wr_wr(rdma=ib.struct_ibv_send_wr_wr_rdma(remote_addr=sg.dst_iova+off, rkey=sg.dst_key)) - # Signal (with chosen work request id) if it's the last wr (first in the loop since it's reversed) - wid, flags = (wr_id, ib.IBV_SEND_SIGNALED) if swr is None else (0, 0) - # Create Send Request - swr = ctypes.pointer(ib.struct_ibv_send_wr(opcode=ib.IBV_WR_RDMA_WRITE, sg_list=sge, num_sge=1, wr=wr, wr_id=wid, send_flags=flags, next=swr)) - # Flush if queue is being overrun - if (swr_cnt:=swr_cnt + 1) >= self.qp_cap.max_send_wr: _post() - _post()