remove REMOTE=1 (#13722)

* remove REMOTE=1

* leave ibverbs
This commit is contained in:
George Hotz
2025-12-16 15:58:10 -04:00
committed by GitHub
parent 4d8d821f56
commit 4b741e893f
8 changed files with 2 additions and 973 deletions

View File

@@ -554,8 +554,6 @@ jobs:
run: time BENCHMARK_LOG=cifar AMD=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt
- name: Run full CIFAR training steps w 6 GPUS
run: time BENCHMARK_LOG=cifar_6gpu AMD=1 DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee train_cifar_six_gpu.txt
- name: Run full CIFAR training steps w 6 GPUS (REMOTE)
run: time BENCHMARK_LOG=cifar_6gpu_remote REMOTE=1 REMOTEDEV=AMD DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee train_cifar_six_gpu_remote.txt
- uses: actions/upload-artifact@v4
with:
name: Speed (AMD Training)
@@ -567,7 +565,6 @@ jobs:
train_cifar_wino.txt
train_cifar_one_gpu.txt
train_cifar_six_gpu.txt
train_cifar_six_gpu_remote.txt
- name: Run process replay tests
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py

View File

@@ -721,71 +721,6 @@ jobs:
- name: Run process replay tests
uses: ./.github/actions/process-replay
amdremote:
name: Linux (remote)
runs-on: ubuntu-22.04
timeout-minutes: 20
env:
REMOTE: 1
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: linux-remote
deps: testing_minimal
amd: 'true'
llvm: 'true'
opencl: 'true'
- name: Start remote server
run: |
start_server() {
systemd-run --user \
--unit="$1" \
--setenv=REMOTEDEV="$2" \
--setenv=MOCKGPU=1 \
--setenv=PYTHONPATH=. \
--setenv=PORT="$3" \
--working-directory="$(pwd)" \
python tinygrad/runtime/ops_remote.py
}
start_server "remote-server-amd-1" "AMD" 6667
start_server "remote-server-amd-2" "AMD" 6668
start_server "remote-server-gpu" "CL" 7667
start_server "remote-server-cpu" "CPU" 8667
- name: Check Device.DEFAULT and print some source
env:
HOST: 127.0.0.1:6667*6,127.0.0.1:6668*6
run: |
python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT"
python -c "from tinygrad import Device; assert Device.default.properties.real_device == 'AMD', Device.default.properties.real_device"
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
- name: Run REMOTE=1 Test (AMD)
env:
HOST: 127.0.0.1:6667*6,127.0.0.1:6668*6
run: |
python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_remote.py test/test_tensor_variable.py --durations 20
- name: Run REMOTE=1 Test (CL)
env:
HOST: 127.0.0.1:7667*6
run: |
python3 -m pytest test/test_tiny.py test/test_image_dtype.py test/test_jit.py --durations 20
IMAGE=2 python3 -m pytest test/test_tiny.py test/test_image_dtype.py
- name: Run REMOTE=1 Test (CPU)
env:
HOST: 127.0.0.1:8667*6
run: |
python3 -m pytest test/test_tiny.py test/test_jit.py test/test_multitensor.py --durations 20
- name: Show remote server logs
if: always()
run: |
journalctl --user -u remote-server-amd-1 --no-pager
journalctl --user -u remote-server-amd-2 --no-pager
journalctl --user -u remote-server-gpu --no-pager
journalctl --user -u remote-server-cpu --no-pager
# ****** OSX Tests ******
testmetal:
@@ -883,30 +818,6 @@ jobs:
- name: Test ONNX Runner (WEBGPU)
run: WEBGPU=1 python3 test/external/external_test_onnx_runner.py
osxremote:
name: MacOS (remote metal)
runs-on: macos-15
timeout-minutes: 10
env:
REMOTE: 1
REMOTEDEV: METAL
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: macos-remote
deps: testing_minimal
- name: Check Device.DEFAULT and print some source
run: |
python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT"
python -c "from tinygrad import Device; assert Device.default.properties.real_device == 'METAL', Device.default.properties.real_device"
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
- name: Run REMOTE=1 Test
run: |
python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_tensor_variable.py
osxtests:
strategy:
fail-fast: false

View File

@@ -8,7 +8,7 @@ def multidevice_test(fxn):
def ret(self):
for device in Device._devices:
# broken on OSX USB AMD, why?
if device in ["REMOTE", "DISK", "NPY", "FAKE", "DSP", "NULL"] or (OSX and device in ["AMD"]): continue
if device in ["DISK", "NPY", "FAKE", "DSP", "NULL"] or (OSX and device in ["AMD"]): continue
if not CI: print(device)
if device in exclude_devices:
if not CI: print(f"WARNING: {device} test is excluded")

View File

@@ -69,5 +69,4 @@ def needs_second_gpu(fn):
return fn(self, *args, **kwargs)
return wrapper
# NOTE: This will open REMOTE if it's the default device
REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties.real_device)
REAL_DEV = Device.DEFAULT

View File

@@ -1,101 +0,0 @@
import numpy as np, unittest, string
from hypothesis import given, strategies as st
from tinygrad import Device, Tensor, TinyJit, dtypes
from tinygrad.runtime.ops_remote import RemoteDevice, parse_hosts
from tinygrad.runtime.graph.remote import RemoteGraph
from tinygrad.helpers import LazySeq, all_same, Context
def multihost_env(devices):
def same_hosts(devices): return all_same([h for h,_ in devices])
return isinstance(devices, list) and len(devices) >= 12 and not same_hosts(devices[0:12]) and same_hosts(devices[0:6]) and same_hosts(devices[6:12])
@unittest.skipUnless(Device.DEFAULT == "REMOTE" and multihost_env(RemoteDevice.devices), "Requires special environment")
class TestRemoteMultiHost(unittest.TestCase):
def test_mutlihost_transfer(self):
a = Tensor.arange(0, 16, device='REMOTE:0').contiguous().realize()
b = a.to('REMOTE:6').contiguous().realize()
np.testing.assert_equal(b.numpy(), np.arange(0, 16))
@Context(JIT_BATCH_SIZE=2**32)
@unittest.skip("kernel must all be multibuffer")
def test_multihost_matmul_jit_graph(self):
@TinyJit
def do(a:Tensor, b:Tensor): return (a @ b).contiguous().realize()
ds = ('REMOTE:0', 'REMOTE:1', 'REMOTE:6', 'REMOTE:7')
for _ in range(3):
na, nb = np.random.rand(128, 128).astype(np.float32), np.random.rand(128, 128).astype(np.float32)
a, b = Tensor(na).shard(ds, 0).contiguous().realize(), Tensor(nb).shard(ds, 0).contiguous().realize()
nc = na @ nb
c = do(a, b)
np.testing.assert_allclose(nc, c.numpy(), rtol=3e-2, atol=1e-4) # tolerances from extra/gemm/simple_matmul.py
# Verify that everything is in one big cross-host graph
assert len(do.captured._jit_cache) == 1 and isinstance(do.captured._jit_cache[0].prg, RemoteGraph), repr(do.captured)
@Context(JIT_BATCH_SIZE=2**32)
@unittest.skip("assign target and input devices mismatch")
def test_multihost_aware_schedule(self):
@TinyJit
def do(*ts:Tensor):
acc = Tensor.zeros(1, dtype=dtypes.float32).contiguous().realize()
for t in ts: acc += t.sum()
return acc.realize()
def do_np(*ts:np.ndarray):
acc = np.zeros(1, np.float32)
for t in ts: acc += t.sum()
return acc
ds = ('REMOTE:0', 'REMOTE:1', 'REMOTE:6', 'REMOTE:7')
TS = 64
for _ in range(3):
inp_np = [np.random.rand(256).astype(np.float32) for _ in range(TS)]
inp = [Tensor(inp).shard(ds, 0).contiguous().realize() for inp in inp_np]
out_np = do_np(*inp_np)
out = do(*inp)
np.testing.assert_allclose(out_np, out.numpy(), rtol=3e-2, atol=1e-4)
# Verify that everything is in one big cross-host graph and that the scheduling is reasonable
assert len(do.captured._jit_cache) == 1 and isinstance(do.captured._jit_cache[0].prg, RemoteGraph), repr(do.captured)
# At the time of writing this: 2050 graph breaks without multihost aware scheduling, 14 with it. I've set fail threshold to 28 to not fail on
# unrelated scheduling changes. Maybe 2x is a bit too pessimistic, but remote should perform just fine as long as this is not like a half hundred
# or more here.
self.assertLess(len(do.captured._jit_cache[0].prg.template), 28, "Very bad scheduling! Many unnecesary graph breaks!")
class TestParseHosts(unittest.TestCase):
def assert_seq(self, result:LazySeq, host:str):
self.assertIsInstance(result, LazySeq)
for i in [0, 1, 5, 10]: self.assertEqual(result[i], (host, i))
@given(st.sampled_from(["", "localhost", "192.168.1.1:8080", "host"]))
def test_single_host_no_count(self, host:str):
self.assert_seq(parse_hosts(host), host)
@given(host=st.sampled_from(["localhost", "host", "192.168.1.1:8080"]), count=st.integers(0, 10))
def test_single_host_with_count(self, host:str, count:int):
self.assertEqual(parse_hosts(f"{host}*{count}"), [(host, i) for i in range(count)])
def test_multiple_hosts_with_counts_simple(self):
self.assertEqual(parse_hosts("host1*2,host2*3"), [("host1", i) for i in range(2)] + [("host2", i) for i in range(3)])
@given(st.lists(st.tuples(st.text(alphabet=string.ascii_letters + string.digits + ".-:"), st.integers(1, 16)), min_size=1))
def test_multiple_hosts_with_counts_sampled(self, host_count_pairs):
hosts_str = ",".join(f"{host}*{count}" for host, count in host_count_pairs)
expected = [(host, i) for host, count in host_count_pairs for i in range(count)]
self.assertEqual(parse_hosts(hosts_str), expected)
@given(st.sampled_from(["host1*2,host2", "a*1,b", "x*3,y*2,z"]))
def test_mixed_hosts_fails(self, hosts):
with self.assertRaises(AssertionError): parse_hosts(hosts)
@given(st.sampled_from(["host*abc", "test*xyz", "a*1.5"]))
def test_invalid_count_fails(self, hosts):
with self.assertRaises(ValueError): parse_hosts(hosts)
@given(st.sampled_from(["host*2*3", "a*1*2*3", "test*x*y"]))
def test_multiple_asterisks_fails(self, hosts):
with self.assertRaises(ValueError): parse_hosts(hosts)
if __name__ == '__main__':
unittest.main()

View File

@@ -1,113 +0,0 @@
import time, itertools
from tinygrad.engine.jit import MultiGraphRunner
from tinygrad.engine.realize import CompiledRunner, BufferXfer, ExecItem
from tinygrad.device import Device, Compiled, Buffer
from tinygrad.runtime.ops_remote import RemoteDevice, RemoteConnection, RemoteRequest, GraphComputeItem, Transfer, GraphAlloc, GraphFree, GraphExec
from tinygrad.runtime.ops_remote import BatchTransfer, Event, Wait
from tinygrad.helpers import unwrap, flatten, dedup
from enum import Enum, auto
from dataclasses import replace
from collections import defaultdict
from typing import cast
class StagingType(Enum): NONE = auto(); GRAPH = auto(); TRANSFER = auto() # noqa: E702
def rd(dev:Compiled) -> RemoteDevice: return cast(RemoteDevice, dev)
def dev_key(dev:RemoteDevice): return dev.conn if dev.properties.graph_supports_multi else dev
def map_rawbuf(rawbuf:Buffer): return (cast(RemoteDevice, Device[rawbuf.device]).session, rawbuf._buf)
class RemoteGraph(MultiGraphRunner):
def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[str, int]):
super().__init__(jit_cache, rawbufs, var_vals)
devices = dedup(flatten([[Device[unwrap(buf).device] for buf in ji.bufs] for ji in jit_cache]))
c2d = {device.conn: device for device in devices}
self.handle_indexes = {map_rawbuf(rawbufs[i]): i for i in sorted(dedup(self.input_replace.values()))}
self.template: list[RemoteRequest] = []
stagings: dict[RemoteDevice|RemoteConnection, list[GraphComputeItem|Transfer]] = defaultdict(list)
clobbered_buffers: set[Buffer] = set()
cur_staging_type: StagingType = StagingType.NONE
def _flush(new_staging_type:StagingType, force_break:bool=False):
nonlocal cur_staging_type
if cur_staging_type == new_staging_type and not force_break: return
# Pre-sync
if cur_staging_type == StagingType.TRANSFER:
for sdev,ddev in itertools.permutations(c2d.values(), 2):
self.template.append(Event(ddev.session, event:=next(ddev.event_num), session=sdev.session))
self.template.append(Wait(event, session=ddev.session))
# Flush
for dev in devices:
dk = dev_key(dev)
staging = stagings[dk]
if not staging: continue
match cur_staging_type:
case StagingType.GRAPH:
bufs = tuple(map_rawbuf(rawbufs[i]) for i in sorted(dedup(self.input_replace.values())) if dev_key(rd(Device[rawbufs[i].device])) == dk)
dev.q(GraphAlloc(graph_num:=next(dev.graph_num), tuple(staging), tuple(bufs), var_vals))
self.template.append(GraphExec(graph_num, bufs, var_vals, wait=False, session=dev.session))
case StagingType.TRANSFER:
st = cast(list[Transfer], staging)
for host in dedup(t.dsession.host for t in st):
sbuffer_nums = [(unwrap(t.session), t.buffer_num) for t in st if t.dsession.host == host]
dbuffer_nums = [(t.dsession, t.dbuffer_num) for t in st if t.dsession.host == host]
self.template.append(BatchTransfer(sbuffer_nums, dbuffer_nums, session=dev.session))
staging.clear()
# Post-sync
if cur_staging_type == StagingType.TRANSFER:
for sdev,ddev in itertools.permutations(c2d.values(), 2):
self.template.append(Event(ddev.session, event:=next(ddev.event_num), session=sdev.session))
self.template.append(Wait(event, session=ddev.session))
cur_staging_type = new_staging_type
clobbered_buffers.clear()
for ji in jit_cache:
match ji.prg:
case CompiledRunner():
_flush(StagingType.GRAPH)
gi = GraphComputeItem(ji.prg.dev.session, ji.prg._prg.name, ji.prg._prg.datahash, tuple(unwrap(buf)._buf for buf in ji.bufs),
tuple(ji.prg.p.vars), ji.fixedvars, tuple(ji.prg.p.ins), tuple(ji.prg.p.outs),
tuple(ji.prg.p.global_size) if ji.prg.p.global_size is not None else None,
tuple(ji.prg.p.local_size) if ji.prg.p.local_size is not None else None)
stagings[dev_key(ji.prg.dev)].append(gi)
case BufferXfer():
dest, src = ji.bufs[0:2]
dest_dev, src_dev = cast(RemoteDevice, Device[unwrap(dest).device]), cast(RemoteDevice, Device[unwrap(src).device])
assert dest is not None and src is not None, ji
ti = Transfer(session=src_dev.session, buffer_num=src._buf, dsession=dest_dev.session, dbuffer_num=dest._buf)
if dev_key(dest_dev) == dev_key(src_dev):
_flush(StagingType.GRAPH)
stagings[dev_key(src_dev)].append(ti)
elif dest_dev.conn == src_dev.conn:
_flush(StagingType.NONE)
self.template.append(ti)
else:
_flush(StagingType.TRANSFER, force_break=src in clobbered_buffers)
clobbered_buffers.add(dest)
stagings[dev_key(src_dev)].append(ti)
case _: raise NotImplementedError(ji.prg)
_flush(StagingType.NONE)
def __del__(self):
for req in self.template:
match req:
case GraphExec(): RemoteConnection(unwrap(req.session).host).q(GraphFree(req.graph_num, session=req.session))
def __call__(self, rawbufs: list[Buffer], var_vals: dict[str, int], wait=False):
if wait: st = time.perf_counter()
rmap = {orig: map_rawbuf(rawbufs[replace_idx]) for orig,replace_idx in self.handle_indexes.items()}
for req in self.template:
match req:
case GraphExec():
req = replace(req, bufs=tuple(rmap[buf] for buf in req.bufs), var_vals=var_vals, wait=wait)
case Transfer():
if (req.session, req.buffer_num) in rmap: req = replace(req, buffer_num=rmap[(req.session, req.buffer_num)][1])
if (req.dsession, req.dbuffer_num) in rmap: req = replace(req, dbuffer_num=rmap[(req.dsession, req.dbuffer_num)][1])
case BatchTransfer():
req = replace(req, sbuffer_nums=[rmap.get(b, b) for b in req.sbuffer_nums], dbuffer_nums=[rmap.get(b, b) for b in req.dbuffer_nums])
case Event()|Wait():
pass # event number can be reused
case _: raise NotImplementedError(req)
RemoteConnection(unwrap(req.session).host).q(req)
if wait:
RemoteConnection(unwrap(req.session).host).batch_submit()
return time.perf_counter() - st

View File

@@ -1,491 +0,0 @@
# the REMOTE=1 device is a process boundary between the frontend/runtime
# normally tinygrad is frontend <-> middleware <-> runtime <-> hardware
# with REMOTE tinygrad is frontend <-> middleware <-> RemoteDevice ///HTTP/// remote_server <-> runtime <-> hardware
# this client and server can be on the same machine, same network, or just same internet
# it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC
from __future__ import annotations
from typing import Callable, Iterator, Any, cast
from collections import defaultdict
from dataclasses import dataclass, field, replace
import multiprocessing, threading, functools, itertools, asyncio, http, http.client, hashlib, time, os, binascii, struct, ast, contextlib, weakref
import traceback, builtins
from tinygrad.renderer import Renderer, ProgramSpec
from tinygrad.dtype import DTYPES_DICT, dtypes
from tinygrad.uop.ops import UOp, Ops, Variable, sint
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, LazySeq, Timing
from tinygrad.engine.jit import GraphRunner, MultiGraphRunner, ExecItem, graph_class
from tinygrad.engine.realize import CompiledRunner, BufferXfer
from tinygrad.device import Compiled, Buffer, Allocator, Compiler, Device, BufferSpec, CompilerSet, CompilerPair
from tinygrad.runtime.support.ib import IBCtx, IBConn, SGE
# ***** API *****
@dataclass(frozen=True)
class SessionKey: host: str; idx: int; nonce: str # noqa: E702
@dataclass(frozen=True)
class RemoteRequest: session: SessionKey|None = field(default=None, kw_only=True)
@dataclass(frozen=True)
class SessionFree(RemoteRequest): pass
@dataclass(frozen=True)
class RemoteProperties:
real_device: str
renderer: tuple[str, str, tuple[Any, ...]]
offset_supported: bool
graph_supported: bool
graph_supports_multi: bool
ib_gid: bytes|None
@dataclass(frozen=True)
class RemoteException:
exc: Exception
trace: str = ""
@dataclass(frozen=True)
class GetProperties(RemoteRequest): pass
@dataclass(frozen=True)
class Event(RemoteRequest): event_session: SessionKey; event: int # noqa: E702
@dataclass(frozen=True)
class Wait(RemoteRequest): event: int
@dataclass(frozen=True)
class IBConnect(RemoteRequest): host: str; gid: bytes; qp_num: int # noqa: E702
@dataclass(frozen=True)
class BufferAlloc(RemoteRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702
@dataclass(frozen=True)
class BufferOffset(RemoteRequest): buffer_num: int; size: int; offset: int; sbuffer_num: int # noqa: E702
@dataclass(frozen=True)
class BufferIOVAS(RemoteRequest): buffer_nums: list[tuple[SessionKey, int]] # noqa: E702
@dataclass(frozen=True)
class BufferFree(RemoteRequest): buffer_num: int # noqa: E702
@dataclass(frozen=True)
class CopyIn(RemoteRequest): buffer_num: int; datahash: str # noqa: E702
@dataclass(frozen=True)
class CopyOut(RemoteRequest): buffer_num: int
@dataclass(frozen=True)
class Transfer(RemoteRequest): buffer_num: int; dsession: SessionKey; dbuffer_num: int # noqa: E702
@dataclass(frozen=True)
class BatchTransfer(RemoteRequest):
sbuffer_nums: list[tuple[SessionKey, int]]
dbuffer_nums: list[tuple[SessionKey, int]]
@dataclass(frozen=True)
class ProgramAlloc(RemoteRequest): name: str; datahash: str # noqa: E702
@dataclass(frozen=True)
class ProgramFree(RemoteRequest): name: str; datahash: str # noqa: E702
@dataclass(frozen=True)
class ProgramExec(RemoteRequest):
name: str; datahash: str; bufs: tuple[int, ...]; vals: tuple[int, ...] # noqa: E702
global_size: tuple[int, ...]|None; local_size: tuple[int, ...]|None; wait: bool # noqa: E702
@dataclass(frozen=True)
class GraphComputeItem:
session: SessionKey
name: str
datahash: str
bufs: tuple[int, ...]
vars: tuple[Variable, ...]
fixedvars: dict[str, int]
ins: tuple[int, ...]
outs: tuple[int, ...]
global_size: tuple[sint, ...]|None
local_size: tuple[sint, ...]|None
@dataclass(frozen=True)
class GraphAlloc(RemoteRequest):
graph_num: int
jit_cache: tuple[GraphComputeItem|Transfer, ...]
bufs: tuple[tuple[SessionKey, int], ...]
var_vals: dict[str, int]
@dataclass(frozen=True)
class GraphFree(RemoteRequest):
graph_num: int
@dataclass(frozen=True)
class GraphExec(RemoteRequest):
graph_num: int
bufs: tuple[tuple[SessionKey, int], ...]
var_vals: dict[str, int]
wait: bool
# for safe deserialization
eval_excs = [v for k,v in builtins.__dict__.items() if isinstance(v, type) and issubclass(v, Exception) and not k.endswith("Warning")]
eval_globals = {x.__name__:x for x in [SessionKey, SessionFree, RemoteProperties, GetProperties, Event, Wait, BufferAlloc, BufferOffset, BufferIOVAS,
BufferFree, CopyIn, CopyOut, Transfer, BatchTransfer, IBConnect, ProgramAlloc, ProgramFree, ProgramExec,
GraphComputeItem, GraphAlloc, GraphFree, GraphExec, BufferSpec, UOp, Ops, dtypes, RemoteException] + eval_excs}
attribute_whitelist: dict[Any, set[str]] = {dtypes: {*DTYPES_DICT.keys(), 'imagef', 'imageh'}, Ops: {x.name for x in Ops}}
eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
ast.Dict: lambda x: {safe_eval(k):safe_eval(v) for k,v in zip(x.keys, x.values)},
ast.Call: lambda x: safe_eval(x.func)(*[safe_eval(arg) for arg in x.args], **{kwarg.arg: safe_eval(kwarg.value) for kwarg in x.keywords}),
ast.Name: lambda x: eval_globals[x.id], ast.Attribute: lambda x: safe_getattr(safe_eval(x.value), x.attr)}
def safe_getattr(value, attr):
assert attr in attribute_whitelist.get(value, set()), f'getattr({value}, {repr(attr)}) is not whitelisted'
return getattr(value, attr)
def safe_eval(node): return eval_fxns[node.__class__](node)
class BatchRequest:
def __init__(self):
self._q: list[RemoteRequest] = []
self._h: dict[str, bytes] = {}
def h(self, d:bytes|memoryview) -> str:
datahash = hashlib.sha256(d).hexdigest() # NOTE: this is very slow, should use blake3 on gpu instead
if datahash not in self._h:
self._h[datahash] = bytes.fromhex(datahash)+struct.pack("<Q", len(d))+bytes(d)
return datahash
def q(self, x:RemoteRequest): self._q.append(x)
def serialize(self) -> bytes:
self.h(repr(self._q).encode())
return b''.join(self._h.values())
def deserialize(self, dat:bytes) -> BatchRequest:
ptr = 0
while ptr < len(dat):
datahash, datalen = binascii.hexlify(dat[ptr:ptr+0x20]).decode(), struct.unpack("<Q", dat[ptr+0x20:ptr+0x28])[0]
self._h[datahash] = dat[ptr+0x28:ptr+0x28+datalen]
ptr += 0x28+datalen
self._q = safe_eval(ast.parse(self._h[datahash], mode="eval").body)
return self
# ***** backend *****
@dataclass
class RemoteSession:
programs: dict[tuple[str, str], Any] = field(default_factory=dict)
graphs: dict[int, GraphRunner] = field(default_factory=dict)
buffers: dict[int, Buffer] = field(default_factory=dict)
events: defaultdict[int, asyncio.Event] = field(default_factory=functools.partial(defaultdict, asyncio.Event))
class RemoteHandler:
def __init__(self, base_device: str):
self.base_device = base_device
self.sessions: defaultdict[SessionKey, RemoteSession] = defaultdict(RemoteSession)
try: self.ib_ctx: IBCtx|None = IBCtx(getenv("IB_DEV", 0))
except (RuntimeError, IndexError, AttributeError): self.ib_ctx = None
self.ib_lock = asyncio.Lock()
self.ib_conns: dict[str, IBConn|None] = {}
self.iova_cache: dict[tuple[SessionKey, int], tuple[int, int, int]] = {}
async def __call__(self, reader:asyncio.StreamReader, writer:asyncio.StreamWriter):
while (req_hdr:=(await reader.readline()).decode().strip()):
req_method, req_path, _ = req_hdr.split(' ')
req_headers = {}
while (hdr:=(await reader.readline()).decode().strip()):
key, value = hdr.split(':', 1)
req_headers[key.lower()] = value.strip()
req_body = await reader.readexactly(int(req_headers.get("content-length", "0")))
try: res_status, res_body = await self.handle(req_method, req_path, req_body)
except Exception as e:
res_status, res_body = http.HTTPStatus.INTERNAL_SERVER_ERROR, repr(RemoteException(e, traceback.format_exc())).encode()
print(f"{traceback.format_exc()}", flush=True)
writer.write(f"HTTP/1.1 {res_status.value} {res_status.phrase}\r\nContent-Length: {len(res_body)}\r\n\r\n".encode() + res_body)
async def ib_connect(self, ssession:SessionKey, dsession:SessionKey) -> IBConn|None:
if self.ib_ctx is None: return None
await self.ib_lock.acquire()
conn = RemoteConnection(dsession.host)
if dsession.host not in self.ib_conns:
props = safe_eval(ast.parse(conn.q(GetProperties(session=dsession), wait=True), mode="eval").body)
if props.ib_gid is not None:
self.ib_conns[dsession.host] = ib_conn = IBConn(self.ib_ctx)
ibxc_ret = conn.q(IBConnect(ssession.host, ib_conn.gid, ib_conn.qp_num, session=dsession), wait=True)
ib_conn.connect(*struct.unpack('<16sQ', ibxc_ret))
else:
self.ib_conns[dsession.host] = None
self.ib_lock.release()
return self.ib_conns[dsession.host]
async def get_iovas(self, bufs:list[tuple[SessionKey, int]]) -> list[tuple[int, int, int]]:
await self.ib_lock.acquire()
if (rbufs:=[buf for buf in bufs if buf not in self.iova_cache]):
conn = RemoteConnection(rbufs[0][0].host)
resp = await conn.aq(BufferIOVAS(rbufs, session=rbufs[0][0]), wait=True)
self.iova_cache.update({rbuf: struct.unpack('<QQQ', resp[i*24:(i+1)*24]) for i,rbuf in enumerate(rbufs)})
self.ib_lock.release()
return [self.iova_cache[buf] for buf in bufs]
async def handle(self, method:str, path:str, body:bytes) -> tuple[http.HTTPStatus, bytes]:
status, ret = http.HTTPStatus.OK, b""
if path == "/batch" and method == "POST":
# TODO: streaming deserialize?
req = BatchRequest().deserialize(body)
# the cmds are always last (currently in datahash)
for c in req._q:
if DEBUG >= 1: print(c)
session, dev = self.sessions[unwrap(c.session)], Device[f"{self.base_device}:{unwrap(c.session).idx}"]
match c:
case SessionFree(): del self.sessions[unwrap(c.session)]
case GetProperties():
cls, args = dev.renderer.__reduce__()
graph_cls = graph_class(Device[self.base_device])
rp = RemoteProperties(
real_device=dev.device, renderer=(cls.__module__, cls.__name__, args), offset_supported=hasattr(dev.allocator, '_offset'),
graph_supported=graph_cls is not None,
graph_supports_multi=graph_cls is not None and issubclass(graph_cls, MultiGraphRunner) and hasattr(dev.allocator, '_transfer'),
ib_gid=bytes(self.ib_ctx.gid_attr.raw) if self.ib_ctx is not None else None,
)
ret = repr(rp).encode()
case Event():
if c.session == c.event_session:
session.events[c.event].set()
else:
for d in Device._opened_devices: Device[d].synchronize() # wait for device*s* to finish executing previous stuff
# TODO: don't wait, just send
await RemoteConnection(c.event_session.host).aq(Event(c.event_session, c.event, session=c.event_session), wait=True)
case Wait():
assert await session.events[c.event].wait()
del session.events[c.event] # do not leak memory
case IBConnect():
self.ib_conns[c.host] = ibc = IBConn(unwrap(self.ib_ctx))
ibc.connect(c.gid, c.qp_num)
ret = struct.pack('<16sQ', ibc.gid, ibc.qp_num)
case BufferAlloc():
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
session.buffers[c.buffer_num] = Buffer(dev.device, c.size, dtypes.uint8, options=c.options, preallocate=True)
case BufferIOVAS():
rets = []
for buffer_session,buffer_num in c.buffer_nums:
iova, mr = unwrap(self.ib_ctx).reg(buf:=self.sessions[buffer_session].buffers[buffer_num])
rets.append(struct.pack("<QQQ", iova, mr.contents.rkey, buf.nbytes))
ret = b"".join(rets)
case BufferOffset():
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already exists"
session.buffers[c.buffer_num] = session.buffers[c.sbuffer_num].view(c.size, dtypes.uint8, c.offset).allocate()
case BufferFree(): del session.buffers[c.buffer_num]
case CopyIn(): session.buffers[c.buffer_num].copyin(memoryview(bytearray(req._h[c.datahash])))
case CopyOut(): session.buffers[c.buffer_num].copyout(memoryview(ret:=bytearray(session.buffers[c.buffer_num].nbytes)))
case Transfer():
if c.dsession.host == unwrap(c.session).host:
dsession, ddev = self.sessions[c.dsession], Device[f"{self.base_device}:{unwrap(c.dsession).idx}"]
dbuf, sbuf = dsession.buffers[c.dbuffer_num], session.buffers[c.buffer_num]
if hasattr(ddev.allocator, '_transfer'):
assert dbuf.nbytes == sbuf.nbytes, f"{dbuf.nbytes} != {sbuf.nbytes}"
ddev.allocator._transfer(dbuf._buf, sbuf._buf, dbuf.nbytes, dest_dev=ddev, src_dev=dev)
else:
sbuf.copyout(data:=memoryview(bytearray(sbuf.nbytes)))
dbuf.copyin(data)
else:
conn, ib_conn = RemoteConnection(c.dsession.host), await self.ib_connect(unwrap(c.session), c.dsession)
sbuf = session.buffers[c.buffer_num]
if ib_conn is not None:
src_iova, src_mr = unwrap(self.ib_ctx).reg(sbuf)
dst_iova, dst_key, dst_size = (await self.get_iovas([(c.dsession, c.dbuffer_num)]))[0]
assert sbuf.nbytes == dst_size, f"{sbuf.nbytes} != {dst_size}"
for d in Device._opened_devices: Device[d].synchronize()
ib_conn.rdma_write([SGE(dst_iova, dst_key, src_iova, src_mr.contents.lkey, dst_size)])
else:
sbuf.copyout(data:=memoryview(bytearray(sbuf.nbytes)))
await conn.aq(CopyIn(c.dbuffer_num, conn.req.h(data), session=c.dsession), wait=True)
case BatchTransfer():
conn, ib_conn = RemoteConnection(c.dbuffer_nums[0][0].host), await self.ib_connect(c.sbuffer_nums[0][0], c.dbuffer_nums[0][0])
if ib_conn is not None:
sbufs = [unwrap(self.ib_ctx).reg(self.sessions[s].buffers[bi]) for s,bi in c.sbuffer_nums]
dbufs = await self.get_iovas(c.dbuffer_nums)
for d in Device._opened_devices: Device[d].synchronize()
ib_conn.rdma_write([SGE(di, dk, si, sm.contents.lkey, ds) for (di,dk,ds),(si,sm) in zip(dbufs, sbufs)])
else:
for (sbuf_session,sbuf_num),(dbuf_session,dbuf_num) in zip(c.sbuffer_nums, c.dbuffer_nums):
sbuf = self.sessions[sbuf_session].buffers[sbuf_num]
sbuf.copyout(data:=memoryview(bytearray(sbuf.nbytes)))
await conn.aq(CopyIn(dbuf_num, conn.req.h(data), session=dbuf_session), wait=True)
case ProgramAlloc():
lib = dev.compiler.compile_cached(req._h[c.datahash].decode())
session.programs[(c.name, c.datahash)] = dev.runtime(c.name, lib)
case ProgramFree():
key = (c.name, c.datahash)
# WORKAROUND: should be unconditional once the protocol supports proper exception handling
if key in session.programs: del session.programs[key]
case ProgramExec():
bufs = [session.buffers[x]._buf for x in c.bufs]
extra_args = {k:v for k,v in [("global_size", c.global_size), ("local_size", c.local_size)] if v is not None}
r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args)
if r is not None: ret = str(r).encode()
case GraphAlloc():
graph_fn: Callable = unwrap(dev.graph)
def _parse_ji(gi: GraphComputeItem|Transfer):
match gi:
case GraphComputeItem():
prg = self.sessions[gi.session].programs[(gi.name, gi.datahash)]
ps = ProgramSpec(gi.name, '', f"{self.base_device}:{gi.session.idx}", UOp(Ops.NOOP),
vars=list(gi.vars), ins=list(gi.ins), outs=list(gi.outs),
global_size=list(cast(tuple[int], gi.global_size)) if gi.global_size is not None else None,
local_size=list(cast(tuple[int], gi.local_size)) if gi.local_size is not None else None)
return ExecItem(CompiledRunner(ps, precompiled=b'', prg=prg), [self.sessions[gi.session].buffers[buf] for buf in gi.bufs],
fixedvars=gi.fixedvars)
case Transfer():
dbuf, sbuf = self.sessions[gi.dsession].buffers[gi.dbuffer_num], self.sessions[unwrap(gi.session)].buffers[gi.buffer_num]
assert dbuf.nbytes == sbuf.nbytes, f"{dbuf.nbytes} != {sbuf.nbytes}"
return ExecItem(BufferXfer(dbuf.nbytes, dbuf.device, sbuf.device), [dbuf, sbuf])
assert c.graph_num not in session.graphs, f"graph {c.graph_num} already allocated"
session.graphs[c.graph_num] = graph_fn(list(map(_parse_ji, c.jit_cache)), [self.sessions[s].buffers[i] for s,i in c.bufs], c.var_vals)
case GraphFree(): del session.graphs[c.graph_num]
case GraphExec():
r = session.graphs[c.graph_num]([self.sessions[s].buffers[i] for s,i in c.bufs], c.var_vals, wait=c.wait)
if r is not None: ret = str(r).encode()
else: status, ret = http.HTTPStatus.NOT_FOUND, b"Not Found"
return status, ret
def remote_server(port:int):
device = getenv("REMOTEDEV", next(Device.get_available_devices()) if Device.DEFAULT == "REMOTE" else Device.DEFAULT)
async def _inner_async(port:int, device:str):
print(f"start remote server on {port} with device {device}")
await (await asyncio.start_server(RemoteHandler(device), host='', port=port)).serve_forever()
asyncio.run(_inner_async(port, device))
# ***** frontend *****
class RemoteAllocator(Allocator['RemoteDevice']):
def __init__(self, dev:RemoteDevice):
if dev.properties.offset_supported: self._offset = self._dyn_offset
super().__init__(dev)
# TODO: ideally we shouldn't have to deal with images here
def _alloc(self, size:int, options:BufferSpec) -> int:
self.dev.q(BufferAlloc(buffer_num:=next(self.dev.buffer_num), size, options))
return buffer_num
# TODO: options should not be here in any Allocator
def _free(self, opaque:int, options):
try: self.dev.q(BufferFree(opaque))
except (TypeError, AttributeError): pass
def _copyin(self, dest:int, src:memoryview): self.dev.q(CopyIn(dest, self.dev.conn.req.h(src)))
def _copyout(self, dest:memoryview, src:int):
resp = self.dev.q(CopyOut(src), wait=True)
assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}"
dest[:] = resp
def _transfer(self, dest, src, sz, src_dev, dest_dev):
if dest_dev.conn != src_dev.conn:
dest_dev.q(Event(src_dev.session, start_event:=next(src_dev.event_num)))
src_dev.q(Wait(start_event))
src_dev.q(Transfer(src, dest_dev.session, dest))
if dest_dev.conn != src_dev.conn:
src_dev.q(Event(dest_dev.session, end_event:=next(dest_dev.event_num)))
dest_dev.q(Wait(end_event))
if DEBUG >= 2: dest_dev.conn.batch_submit()
def _dyn_offset(self, opaque:int, size:int, offset:int) -> int:
self.dev.q(BufferOffset(buffer_num:=next(self.dev.buffer_num), size, offset, opaque))
return buffer_num
class RemoteProgram:
def __init__(self, dev:RemoteDevice, name:str, lib:bytes):
self.dev, self.name = dev, name
self.datahash = self.dev.conn.req.h(lib)
self.dev.q(ProgramAlloc(self.name, self.datahash))
super().__init__()
weakref.finalize(self, self._fini, self.dev, self.name, self.datahash)
@staticmethod
def _fini(dev:RemoteDevice, name:str, datahash:str): dev.q(ProgramFree(name, datahash))
def __call__(self, *bufs, global_size=None, local_size=None, vals:tuple[int, ...]=(), wait=False):
ret = self.dev.q(ProgramExec(self.name, self.datahash, bufs, vals, global_size, local_size, wait), wait=wait)
if wait: return float(ret)
@functools.cache
class RemoteConnection:
q_lock = threading.Lock()
all: dict[RemoteConnection, None] = {} # dict instead of set for deterministic ordering
def __init__(self, host:str):
if DEBUG >= 1: print(f"remote with host {host}")
while 1:
try:
self.conn = http.client.HTTPConnection(host, timeout=getenv("REMOTE_TIMEOUT", 300.0))
self.conn.connect()
break
except Exception as e:
print(e)
time.sleep(0.1)
self.req: BatchRequest = BatchRequest()
RemoteConnection.all[self] = None
def q(self, x:RemoteRequest, wait:bool=False):
with RemoteConnection.q_lock:
self.req.q(x)
if wait: return self.batch_submit(take_q=False)
async def aq(self, x:RemoteRequest, wait:bool=False): return await asyncio.to_thread(self.q, x, wait=wait)
def batch_submit(self, take_q:bool=True):
if take_q: RemoteConnection.q_lock.acquire()
conns = RemoteConnection.all.keys()
datas = {conn: conn.req.serialize() for conn in conns}
reqs, hashes, hash_datas = sum(len(c.req._q) for c in conns), sum(len(c.req._h) for c in conns), sum(len(data) for data in datas.values())
ret, resps = None, []
with Timing(f"*** send {reqs:-3d} requests {hashes:-3d} hashes with len {hash_datas/1024:.2f} kB in ", enabled=DEBUG>=3):
for conn,data in datas.items(): conn.conn.request("POST", "/batch", data)
for conn in datas.keys():
resp = conn.conn.getresponse()
body = resp.read()
resps.append((conn, resp, body))
conn.req = BatchRequest()
if take_q: RemoteConnection.q_lock.release()
for conn,resp,body in resps:
match resp.status:
case http.HTTPStatus.OK: pass
case http.HTTPStatus.INTERNAL_SERVER_ERROR:
exc_wrapper = safe_eval(ast.parse(body.decode(), mode="eval").body)
exc_wrapper.exc.add_note(exc_wrapper.trace)
raise exc_wrapper.exc
case code: raise RuntimeError(f"POST /batch failed with {code}: {body.decode()}")
if conn == self: ret = body
return ret
def parse_hosts(hs:str) -> list[tuple[str, int]]|LazySeq[tuple[str, int]]:
hosts = [(unwrap(h), int(c) if c is not None else c) for h,c in ((h.split("*", maxsplit=1)+[None,])[:2] for h in hs.split(","))]
if len(hosts) == 1 and hosts[0][1] is None: return LazySeq(lambda idx: (hosts[0][0], idx))
return [(h, i) for h,c in hosts for i in range(unwrap(c))]
class RemoteDevice(Compiled):
devices = parse_hosts(getenv("HOST", ""))
def __init__(self, device:str):
host, idx = RemoteDevice.devices[int(device.split(":")[1]) if ":" in device else 0]
# connection is shared between sessions on the same host
self.session: SessionKey = SessionKey(host or RemoteDevice.local_server(), idx, binascii.hexlify(os.urandom(0x10)).decode())
self.conn: RemoteConnection = RemoteConnection(self.session.host)
# state for the session
self.buffer_num: Iterator[int] = itertools.count(0)
self.graph_num: Iterator[int] = itertools.count(0)
self.event_num: Iterator[int] = itertools.count(0)
self.properties: RemoteProperties = safe_eval(ast.parse(self.q(GetProperties(), wait=True), mode="eval").body)
if DEBUG >= 1: print(f"remote has device {self.properties.real_device}")
# TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer
renderer = self.properties.renderer
if not renderer[0].startswith("tinygrad.") or not renderer[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {renderer}")
renderer_class = fromimport(renderer[0], renderer[1]) # TODO: is this secure?
if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {renderer}")
graph = fromimport('tinygrad.runtime.graph.remote', "RemoteGraph") if self.properties.graph_supported else None
compilers = CompilerSet([CompilerPair(functools.partial(renderer_class, *renderer[2]), Compiler)])
super().__init__(device, RemoteAllocator(self), compilers, functools.partial(RemoteProgram, self), graph, id(self.conn))
self.renderer.device = device
def finalize(self):
with contextlib.suppress(ConnectionError, http.client.HTTPException): self.q(SessionFree(), wait=True)
def q(self, x:RemoteRequest, wait:bool=False): return self.conn.q(replace(x, session=self.session), wait=wait)
@functools.cache
@staticmethod
def local_server():
multiprocessing.Process(target=remote_server, args=(6667,), name="MainProcess", daemon=True).start()
return "127.0.0.1:6667"
if __name__ == "__main__": remote_server(getenv("PORT", 6667))

View File

@@ -1,173 +0,0 @@
from __future__ import annotations
import resource, ctypes, weakref, functools, itertools
from tinygrad.runtime.autogen import ib
from typing import Iterator
from dataclasses import dataclass
from weakref import WeakKeyDictionary
from tinygrad.device import Buffer, DMACPURef, DMAFdRef
from tinygrad.helpers import getenv, round_up, DEBUG
DEFAULT_PORT, DEFAULT_GID = getenv("DEFAULT_PORT", 1), getenv("DEFAULT_GID", 3) # DEFAULT_GID=0 for RXE
IOVA_ALIGN = resource.getpagesize()
def checkz(x, ret=None):
if x != 0: raise RuntimeError(f'{x} != 0 (errno {ctypes.get_errno()})')
return ret
@dataclass(frozen=True)
class SGE:
dst_iova: int
dst_key: int
src_iova: int
src_key: int
size: int
class IBCtx:
def __init__(self, idx:int):
# Open the device (aka Host Channel Adapter in ib-speak)
devs = ib.ibv_get_device_list(ctypes.byref(ndevs:=ctypes.c_int32()))
if idx >= ndevs.value: raise IndexError(f"{idx} > {ndevs.value}")
self.ctx = ib.ibv_open_device(devs[idx])
ib.ibv_free_device_list(devs)
# HACK: remove this (and all usage of `ctx.contents.ops`) when clang2py can deal with `static inline` wrapper-functions
self.vctx = ctypes.cast(ctypes.addressof(self.ctx.contents) - ib.struct_verbs_context.context.offset, ctypes.POINTER(ib.struct_verbs_context))
# Get attributes. Something like port_attr.max_msg_sz sound like it might requre taking the min of host's and remote's attributes if they differ
self.device_attr = checkz(ib.ibv_query_device(self.ctx, ctypes.byref(da:=ib.struct_ibv_device_attr())), da)
self.port_attr = checkz(self.vctx.contents.query_port(self.ctx, DEFAULT_PORT, ctypes.byref(pa:=ib.struct_ibv_port_attr()), ctypes.sizeof(pa)), pa)
self.gid_attr = checkz(ib.ibv_query_gid(self.ctx, DEFAULT_PORT, DEFAULT_GID, ctypes.byref(ga:=ib.union_ibv_gid())), ga)
# Allocate protection domain
self.pd = ib.ibv_alloc_pd(self.ctx)
self.next_iova: int = IOVA_ALIGN # don't start at zero (nullptr)
# weakref(buf) => (iova, mr, mr_dealloc). mr_dealloc is kept here to avoid double freeing mrs that are deallocated in __del__
self.mrs: WeakKeyDictionary[Buffer, tuple[int, ctypes._Pointer[ib.struct_ibv_mr], weakref.finalize]] = WeakKeyDictionary()
# Default soft fd limit is 1024, which is not enough, set soft to hard (maximum allowed by the os)
IBCtx.rlimit_fix()
def __del__(self):
# must deallocate all mrs in protection domain before deallocating the protection domain
if hasattr(self, "mrs"): [fin() for _,_,fin in self.mrs.values()]
if hasattr(self, "pd"): ib.ibv_dealloc_pd(self.pd)
if hasattr(self, "ctx"): ib.ibv_close_device(self.ctx)
@functools.cache # run once
@staticmethod
def rlimit_fix():
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
if DEBUG>=2: print(f"IB: Increased fd limit from {soft} to {hard}")
def alloc_iova(self, size:int, required_offset:int):
iova = round_up(self.next_iova - required_offset, IOVA_ALIGN) + required_offset
self.next_iova = iova + size
return iova
def reg(self, buf:Buffer) -> tuple[int, ctypes._Pointer[ib.struct_ibv_mr]]:
buf = buf.base
if buf not in self.mrs:
if buf.nbytes > self.device_attr.max_mr_size: raise RuntimeError(f"Buffer too big: {buf.nbytes:#x} > {self.device_attr.max_mr_size:#x}")
if len(self.mrs) >= self.device_attr.max_mr: raise RuntimeError(f"Out of memory region cap: {len(self.mrs)} >= {self.device_attr.max_mr}")
# Local read is implied (but still have to create the memory region, except for short sends/writes with IBV_SEND_INLINE that are inlined by cpu)
mr_flags = ib.IBV_ACCESS_LOCAL_WRITE | ib.IBV_ACCESS_REMOTE_READ | ib.IBV_ACCESS_REMOTE_WRITE
match (dmaref:=buf.as_dmaref()):
case DMACPURef():
iova = self.alloc_iova(dmaref.size, dmaref.addr % IOVA_ALIGN)
mr = ib.ibv_reg_mr_iova2(self.pd, ctypes.c_void_p(dmaref.addr), dmaref.size, iova, mr_flags)
case DMAFdRef():
iova = self.alloc_iova(dmaref.size, dmaref.offset % IOVA_ALIGN)
mr = ib.ibv_reg_dmabuf_mr(self.pd, dmaref.offset, dmaref.size, iova, dmaref.fd, mr_flags)
case _: raise RuntimeError(f"Unknown type of dma ref: {dmaref}")
if not mr: raise RuntimeError(f"Couldn't register memory region for {buf} {dmaref} (errno={ctypes.get_errno()})")
self.mrs[buf] = (iova, mr, weakref.finalize(buf, ib.ibv_dereg_mr, mr))
return self.mrs[buf][0:2]
class IBConn:
def __init__(self, ctx:IBCtx):
self.ctx = ctx
# Create Completion Channel. It is a file descriptor that kernel sends notifications through, not a thing in infiniband spec, just linux-ism
self.comp_channel = ib.ibv_create_comp_channel(self.ctx.ctx)
# Create Completion Queue. When a Work Request with signaled flag is completed a Completion Queue Entry is pushed onto this queue
self.cq = ib.ibv_create_cq(self.ctx.ctx, _capacity:=256, _cq_context:=None, self.comp_channel, _comp_vector:=0)
self.pending_wrids: set[int] = set()
self.wrid_num: Iterator[int] = itertools.count(0) # wc_id is uint64, this will never overflow
# Create Queue Pair. It's the closest thing to a socket in infiniband with QP num being the closest thing to a port, except it's allocated by hca
qp_init_attrs_cap = ib.struct_ibv_qp_cap(max_send_wr=1024, max_recv_wr=64, max_send_sge=8, max_recv_sge=8, max_inline_data=64)
qp_init_attrs = ib.struct_ibv_qp_init_attr(send_cq=self.cq, recv_cq=self.cq, cap=qp_init_attrs_cap, qp_type=ib.IBV_QPT_RC) # Reliable Connection
self.qp = ib.ibv_create_qp(self.ctx.pd, ctypes.byref(qp_init_attrs))
self.qp_cap = qp_init_attrs.cap
# The most important thing about QPs is their state, when a new QP is created it's in the RESET state, before it can be properly used it has to go
# through Init, Ready To Receive, Ready To Send. A good docs on QP state machine: https://www.rdmamojo.com/2012/05/05/qp-state-machine/
# INIT
qp_access_flags = ib.IBV_ACCESS_REMOTE_WRITE | ib.IBV_ACCESS_REMOTE_READ
qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_INIT, port_num=DEFAULT_PORT, qp_access_flags=qp_access_flags)
checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_PORT | ib.IBV_QP_ACCESS_FLAGS | ib.IBV_QP_PKEY_INDEX))
self.gid, self.qp_num = bytes(self.ctx.gid_attr.raw), self.qp.contents.qp_num
# Exchange GID and QP num with remote. At least in RoCEv2 gid can be guessed from remote's ip, QP num can't.
def connect(self, remote_gid:bytes, remote_qp_num:int):
# RTR
qp_ah_attr_grh = ib.struct_ibv_global_route(hop_limit=1, dgid=ib.union_ibv_gid(raw=(ctypes.c_ubyte * 16)(*remote_gid)), sgid_index=DEFAULT_GID)
qp_ah_attr = ib.struct_ibv_ah_attr(is_global=1, port_num=DEFAULT_PORT, grh=qp_ah_attr_grh)
qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_RTR, path_mtu=ib.IBV_MTU_4096, dest_qp_num=remote_qp_num, rq_psn=0, max_dest_rd_atomic=1,
min_rnr_timer=12, ah_attr=qp_ah_attr)
checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_PATH_MTU | ib.IBV_QP_DEST_QPN | ib.IBV_QP_RQ_PSN | \
ib.IBV_QP_MAX_DEST_RD_ATOMIC | ib.IBV_QP_MIN_RNR_TIMER | ib.IBV_QP_AV))
# RTS
qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_RTS, timeout=14, retry_cnt=7, rnr_retry=7, sq_psn=0, max_rd_atomic=1)
checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_TIMEOUT | ib.IBV_QP_RETRY_CNT | ib.IBV_QP_RNR_RETRY | ib.IBV_QP_SQ_PSN | \
ib.IBV_QP_MAX_QP_RD_ATOMIC))
def __del__(self):
self.wait_cq() # need to wait for **everything** to complete before it's safe to dealloc queues and stuff
ib.ibv_destroy_qp(self.qp)
ib.ibv_destroy_cq(self.cq)
ib.ibv_destroy_comp_channel(self.comp_channel)
def next_wrid(self):
self.pending_wrids.add(wrid:=next(self.wrid_num))
return wrid
def wait_cq(self, wr_id: int|None=None):
while (wr_id in self.pending_wrids) if wr_id is not None else self.pending_wrids:
if self.ctx.ctx.contents.ops.poll_cq(self.cq, _num_entries:=1, ctypes.byref(wc:=ib.struct_ibv_wc())):
if wc.status != ib.IBV_WC_SUCCESS:
raise RuntimeError(f'Work Request completed with error: wr_id={wc.wr_id} status={ib.enum_ibv_wc_status.get(wc.status, wc.status)}')
self.pending_wrids.remove(wc.wr_id)
def rdma_write(self, sgl:list[SGE]):
swr: ctypes._Pointer[ib.struct_ibv_send_wr]|None = None
swr_cnt, wr_id = 0, self.next_wrid()
def _post():
nonlocal swr, swr_cnt, wr_id
if swr is not None:
# The swr can be freed when this returns, the memory that sge points to can be unmapped after work completion is retrieved from cq
checkz(self.ctx.ctx.contents.ops.post_send(self.qp, swr, ctypes.byref(_bad_wr:=ctypes.POINTER(ib.struct_ibv_send_wr)())))
# TODO: async
self.wait_cq(wr_id)
swr, swr_cnt, wr_id = None, 0, self.next_wrid()
# Everything is in reverse for elegant chaining
for sg in reversed(sgl):
# Message size limit (max 2GB per ib spec, 1GB on tinybox mellanoxes) applies to both scatter-gather entries and entire wrs
for off in reversed(range(0, sg.size, self.ctx.port_attr.max_msg_sz)):
# Scatter-Gather Entry for local memory
sge = ctypes.pointer(ib.struct_ibv_sge(addr=sg.src_iova+off, length=min(sg.size-off, self.ctx.port_attr.max_msg_sz), lkey=sg.src_key))
# RDMA struct for remote memory
wr = ib.struct_ibv_send_wr_wr(rdma=ib.struct_ibv_send_wr_wr_rdma(remote_addr=sg.dst_iova+off, rkey=sg.dst_key))
# Signal (with chosen work request id) if it's the last wr (first in the loop since it's reversed)
wid, flags = (wr_id, ib.IBV_SEND_SIGNALED) if swr is None else (0, 0)
# Create Send Request
swr = ctypes.pointer(ib.struct_ibv_send_wr(opcode=ib.IBV_WR_RDMA_WRITE, sg_list=sge, num_sge=1, wr=wr, wr_id=wid, send_flags=flags, next=swr))
# Flush if queue is being overrun
if (swr_cnt:=swr_cnt + 1) >= self.qp_cap.max_send_wr: _post()
_post()