mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
3
.github/workflows/benchmark.yml
vendored
3
.github/workflows/benchmark.yml
vendored
@@ -554,8 +554,6 @@ jobs:
|
||||
run: time BENCHMARK_LOG=cifar AMD=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt
|
||||
- name: Run full CIFAR training steps w 6 GPUS
|
||||
run: time BENCHMARK_LOG=cifar_6gpu AMD=1 DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee train_cifar_six_gpu.txt
|
||||
- name: Run full CIFAR training steps w 6 GPUS (REMOTE)
|
||||
run: time BENCHMARK_LOG=cifar_6gpu_remote REMOTE=1 REMOTEDEV=AMD DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee train_cifar_six_gpu_remote.txt
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: Speed (AMD Training)
|
||||
@@ -567,7 +565,6 @@ jobs:
|
||||
train_cifar_wino.txt
|
||||
train_cifar_one_gpu.txt
|
||||
train_cifar_six_gpu.txt
|
||||
train_cifar_six_gpu_remote.txt
|
||||
- name: Run process replay tests
|
||||
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
|
||||
|
||||
|
||||
89
.github/workflows/test.yml
vendored
89
.github/workflows/test.yml
vendored
@@ -721,71 +721,6 @@ jobs:
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
amdremote:
|
||||
name: Linux (remote)
|
||||
runs-on: ubuntu-22.04
|
||||
timeout-minutes: 20
|
||||
env:
|
||||
REMOTE: 1
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: linux-remote
|
||||
deps: testing_minimal
|
||||
amd: 'true'
|
||||
llvm: 'true'
|
||||
opencl: 'true'
|
||||
- name: Start remote server
|
||||
run: |
|
||||
start_server() {
|
||||
systemd-run --user \
|
||||
--unit="$1" \
|
||||
--setenv=REMOTEDEV="$2" \
|
||||
--setenv=MOCKGPU=1 \
|
||||
--setenv=PYTHONPATH=. \
|
||||
--setenv=PORT="$3" \
|
||||
--working-directory="$(pwd)" \
|
||||
python tinygrad/runtime/ops_remote.py
|
||||
}
|
||||
|
||||
start_server "remote-server-amd-1" "AMD" 6667
|
||||
start_server "remote-server-amd-2" "AMD" 6668
|
||||
start_server "remote-server-gpu" "CL" 7667
|
||||
start_server "remote-server-cpu" "CPU" 8667
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
env:
|
||||
HOST: 127.0.0.1:6667*6,127.0.0.1:6668*6
|
||||
run: |
|
||||
python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT"
|
||||
python -c "from tinygrad import Device; assert Device.default.properties.real_device == 'AMD', Device.default.properties.real_device"
|
||||
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
|
||||
- name: Run REMOTE=1 Test (AMD)
|
||||
env:
|
||||
HOST: 127.0.0.1:6667*6,127.0.0.1:6668*6
|
||||
run: |
|
||||
python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_remote.py test/test_tensor_variable.py --durations 20
|
||||
- name: Run REMOTE=1 Test (CL)
|
||||
env:
|
||||
HOST: 127.0.0.1:7667*6
|
||||
run: |
|
||||
python3 -m pytest test/test_tiny.py test/test_image_dtype.py test/test_jit.py --durations 20
|
||||
IMAGE=2 python3 -m pytest test/test_tiny.py test/test_image_dtype.py
|
||||
- name: Run REMOTE=1 Test (CPU)
|
||||
env:
|
||||
HOST: 127.0.0.1:8667*6
|
||||
run: |
|
||||
python3 -m pytest test/test_tiny.py test/test_jit.py test/test_multitensor.py --durations 20
|
||||
- name: Show remote server logs
|
||||
if: always()
|
||||
run: |
|
||||
journalctl --user -u remote-server-amd-1 --no-pager
|
||||
journalctl --user -u remote-server-amd-2 --no-pager
|
||||
journalctl --user -u remote-server-gpu --no-pager
|
||||
journalctl --user -u remote-server-cpu --no-pager
|
||||
|
||||
# ****** OSX Tests ******
|
||||
|
||||
testmetal:
|
||||
@@ -883,30 +818,6 @@ jobs:
|
||||
- name: Test ONNX Runner (WEBGPU)
|
||||
run: WEBGPU=1 python3 test/external/external_test_onnx_runner.py
|
||||
|
||||
osxremote:
|
||||
name: MacOS (remote metal)
|
||||
runs-on: macos-15
|
||||
timeout-minutes: 10
|
||||
env:
|
||||
REMOTE: 1
|
||||
REMOTEDEV: METAL
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: macos-remote
|
||||
deps: testing_minimal
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
run: |
|
||||
python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT"
|
||||
python -c "from tinygrad import Device; assert Device.default.properties.real_device == 'METAL', Device.default.properties.real_device"
|
||||
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
|
||||
- name: Run REMOTE=1 Test
|
||||
run: |
|
||||
python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_tensor_variable.py
|
||||
|
||||
osxtests:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
2
test/external/external_test_example.py
vendored
2
test/external/external_test_example.py
vendored
@@ -8,7 +8,7 @@ def multidevice_test(fxn):
|
||||
def ret(self):
|
||||
for device in Device._devices:
|
||||
# broken on OSX USB AMD, why?
|
||||
if device in ["REMOTE", "DISK", "NPY", "FAKE", "DSP", "NULL"] or (OSX and device in ["AMD"]): continue
|
||||
if device in ["DISK", "NPY", "FAKE", "DSP", "NULL"] or (OSX and device in ["AMD"]): continue
|
||||
if not CI: print(device)
|
||||
if device in exclude_devices:
|
||||
if not CI: print(f"WARNING: {device} test is excluded")
|
||||
|
||||
@@ -69,5 +69,4 @@ def needs_second_gpu(fn):
|
||||
return fn(self, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
# NOTE: This will open REMOTE if it's the default device
|
||||
REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties.real_device)
|
||||
REAL_DEV = Device.DEFAULT
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
import numpy as np, unittest, string
|
||||
from hypothesis import given, strategies as st
|
||||
from tinygrad import Device, Tensor, TinyJit, dtypes
|
||||
from tinygrad.runtime.ops_remote import RemoteDevice, parse_hosts
|
||||
from tinygrad.runtime.graph.remote import RemoteGraph
|
||||
from tinygrad.helpers import LazySeq, all_same, Context
|
||||
|
||||
def multihost_env(devices):
|
||||
def same_hosts(devices): return all_same([h for h,_ in devices])
|
||||
return isinstance(devices, list) and len(devices) >= 12 and not same_hosts(devices[0:12]) and same_hosts(devices[0:6]) and same_hosts(devices[6:12])
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "REMOTE" and multihost_env(RemoteDevice.devices), "Requires special environment")
|
||||
class TestRemoteMultiHost(unittest.TestCase):
|
||||
def test_mutlihost_transfer(self):
|
||||
a = Tensor.arange(0, 16, device='REMOTE:0').contiguous().realize()
|
||||
b = a.to('REMOTE:6').contiguous().realize()
|
||||
np.testing.assert_equal(b.numpy(), np.arange(0, 16))
|
||||
|
||||
@Context(JIT_BATCH_SIZE=2**32)
|
||||
@unittest.skip("kernel must all be multibuffer")
|
||||
def test_multihost_matmul_jit_graph(self):
|
||||
@TinyJit
|
||||
def do(a:Tensor, b:Tensor): return (a @ b).contiguous().realize()
|
||||
|
||||
ds = ('REMOTE:0', 'REMOTE:1', 'REMOTE:6', 'REMOTE:7')
|
||||
for _ in range(3):
|
||||
na, nb = np.random.rand(128, 128).astype(np.float32), np.random.rand(128, 128).astype(np.float32)
|
||||
a, b = Tensor(na).shard(ds, 0).contiguous().realize(), Tensor(nb).shard(ds, 0).contiguous().realize()
|
||||
nc = na @ nb
|
||||
c = do(a, b)
|
||||
np.testing.assert_allclose(nc, c.numpy(), rtol=3e-2, atol=1e-4) # tolerances from extra/gemm/simple_matmul.py
|
||||
|
||||
# Verify that everything is in one big cross-host graph
|
||||
assert len(do.captured._jit_cache) == 1 and isinstance(do.captured._jit_cache[0].prg, RemoteGraph), repr(do.captured)
|
||||
|
||||
@Context(JIT_BATCH_SIZE=2**32)
|
||||
@unittest.skip("assign target and input devices mismatch")
|
||||
def test_multihost_aware_schedule(self):
|
||||
@TinyJit
|
||||
def do(*ts:Tensor):
|
||||
acc = Tensor.zeros(1, dtype=dtypes.float32).contiguous().realize()
|
||||
for t in ts: acc += t.sum()
|
||||
return acc.realize()
|
||||
|
||||
def do_np(*ts:np.ndarray):
|
||||
acc = np.zeros(1, np.float32)
|
||||
for t in ts: acc += t.sum()
|
||||
return acc
|
||||
|
||||
ds = ('REMOTE:0', 'REMOTE:1', 'REMOTE:6', 'REMOTE:7')
|
||||
TS = 64
|
||||
for _ in range(3):
|
||||
inp_np = [np.random.rand(256).astype(np.float32) for _ in range(TS)]
|
||||
inp = [Tensor(inp).shard(ds, 0).contiguous().realize() for inp in inp_np]
|
||||
out_np = do_np(*inp_np)
|
||||
out = do(*inp)
|
||||
np.testing.assert_allclose(out_np, out.numpy(), rtol=3e-2, atol=1e-4)
|
||||
|
||||
# Verify that everything is in one big cross-host graph and that the scheduling is reasonable
|
||||
assert len(do.captured._jit_cache) == 1 and isinstance(do.captured._jit_cache[0].prg, RemoteGraph), repr(do.captured)
|
||||
# At the time of writing this: 2050 graph breaks without multihost aware scheduling, 14 with it. I've set fail threshold to 28 to not fail on
|
||||
# unrelated scheduling changes. Maybe 2x is a bit too pessimistic, but remote should perform just fine as long as this is not like a half hundred
|
||||
# or more here.
|
||||
self.assertLess(len(do.captured._jit_cache[0].prg.template), 28, "Very bad scheduling! Many unnecesary graph breaks!")
|
||||
|
||||
class TestParseHosts(unittest.TestCase):
|
||||
def assert_seq(self, result:LazySeq, host:str):
|
||||
self.assertIsInstance(result, LazySeq)
|
||||
for i in [0, 1, 5, 10]: self.assertEqual(result[i], (host, i))
|
||||
|
||||
@given(st.sampled_from(["", "localhost", "192.168.1.1:8080", "host"]))
|
||||
def test_single_host_no_count(self, host:str):
|
||||
self.assert_seq(parse_hosts(host), host)
|
||||
|
||||
@given(host=st.sampled_from(["localhost", "host", "192.168.1.1:8080"]), count=st.integers(0, 10))
|
||||
def test_single_host_with_count(self, host:str, count:int):
|
||||
self.assertEqual(parse_hosts(f"{host}*{count}"), [(host, i) for i in range(count)])
|
||||
|
||||
def test_multiple_hosts_with_counts_simple(self):
|
||||
self.assertEqual(parse_hosts("host1*2,host2*3"), [("host1", i) for i in range(2)] + [("host2", i) for i in range(3)])
|
||||
|
||||
@given(st.lists(st.tuples(st.text(alphabet=string.ascii_letters + string.digits + ".-:"), st.integers(1, 16)), min_size=1))
|
||||
def test_multiple_hosts_with_counts_sampled(self, host_count_pairs):
|
||||
hosts_str = ",".join(f"{host}*{count}" for host, count in host_count_pairs)
|
||||
expected = [(host, i) for host, count in host_count_pairs for i in range(count)]
|
||||
self.assertEqual(parse_hosts(hosts_str), expected)
|
||||
|
||||
@given(st.sampled_from(["host1*2,host2", "a*1,b", "x*3,y*2,z"]))
|
||||
def test_mixed_hosts_fails(self, hosts):
|
||||
with self.assertRaises(AssertionError): parse_hosts(hosts)
|
||||
|
||||
@given(st.sampled_from(["host*abc", "test*xyz", "a*1.5"]))
|
||||
def test_invalid_count_fails(self, hosts):
|
||||
with self.assertRaises(ValueError): parse_hosts(hosts)
|
||||
|
||||
@given(st.sampled_from(["host*2*3", "a*1*2*3", "test*x*y"]))
|
||||
def test_multiple_asterisks_fails(self, hosts):
|
||||
with self.assertRaises(ValueError): parse_hosts(hosts)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,113 +0,0 @@
|
||||
import time, itertools
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
from tinygrad.engine.realize import CompiledRunner, BufferXfer, ExecItem
|
||||
from tinygrad.device import Device, Compiled, Buffer
|
||||
from tinygrad.runtime.ops_remote import RemoteDevice, RemoteConnection, RemoteRequest, GraphComputeItem, Transfer, GraphAlloc, GraphFree, GraphExec
|
||||
from tinygrad.runtime.ops_remote import BatchTransfer, Event, Wait
|
||||
from tinygrad.helpers import unwrap, flatten, dedup
|
||||
from enum import Enum, auto
|
||||
from dataclasses import replace
|
||||
from collections import defaultdict
|
||||
from typing import cast
|
||||
|
||||
class StagingType(Enum): NONE = auto(); GRAPH = auto(); TRANSFER = auto() # noqa: E702
|
||||
|
||||
def rd(dev:Compiled) -> RemoteDevice: return cast(RemoteDevice, dev)
|
||||
def dev_key(dev:RemoteDevice): return dev.conn if dev.properties.graph_supports_multi else dev
|
||||
def map_rawbuf(rawbuf:Buffer): return (cast(RemoteDevice, Device[rawbuf.device]).session, rawbuf._buf)
|
||||
|
||||
class RemoteGraph(MultiGraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, rawbufs, var_vals)
|
||||
devices = dedup(flatten([[Device[unwrap(buf).device] for buf in ji.bufs] for ji in jit_cache]))
|
||||
c2d = {device.conn: device for device in devices}
|
||||
self.handle_indexes = {map_rawbuf(rawbufs[i]): i for i in sorted(dedup(self.input_replace.values()))}
|
||||
|
||||
self.template: list[RemoteRequest] = []
|
||||
|
||||
stagings: dict[RemoteDevice|RemoteConnection, list[GraphComputeItem|Transfer]] = defaultdict(list)
|
||||
clobbered_buffers: set[Buffer] = set()
|
||||
cur_staging_type: StagingType = StagingType.NONE
|
||||
|
||||
def _flush(new_staging_type:StagingType, force_break:bool=False):
|
||||
nonlocal cur_staging_type
|
||||
if cur_staging_type == new_staging_type and not force_break: return
|
||||
# Pre-sync
|
||||
if cur_staging_type == StagingType.TRANSFER:
|
||||
for sdev,ddev in itertools.permutations(c2d.values(), 2):
|
||||
self.template.append(Event(ddev.session, event:=next(ddev.event_num), session=sdev.session))
|
||||
self.template.append(Wait(event, session=ddev.session))
|
||||
# Flush
|
||||
for dev in devices:
|
||||
dk = dev_key(dev)
|
||||
staging = stagings[dk]
|
||||
if not staging: continue
|
||||
match cur_staging_type:
|
||||
case StagingType.GRAPH:
|
||||
bufs = tuple(map_rawbuf(rawbufs[i]) for i in sorted(dedup(self.input_replace.values())) if dev_key(rd(Device[rawbufs[i].device])) == dk)
|
||||
dev.q(GraphAlloc(graph_num:=next(dev.graph_num), tuple(staging), tuple(bufs), var_vals))
|
||||
self.template.append(GraphExec(graph_num, bufs, var_vals, wait=False, session=dev.session))
|
||||
case StagingType.TRANSFER:
|
||||
st = cast(list[Transfer], staging)
|
||||
for host in dedup(t.dsession.host for t in st):
|
||||
sbuffer_nums = [(unwrap(t.session), t.buffer_num) for t in st if t.dsession.host == host]
|
||||
dbuffer_nums = [(t.dsession, t.dbuffer_num) for t in st if t.dsession.host == host]
|
||||
self.template.append(BatchTransfer(sbuffer_nums, dbuffer_nums, session=dev.session))
|
||||
staging.clear()
|
||||
# Post-sync
|
||||
if cur_staging_type == StagingType.TRANSFER:
|
||||
for sdev,ddev in itertools.permutations(c2d.values(), 2):
|
||||
self.template.append(Event(ddev.session, event:=next(ddev.event_num), session=sdev.session))
|
||||
self.template.append(Wait(event, session=ddev.session))
|
||||
cur_staging_type = new_staging_type
|
||||
clobbered_buffers.clear()
|
||||
|
||||
for ji in jit_cache:
|
||||
match ji.prg:
|
||||
case CompiledRunner():
|
||||
_flush(StagingType.GRAPH)
|
||||
gi = GraphComputeItem(ji.prg.dev.session, ji.prg._prg.name, ji.prg._prg.datahash, tuple(unwrap(buf)._buf for buf in ji.bufs),
|
||||
tuple(ji.prg.p.vars), ji.fixedvars, tuple(ji.prg.p.ins), tuple(ji.prg.p.outs),
|
||||
tuple(ji.prg.p.global_size) if ji.prg.p.global_size is not None else None,
|
||||
tuple(ji.prg.p.local_size) if ji.prg.p.local_size is not None else None)
|
||||
stagings[dev_key(ji.prg.dev)].append(gi)
|
||||
case BufferXfer():
|
||||
dest, src = ji.bufs[0:2]
|
||||
dest_dev, src_dev = cast(RemoteDevice, Device[unwrap(dest).device]), cast(RemoteDevice, Device[unwrap(src).device])
|
||||
assert dest is not None and src is not None, ji
|
||||
ti = Transfer(session=src_dev.session, buffer_num=src._buf, dsession=dest_dev.session, dbuffer_num=dest._buf)
|
||||
if dev_key(dest_dev) == dev_key(src_dev):
|
||||
_flush(StagingType.GRAPH)
|
||||
stagings[dev_key(src_dev)].append(ti)
|
||||
elif dest_dev.conn == src_dev.conn:
|
||||
_flush(StagingType.NONE)
|
||||
self.template.append(ti)
|
||||
else:
|
||||
_flush(StagingType.TRANSFER, force_break=src in clobbered_buffers)
|
||||
clobbered_buffers.add(dest)
|
||||
stagings[dev_key(src_dev)].append(ti)
|
||||
case _: raise NotImplementedError(ji.prg)
|
||||
_flush(StagingType.NONE)
|
||||
def __del__(self):
|
||||
for req in self.template:
|
||||
match req:
|
||||
case GraphExec(): RemoteConnection(unwrap(req.session).host).q(GraphFree(req.graph_num, session=req.session))
|
||||
def __call__(self, rawbufs: list[Buffer], var_vals: dict[str, int], wait=False):
|
||||
if wait: st = time.perf_counter()
|
||||
rmap = {orig: map_rawbuf(rawbufs[replace_idx]) for orig,replace_idx in self.handle_indexes.items()}
|
||||
for req in self.template:
|
||||
match req:
|
||||
case GraphExec():
|
||||
req = replace(req, bufs=tuple(rmap[buf] for buf in req.bufs), var_vals=var_vals, wait=wait)
|
||||
case Transfer():
|
||||
if (req.session, req.buffer_num) in rmap: req = replace(req, buffer_num=rmap[(req.session, req.buffer_num)][1])
|
||||
if (req.dsession, req.dbuffer_num) in rmap: req = replace(req, dbuffer_num=rmap[(req.dsession, req.dbuffer_num)][1])
|
||||
case BatchTransfer():
|
||||
req = replace(req, sbuffer_nums=[rmap.get(b, b) for b in req.sbuffer_nums], dbuffer_nums=[rmap.get(b, b) for b in req.dbuffer_nums])
|
||||
case Event()|Wait():
|
||||
pass # event number can be reused
|
||||
case _: raise NotImplementedError(req)
|
||||
RemoteConnection(unwrap(req.session).host).q(req)
|
||||
if wait:
|
||||
RemoteConnection(unwrap(req.session).host).batch_submit()
|
||||
return time.perf_counter() - st
|
||||
@@ -1,491 +0,0 @@
|
||||
# the REMOTE=1 device is a process boundary between the frontend/runtime
|
||||
# normally tinygrad is frontend <-> middleware <-> runtime <-> hardware
|
||||
# with REMOTE tinygrad is frontend <-> middleware <-> RemoteDevice ///HTTP/// remote_server <-> runtime <-> hardware
|
||||
# this client and server can be on the same machine, same network, or just same internet
|
||||
# it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Callable, Iterator, Any, cast
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field, replace
|
||||
import multiprocessing, threading, functools, itertools, asyncio, http, http.client, hashlib, time, os, binascii, struct, ast, contextlib, weakref
|
||||
import traceback, builtins
|
||||
from tinygrad.renderer import Renderer, ProgramSpec
|
||||
from tinygrad.dtype import DTYPES_DICT, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, Variable, sint
|
||||
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, LazySeq, Timing
|
||||
from tinygrad.engine.jit import GraphRunner, MultiGraphRunner, ExecItem, graph_class
|
||||
from tinygrad.engine.realize import CompiledRunner, BufferXfer
|
||||
from tinygrad.device import Compiled, Buffer, Allocator, Compiler, Device, BufferSpec, CompilerSet, CompilerPair
|
||||
from tinygrad.runtime.support.ib import IBCtx, IBConn, SGE
|
||||
|
||||
# ***** API *****
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SessionKey: host: str; idx: int; nonce: str # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RemoteRequest: session: SessionKey|None = field(default=None, kw_only=True)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SessionFree(RemoteRequest): pass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RemoteProperties:
|
||||
real_device: str
|
||||
renderer: tuple[str, str, tuple[Any, ...]]
|
||||
offset_supported: bool
|
||||
graph_supported: bool
|
||||
graph_supports_multi: bool
|
||||
ib_gid: bytes|None
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RemoteException:
|
||||
exc: Exception
|
||||
trace: str = ""
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GetProperties(RemoteRequest): pass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Event(RemoteRequest): event_session: SessionKey; event: int # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Wait(RemoteRequest): event: int
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IBConnect(RemoteRequest): host: str; gid: bytes; qp_num: int # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BufferAlloc(RemoteRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BufferOffset(RemoteRequest): buffer_num: int; size: int; offset: int; sbuffer_num: int # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BufferIOVAS(RemoteRequest): buffer_nums: list[tuple[SessionKey, int]] # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BufferFree(RemoteRequest): buffer_num: int # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CopyIn(RemoteRequest): buffer_num: int; datahash: str # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CopyOut(RemoteRequest): buffer_num: int
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Transfer(RemoteRequest): buffer_num: int; dsession: SessionKey; dbuffer_num: int # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BatchTransfer(RemoteRequest):
|
||||
sbuffer_nums: list[tuple[SessionKey, int]]
|
||||
dbuffer_nums: list[tuple[SessionKey, int]]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProgramAlloc(RemoteRequest): name: str; datahash: str # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProgramFree(RemoteRequest): name: str; datahash: str # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProgramExec(RemoteRequest):
|
||||
name: str; datahash: str; bufs: tuple[int, ...]; vals: tuple[int, ...] # noqa: E702
|
||||
global_size: tuple[int, ...]|None; local_size: tuple[int, ...]|None; wait: bool # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GraphComputeItem:
|
||||
session: SessionKey
|
||||
name: str
|
||||
datahash: str
|
||||
bufs: tuple[int, ...]
|
||||
vars: tuple[Variable, ...]
|
||||
fixedvars: dict[str, int]
|
||||
ins: tuple[int, ...]
|
||||
outs: tuple[int, ...]
|
||||
global_size: tuple[sint, ...]|None
|
||||
local_size: tuple[sint, ...]|None
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GraphAlloc(RemoteRequest):
|
||||
graph_num: int
|
||||
jit_cache: tuple[GraphComputeItem|Transfer, ...]
|
||||
bufs: tuple[tuple[SessionKey, int], ...]
|
||||
var_vals: dict[str, int]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GraphFree(RemoteRequest):
|
||||
graph_num: int
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GraphExec(RemoteRequest):
|
||||
graph_num: int
|
||||
bufs: tuple[tuple[SessionKey, int], ...]
|
||||
var_vals: dict[str, int]
|
||||
wait: bool
|
||||
|
||||
# for safe deserialization
|
||||
eval_excs = [v for k,v in builtins.__dict__.items() if isinstance(v, type) and issubclass(v, Exception) and not k.endswith("Warning")]
|
||||
eval_globals = {x.__name__:x for x in [SessionKey, SessionFree, RemoteProperties, GetProperties, Event, Wait, BufferAlloc, BufferOffset, BufferIOVAS,
|
||||
BufferFree, CopyIn, CopyOut, Transfer, BatchTransfer, IBConnect, ProgramAlloc, ProgramFree, ProgramExec,
|
||||
GraphComputeItem, GraphAlloc, GraphFree, GraphExec, BufferSpec, UOp, Ops, dtypes, RemoteException] + eval_excs}
|
||||
attribute_whitelist: dict[Any, set[str]] = {dtypes: {*DTYPES_DICT.keys(), 'imagef', 'imageh'}, Ops: {x.name for x in Ops}}
|
||||
eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
|
||||
ast.Dict: lambda x: {safe_eval(k):safe_eval(v) for k,v in zip(x.keys, x.values)},
|
||||
ast.Call: lambda x: safe_eval(x.func)(*[safe_eval(arg) for arg in x.args], **{kwarg.arg: safe_eval(kwarg.value) for kwarg in x.keywords}),
|
||||
ast.Name: lambda x: eval_globals[x.id], ast.Attribute: lambda x: safe_getattr(safe_eval(x.value), x.attr)}
|
||||
def safe_getattr(value, attr):
|
||||
assert attr in attribute_whitelist.get(value, set()), f'getattr({value}, {repr(attr)}) is not whitelisted'
|
||||
return getattr(value, attr)
|
||||
def safe_eval(node): return eval_fxns[node.__class__](node)
|
||||
|
||||
class BatchRequest:
|
||||
def __init__(self):
|
||||
self._q: list[RemoteRequest] = []
|
||||
self._h: dict[str, bytes] = {}
|
||||
def h(self, d:bytes|memoryview) -> str:
|
||||
datahash = hashlib.sha256(d).hexdigest() # NOTE: this is very slow, should use blake3 on gpu instead
|
||||
if datahash not in self._h:
|
||||
self._h[datahash] = bytes.fromhex(datahash)+struct.pack("<Q", len(d))+bytes(d)
|
||||
return datahash
|
||||
def q(self, x:RemoteRequest): self._q.append(x)
|
||||
def serialize(self) -> bytes:
|
||||
self.h(repr(self._q).encode())
|
||||
return b''.join(self._h.values())
|
||||
def deserialize(self, dat:bytes) -> BatchRequest:
|
||||
ptr = 0
|
||||
while ptr < len(dat):
|
||||
datahash, datalen = binascii.hexlify(dat[ptr:ptr+0x20]).decode(), struct.unpack("<Q", dat[ptr+0x20:ptr+0x28])[0]
|
||||
self._h[datahash] = dat[ptr+0x28:ptr+0x28+datalen]
|
||||
ptr += 0x28+datalen
|
||||
self._q = safe_eval(ast.parse(self._h[datahash], mode="eval").body)
|
||||
return self
|
||||
|
||||
# ***** backend *****
|
||||
|
||||
@dataclass
|
||||
class RemoteSession:
|
||||
programs: dict[tuple[str, str], Any] = field(default_factory=dict)
|
||||
graphs: dict[int, GraphRunner] = field(default_factory=dict)
|
||||
buffers: dict[int, Buffer] = field(default_factory=dict)
|
||||
events: defaultdict[int, asyncio.Event] = field(default_factory=functools.partial(defaultdict, asyncio.Event))
|
||||
|
||||
class RemoteHandler:
|
||||
def __init__(self, base_device: str):
|
||||
self.base_device = base_device
|
||||
self.sessions: defaultdict[SessionKey, RemoteSession] = defaultdict(RemoteSession)
|
||||
|
||||
try: self.ib_ctx: IBCtx|None = IBCtx(getenv("IB_DEV", 0))
|
||||
except (RuntimeError, IndexError, AttributeError): self.ib_ctx = None
|
||||
self.ib_lock = asyncio.Lock()
|
||||
self.ib_conns: dict[str, IBConn|None] = {}
|
||||
self.iova_cache: dict[tuple[SessionKey, int], tuple[int, int, int]] = {}
|
||||
|
||||
async def __call__(self, reader:asyncio.StreamReader, writer:asyncio.StreamWriter):
|
||||
while (req_hdr:=(await reader.readline()).decode().strip()):
|
||||
req_method, req_path, _ = req_hdr.split(' ')
|
||||
req_headers = {}
|
||||
while (hdr:=(await reader.readline()).decode().strip()):
|
||||
key, value = hdr.split(':', 1)
|
||||
req_headers[key.lower()] = value.strip()
|
||||
req_body = await reader.readexactly(int(req_headers.get("content-length", "0")))
|
||||
try: res_status, res_body = await self.handle(req_method, req_path, req_body)
|
||||
except Exception as e:
|
||||
res_status, res_body = http.HTTPStatus.INTERNAL_SERVER_ERROR, repr(RemoteException(e, traceback.format_exc())).encode()
|
||||
print(f"{traceback.format_exc()}", flush=True)
|
||||
writer.write(f"HTTP/1.1 {res_status.value} {res_status.phrase}\r\nContent-Length: {len(res_body)}\r\n\r\n".encode() + res_body)
|
||||
|
||||
async def ib_connect(self, ssession:SessionKey, dsession:SessionKey) -> IBConn|None:
|
||||
if self.ib_ctx is None: return None
|
||||
await self.ib_lock.acquire()
|
||||
conn = RemoteConnection(dsession.host)
|
||||
if dsession.host not in self.ib_conns:
|
||||
props = safe_eval(ast.parse(conn.q(GetProperties(session=dsession), wait=True), mode="eval").body)
|
||||
if props.ib_gid is not None:
|
||||
self.ib_conns[dsession.host] = ib_conn = IBConn(self.ib_ctx)
|
||||
ibxc_ret = conn.q(IBConnect(ssession.host, ib_conn.gid, ib_conn.qp_num, session=dsession), wait=True)
|
||||
ib_conn.connect(*struct.unpack('<16sQ', ibxc_ret))
|
||||
else:
|
||||
self.ib_conns[dsession.host] = None
|
||||
self.ib_lock.release()
|
||||
return self.ib_conns[dsession.host]
|
||||
|
||||
async def get_iovas(self, bufs:list[tuple[SessionKey, int]]) -> list[tuple[int, int, int]]:
|
||||
await self.ib_lock.acquire()
|
||||
if (rbufs:=[buf for buf in bufs if buf not in self.iova_cache]):
|
||||
conn = RemoteConnection(rbufs[0][0].host)
|
||||
resp = await conn.aq(BufferIOVAS(rbufs, session=rbufs[0][0]), wait=True)
|
||||
self.iova_cache.update({rbuf: struct.unpack('<QQQ', resp[i*24:(i+1)*24]) for i,rbuf in enumerate(rbufs)})
|
||||
self.ib_lock.release()
|
||||
return [self.iova_cache[buf] for buf in bufs]
|
||||
|
||||
async def handle(self, method:str, path:str, body:bytes) -> tuple[http.HTTPStatus, bytes]:
|
||||
status, ret = http.HTTPStatus.OK, b""
|
||||
if path == "/batch" and method == "POST":
|
||||
# TODO: streaming deserialize?
|
||||
req = BatchRequest().deserialize(body)
|
||||
# the cmds are always last (currently in datahash)
|
||||
for c in req._q:
|
||||
if DEBUG >= 1: print(c)
|
||||
session, dev = self.sessions[unwrap(c.session)], Device[f"{self.base_device}:{unwrap(c.session).idx}"]
|
||||
match c:
|
||||
case SessionFree(): del self.sessions[unwrap(c.session)]
|
||||
case GetProperties():
|
||||
cls, args = dev.renderer.__reduce__()
|
||||
graph_cls = graph_class(Device[self.base_device])
|
||||
rp = RemoteProperties(
|
||||
real_device=dev.device, renderer=(cls.__module__, cls.__name__, args), offset_supported=hasattr(dev.allocator, '_offset'),
|
||||
graph_supported=graph_cls is not None,
|
||||
graph_supports_multi=graph_cls is not None and issubclass(graph_cls, MultiGraphRunner) and hasattr(dev.allocator, '_transfer'),
|
||||
ib_gid=bytes(self.ib_ctx.gid_attr.raw) if self.ib_ctx is not None else None,
|
||||
)
|
||||
ret = repr(rp).encode()
|
||||
case Event():
|
||||
if c.session == c.event_session:
|
||||
session.events[c.event].set()
|
||||
else:
|
||||
for d in Device._opened_devices: Device[d].synchronize() # wait for device*s* to finish executing previous stuff
|
||||
# TODO: don't wait, just send
|
||||
await RemoteConnection(c.event_session.host).aq(Event(c.event_session, c.event, session=c.event_session), wait=True)
|
||||
case Wait():
|
||||
assert await session.events[c.event].wait()
|
||||
del session.events[c.event] # do not leak memory
|
||||
case IBConnect():
|
||||
self.ib_conns[c.host] = ibc = IBConn(unwrap(self.ib_ctx))
|
||||
ibc.connect(c.gid, c.qp_num)
|
||||
ret = struct.pack('<16sQ', ibc.gid, ibc.qp_num)
|
||||
case BufferAlloc():
|
||||
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
|
||||
session.buffers[c.buffer_num] = Buffer(dev.device, c.size, dtypes.uint8, options=c.options, preallocate=True)
|
||||
case BufferIOVAS():
|
||||
rets = []
|
||||
for buffer_session,buffer_num in c.buffer_nums:
|
||||
iova, mr = unwrap(self.ib_ctx).reg(buf:=self.sessions[buffer_session].buffers[buffer_num])
|
||||
rets.append(struct.pack("<QQQ", iova, mr.contents.rkey, buf.nbytes))
|
||||
ret = b"".join(rets)
|
||||
case BufferOffset():
|
||||
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already exists"
|
||||
session.buffers[c.buffer_num] = session.buffers[c.sbuffer_num].view(c.size, dtypes.uint8, c.offset).allocate()
|
||||
case BufferFree(): del session.buffers[c.buffer_num]
|
||||
case CopyIn(): session.buffers[c.buffer_num].copyin(memoryview(bytearray(req._h[c.datahash])))
|
||||
case CopyOut(): session.buffers[c.buffer_num].copyout(memoryview(ret:=bytearray(session.buffers[c.buffer_num].nbytes)))
|
||||
case Transfer():
|
||||
if c.dsession.host == unwrap(c.session).host:
|
||||
dsession, ddev = self.sessions[c.dsession], Device[f"{self.base_device}:{unwrap(c.dsession).idx}"]
|
||||
dbuf, sbuf = dsession.buffers[c.dbuffer_num], session.buffers[c.buffer_num]
|
||||
if hasattr(ddev.allocator, '_transfer'):
|
||||
assert dbuf.nbytes == sbuf.nbytes, f"{dbuf.nbytes} != {sbuf.nbytes}"
|
||||
ddev.allocator._transfer(dbuf._buf, sbuf._buf, dbuf.nbytes, dest_dev=ddev, src_dev=dev)
|
||||
else:
|
||||
sbuf.copyout(data:=memoryview(bytearray(sbuf.nbytes)))
|
||||
dbuf.copyin(data)
|
||||
else:
|
||||
conn, ib_conn = RemoteConnection(c.dsession.host), await self.ib_connect(unwrap(c.session), c.dsession)
|
||||
sbuf = session.buffers[c.buffer_num]
|
||||
if ib_conn is not None:
|
||||
src_iova, src_mr = unwrap(self.ib_ctx).reg(sbuf)
|
||||
dst_iova, dst_key, dst_size = (await self.get_iovas([(c.dsession, c.dbuffer_num)]))[0]
|
||||
assert sbuf.nbytes == dst_size, f"{sbuf.nbytes} != {dst_size}"
|
||||
for d in Device._opened_devices: Device[d].synchronize()
|
||||
ib_conn.rdma_write([SGE(dst_iova, dst_key, src_iova, src_mr.contents.lkey, dst_size)])
|
||||
else:
|
||||
sbuf.copyout(data:=memoryview(bytearray(sbuf.nbytes)))
|
||||
await conn.aq(CopyIn(c.dbuffer_num, conn.req.h(data), session=c.dsession), wait=True)
|
||||
case BatchTransfer():
|
||||
conn, ib_conn = RemoteConnection(c.dbuffer_nums[0][0].host), await self.ib_connect(c.sbuffer_nums[0][0], c.dbuffer_nums[0][0])
|
||||
if ib_conn is not None:
|
||||
sbufs = [unwrap(self.ib_ctx).reg(self.sessions[s].buffers[bi]) for s,bi in c.sbuffer_nums]
|
||||
dbufs = await self.get_iovas(c.dbuffer_nums)
|
||||
for d in Device._opened_devices: Device[d].synchronize()
|
||||
ib_conn.rdma_write([SGE(di, dk, si, sm.contents.lkey, ds) for (di,dk,ds),(si,sm) in zip(dbufs, sbufs)])
|
||||
else:
|
||||
for (sbuf_session,sbuf_num),(dbuf_session,dbuf_num) in zip(c.sbuffer_nums, c.dbuffer_nums):
|
||||
sbuf = self.sessions[sbuf_session].buffers[sbuf_num]
|
||||
sbuf.copyout(data:=memoryview(bytearray(sbuf.nbytes)))
|
||||
await conn.aq(CopyIn(dbuf_num, conn.req.h(data), session=dbuf_session), wait=True)
|
||||
case ProgramAlloc():
|
||||
lib = dev.compiler.compile_cached(req._h[c.datahash].decode())
|
||||
session.programs[(c.name, c.datahash)] = dev.runtime(c.name, lib)
|
||||
case ProgramFree():
|
||||
key = (c.name, c.datahash)
|
||||
# WORKAROUND: should be unconditional once the protocol supports proper exception handling
|
||||
if key in session.programs: del session.programs[key]
|
||||
case ProgramExec():
|
||||
bufs = [session.buffers[x]._buf for x in c.bufs]
|
||||
extra_args = {k:v for k,v in [("global_size", c.global_size), ("local_size", c.local_size)] if v is not None}
|
||||
r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args)
|
||||
if r is not None: ret = str(r).encode()
|
||||
case GraphAlloc():
|
||||
graph_fn: Callable = unwrap(dev.graph)
|
||||
def _parse_ji(gi: GraphComputeItem|Transfer):
|
||||
match gi:
|
||||
case GraphComputeItem():
|
||||
prg = self.sessions[gi.session].programs[(gi.name, gi.datahash)]
|
||||
ps = ProgramSpec(gi.name, '', f"{self.base_device}:{gi.session.idx}", UOp(Ops.NOOP),
|
||||
vars=list(gi.vars), ins=list(gi.ins), outs=list(gi.outs),
|
||||
global_size=list(cast(tuple[int], gi.global_size)) if gi.global_size is not None else None,
|
||||
local_size=list(cast(tuple[int], gi.local_size)) if gi.local_size is not None else None)
|
||||
return ExecItem(CompiledRunner(ps, precompiled=b'', prg=prg), [self.sessions[gi.session].buffers[buf] for buf in gi.bufs],
|
||||
fixedvars=gi.fixedvars)
|
||||
case Transfer():
|
||||
dbuf, sbuf = self.sessions[gi.dsession].buffers[gi.dbuffer_num], self.sessions[unwrap(gi.session)].buffers[gi.buffer_num]
|
||||
assert dbuf.nbytes == sbuf.nbytes, f"{dbuf.nbytes} != {sbuf.nbytes}"
|
||||
return ExecItem(BufferXfer(dbuf.nbytes, dbuf.device, sbuf.device), [dbuf, sbuf])
|
||||
assert c.graph_num not in session.graphs, f"graph {c.graph_num} already allocated"
|
||||
session.graphs[c.graph_num] = graph_fn(list(map(_parse_ji, c.jit_cache)), [self.sessions[s].buffers[i] for s,i in c.bufs], c.var_vals)
|
||||
case GraphFree(): del session.graphs[c.graph_num]
|
||||
case GraphExec():
|
||||
r = session.graphs[c.graph_num]([self.sessions[s].buffers[i] for s,i in c.bufs], c.var_vals, wait=c.wait)
|
||||
if r is not None: ret = str(r).encode()
|
||||
else: status, ret = http.HTTPStatus.NOT_FOUND, b"Not Found"
|
||||
return status, ret
|
||||
|
||||
def remote_server(port:int):
|
||||
device = getenv("REMOTEDEV", next(Device.get_available_devices()) if Device.DEFAULT == "REMOTE" else Device.DEFAULT)
|
||||
async def _inner_async(port:int, device:str):
|
||||
print(f"start remote server on {port} with device {device}")
|
||||
await (await asyncio.start_server(RemoteHandler(device), host='', port=port)).serve_forever()
|
||||
asyncio.run(_inner_async(port, device))
|
||||
|
||||
# ***** frontend *****
|
||||
|
||||
class RemoteAllocator(Allocator['RemoteDevice']):
|
||||
def __init__(self, dev:RemoteDevice):
|
||||
if dev.properties.offset_supported: self._offset = self._dyn_offset
|
||||
super().__init__(dev)
|
||||
# TODO: ideally we shouldn't have to deal with images here
|
||||
def _alloc(self, size:int, options:BufferSpec) -> int:
|
||||
self.dev.q(BufferAlloc(buffer_num:=next(self.dev.buffer_num), size, options))
|
||||
return buffer_num
|
||||
# TODO: options should not be here in any Allocator
|
||||
def _free(self, opaque:int, options):
|
||||
try: self.dev.q(BufferFree(opaque))
|
||||
except (TypeError, AttributeError): pass
|
||||
def _copyin(self, dest:int, src:memoryview): self.dev.q(CopyIn(dest, self.dev.conn.req.h(src)))
|
||||
def _copyout(self, dest:memoryview, src:int):
|
||||
resp = self.dev.q(CopyOut(src), wait=True)
|
||||
assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}"
|
||||
dest[:] = resp
|
||||
def _transfer(self, dest, src, sz, src_dev, dest_dev):
|
||||
if dest_dev.conn != src_dev.conn:
|
||||
dest_dev.q(Event(src_dev.session, start_event:=next(src_dev.event_num)))
|
||||
src_dev.q(Wait(start_event))
|
||||
src_dev.q(Transfer(src, dest_dev.session, dest))
|
||||
if dest_dev.conn != src_dev.conn:
|
||||
src_dev.q(Event(dest_dev.session, end_event:=next(dest_dev.event_num)))
|
||||
dest_dev.q(Wait(end_event))
|
||||
if DEBUG >= 2: dest_dev.conn.batch_submit()
|
||||
def _dyn_offset(self, opaque:int, size:int, offset:int) -> int:
|
||||
self.dev.q(BufferOffset(buffer_num:=next(self.dev.buffer_num), size, offset, opaque))
|
||||
return buffer_num
|
||||
|
||||
class RemoteProgram:
|
||||
def __init__(self, dev:RemoteDevice, name:str, lib:bytes):
|
||||
self.dev, self.name = dev, name
|
||||
self.datahash = self.dev.conn.req.h(lib)
|
||||
self.dev.q(ProgramAlloc(self.name, self.datahash))
|
||||
super().__init__()
|
||||
weakref.finalize(self, self._fini, self.dev, self.name, self.datahash)
|
||||
|
||||
@staticmethod
|
||||
def _fini(dev:RemoteDevice, name:str, datahash:str): dev.q(ProgramFree(name, datahash))
|
||||
|
||||
def __call__(self, *bufs, global_size=None, local_size=None, vals:tuple[int, ...]=(), wait=False):
|
||||
ret = self.dev.q(ProgramExec(self.name, self.datahash, bufs, vals, global_size, local_size, wait), wait=wait)
|
||||
if wait: return float(ret)
|
||||
|
||||
@functools.cache
|
||||
class RemoteConnection:
|
||||
q_lock = threading.Lock()
|
||||
all: dict[RemoteConnection, None] = {} # dict instead of set for deterministic ordering
|
||||
|
||||
def __init__(self, host:str):
|
||||
if DEBUG >= 1: print(f"remote with host {host}")
|
||||
while 1:
|
||||
try:
|
||||
self.conn = http.client.HTTPConnection(host, timeout=getenv("REMOTE_TIMEOUT", 300.0))
|
||||
self.conn.connect()
|
||||
break
|
||||
except Exception as e:
|
||||
print(e)
|
||||
time.sleep(0.1)
|
||||
self.req: BatchRequest = BatchRequest()
|
||||
RemoteConnection.all[self] = None
|
||||
|
||||
def q(self, x:RemoteRequest, wait:bool=False):
|
||||
with RemoteConnection.q_lock:
|
||||
self.req.q(x)
|
||||
if wait: return self.batch_submit(take_q=False)
|
||||
|
||||
async def aq(self, x:RemoteRequest, wait:bool=False): return await asyncio.to_thread(self.q, x, wait=wait)
|
||||
|
||||
def batch_submit(self, take_q:bool=True):
|
||||
if take_q: RemoteConnection.q_lock.acquire()
|
||||
conns = RemoteConnection.all.keys()
|
||||
datas = {conn: conn.req.serialize() for conn in conns}
|
||||
reqs, hashes, hash_datas = sum(len(c.req._q) for c in conns), sum(len(c.req._h) for c in conns), sum(len(data) for data in datas.values())
|
||||
ret, resps = None, []
|
||||
with Timing(f"*** send {reqs:-3d} requests {hashes:-3d} hashes with len {hash_datas/1024:.2f} kB in ", enabled=DEBUG>=3):
|
||||
for conn,data in datas.items(): conn.conn.request("POST", "/batch", data)
|
||||
for conn in datas.keys():
|
||||
resp = conn.conn.getresponse()
|
||||
body = resp.read()
|
||||
resps.append((conn, resp, body))
|
||||
conn.req = BatchRequest()
|
||||
if take_q: RemoteConnection.q_lock.release()
|
||||
for conn,resp,body in resps:
|
||||
match resp.status:
|
||||
case http.HTTPStatus.OK: pass
|
||||
case http.HTTPStatus.INTERNAL_SERVER_ERROR:
|
||||
exc_wrapper = safe_eval(ast.parse(body.decode(), mode="eval").body)
|
||||
exc_wrapper.exc.add_note(exc_wrapper.trace)
|
||||
raise exc_wrapper.exc
|
||||
case code: raise RuntimeError(f"POST /batch failed with {code}: {body.decode()}")
|
||||
if conn == self: ret = body
|
||||
return ret
|
||||
|
||||
def parse_hosts(hs:str) -> list[tuple[str, int]]|LazySeq[tuple[str, int]]:
|
||||
hosts = [(unwrap(h), int(c) if c is not None else c) for h,c in ((h.split("*", maxsplit=1)+[None,])[:2] for h in hs.split(","))]
|
||||
if len(hosts) == 1 and hosts[0][1] is None: return LazySeq(lambda idx: (hosts[0][0], idx))
|
||||
return [(h, i) for h,c in hosts for i in range(unwrap(c))]
|
||||
|
||||
class RemoteDevice(Compiled):
|
||||
devices = parse_hosts(getenv("HOST", ""))
|
||||
|
||||
def __init__(self, device:str):
|
||||
host, idx = RemoteDevice.devices[int(device.split(":")[1]) if ":" in device else 0]
|
||||
|
||||
# connection is shared between sessions on the same host
|
||||
self.session: SessionKey = SessionKey(host or RemoteDevice.local_server(), idx, binascii.hexlify(os.urandom(0x10)).decode())
|
||||
self.conn: RemoteConnection = RemoteConnection(self.session.host)
|
||||
|
||||
# state for the session
|
||||
self.buffer_num: Iterator[int] = itertools.count(0)
|
||||
self.graph_num: Iterator[int] = itertools.count(0)
|
||||
self.event_num: Iterator[int] = itertools.count(0)
|
||||
|
||||
self.properties: RemoteProperties = safe_eval(ast.parse(self.q(GetProperties(), wait=True), mode="eval").body)
|
||||
if DEBUG >= 1: print(f"remote has device {self.properties.real_device}")
|
||||
# TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer
|
||||
renderer = self.properties.renderer
|
||||
if not renderer[0].startswith("tinygrad.") or not renderer[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {renderer}")
|
||||
renderer_class = fromimport(renderer[0], renderer[1]) # TODO: is this secure?
|
||||
if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {renderer}")
|
||||
|
||||
graph = fromimport('tinygrad.runtime.graph.remote', "RemoteGraph") if self.properties.graph_supported else None
|
||||
compilers = CompilerSet([CompilerPair(functools.partial(renderer_class, *renderer[2]), Compiler)])
|
||||
super().__init__(device, RemoteAllocator(self), compilers, functools.partial(RemoteProgram, self), graph, id(self.conn))
|
||||
self.renderer.device = device
|
||||
|
||||
def finalize(self):
|
||||
with contextlib.suppress(ConnectionError, http.client.HTTPException): self.q(SessionFree(), wait=True)
|
||||
|
||||
def q(self, x:RemoteRequest, wait:bool=False): return self.conn.q(replace(x, session=self.session), wait=wait)
|
||||
|
||||
@functools.cache
|
||||
@staticmethod
|
||||
def local_server():
|
||||
multiprocessing.Process(target=remote_server, args=(6667,), name="MainProcess", daemon=True).start()
|
||||
return "127.0.0.1:6667"
|
||||
|
||||
if __name__ == "__main__": remote_server(getenv("PORT", 6667))
|
||||
@@ -1,173 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import resource, ctypes, weakref, functools, itertools
|
||||
from tinygrad.runtime.autogen import ib
|
||||
from typing import Iterator
|
||||
from dataclasses import dataclass
|
||||
from weakref import WeakKeyDictionary
|
||||
from tinygrad.device import Buffer, DMACPURef, DMAFdRef
|
||||
from tinygrad.helpers import getenv, round_up, DEBUG
|
||||
|
||||
DEFAULT_PORT, DEFAULT_GID = getenv("DEFAULT_PORT", 1), getenv("DEFAULT_GID", 3) # DEFAULT_GID=0 for RXE
|
||||
IOVA_ALIGN = resource.getpagesize()
|
||||
|
||||
def checkz(x, ret=None):
|
||||
if x != 0: raise RuntimeError(f'{x} != 0 (errno {ctypes.get_errno()})')
|
||||
return ret
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SGE:
|
||||
dst_iova: int
|
||||
dst_key: int
|
||||
src_iova: int
|
||||
src_key: int
|
||||
size: int
|
||||
|
||||
class IBCtx:
|
||||
def __init__(self, idx:int):
|
||||
# Open the device (aka Host Channel Adapter in ib-speak)
|
||||
devs = ib.ibv_get_device_list(ctypes.byref(ndevs:=ctypes.c_int32()))
|
||||
if idx >= ndevs.value: raise IndexError(f"{idx} > {ndevs.value}")
|
||||
self.ctx = ib.ibv_open_device(devs[idx])
|
||||
ib.ibv_free_device_list(devs)
|
||||
|
||||
# HACK: remove this (and all usage of `ctx.contents.ops`) when clang2py can deal with `static inline` wrapper-functions
|
||||
self.vctx = ctypes.cast(ctypes.addressof(self.ctx.contents) - ib.struct_verbs_context.context.offset, ctypes.POINTER(ib.struct_verbs_context))
|
||||
|
||||
# Get attributes. Something like port_attr.max_msg_sz sound like it might requre taking the min of host's and remote's attributes if they differ
|
||||
self.device_attr = checkz(ib.ibv_query_device(self.ctx, ctypes.byref(da:=ib.struct_ibv_device_attr())), da)
|
||||
self.port_attr = checkz(self.vctx.contents.query_port(self.ctx, DEFAULT_PORT, ctypes.byref(pa:=ib.struct_ibv_port_attr()), ctypes.sizeof(pa)), pa)
|
||||
self.gid_attr = checkz(ib.ibv_query_gid(self.ctx, DEFAULT_PORT, DEFAULT_GID, ctypes.byref(ga:=ib.union_ibv_gid())), ga)
|
||||
|
||||
# Allocate protection domain
|
||||
self.pd = ib.ibv_alloc_pd(self.ctx)
|
||||
self.next_iova: int = IOVA_ALIGN # don't start at zero (nullptr)
|
||||
|
||||
# weakref(buf) => (iova, mr, mr_dealloc). mr_dealloc is kept here to avoid double freeing mrs that are deallocated in __del__
|
||||
self.mrs: WeakKeyDictionary[Buffer, tuple[int, ctypes._Pointer[ib.struct_ibv_mr], weakref.finalize]] = WeakKeyDictionary()
|
||||
|
||||
# Default soft fd limit is 1024, which is not enough, set soft to hard (maximum allowed by the os)
|
||||
IBCtx.rlimit_fix()
|
||||
|
||||
def __del__(self):
|
||||
# must deallocate all mrs in protection domain before deallocating the protection domain
|
||||
if hasattr(self, "mrs"): [fin() for _,_,fin in self.mrs.values()]
|
||||
if hasattr(self, "pd"): ib.ibv_dealloc_pd(self.pd)
|
||||
if hasattr(self, "ctx"): ib.ibv_close_device(self.ctx)
|
||||
|
||||
@functools.cache # run once
|
||||
@staticmethod
|
||||
def rlimit_fix():
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
|
||||
if DEBUG>=2: print(f"IB: Increased fd limit from {soft} to {hard}")
|
||||
|
||||
def alloc_iova(self, size:int, required_offset:int):
|
||||
iova = round_up(self.next_iova - required_offset, IOVA_ALIGN) + required_offset
|
||||
self.next_iova = iova + size
|
||||
return iova
|
||||
|
||||
def reg(self, buf:Buffer) -> tuple[int, ctypes._Pointer[ib.struct_ibv_mr]]:
|
||||
buf = buf.base
|
||||
if buf not in self.mrs:
|
||||
if buf.nbytes > self.device_attr.max_mr_size: raise RuntimeError(f"Buffer too big: {buf.nbytes:#x} > {self.device_attr.max_mr_size:#x}")
|
||||
if len(self.mrs) >= self.device_attr.max_mr: raise RuntimeError(f"Out of memory region cap: {len(self.mrs)} >= {self.device_attr.max_mr}")
|
||||
# Local read is implied (but still have to create the memory region, except for short sends/writes with IBV_SEND_INLINE that are inlined by cpu)
|
||||
mr_flags = ib.IBV_ACCESS_LOCAL_WRITE | ib.IBV_ACCESS_REMOTE_READ | ib.IBV_ACCESS_REMOTE_WRITE
|
||||
match (dmaref:=buf.as_dmaref()):
|
||||
case DMACPURef():
|
||||
iova = self.alloc_iova(dmaref.size, dmaref.addr % IOVA_ALIGN)
|
||||
mr = ib.ibv_reg_mr_iova2(self.pd, ctypes.c_void_p(dmaref.addr), dmaref.size, iova, mr_flags)
|
||||
case DMAFdRef():
|
||||
iova = self.alloc_iova(dmaref.size, dmaref.offset % IOVA_ALIGN)
|
||||
mr = ib.ibv_reg_dmabuf_mr(self.pd, dmaref.offset, dmaref.size, iova, dmaref.fd, mr_flags)
|
||||
case _: raise RuntimeError(f"Unknown type of dma ref: {dmaref}")
|
||||
if not mr: raise RuntimeError(f"Couldn't register memory region for {buf} {dmaref} (errno={ctypes.get_errno()})")
|
||||
self.mrs[buf] = (iova, mr, weakref.finalize(buf, ib.ibv_dereg_mr, mr))
|
||||
return self.mrs[buf][0:2]
|
||||
|
||||
class IBConn:
|
||||
def __init__(self, ctx:IBCtx):
|
||||
self.ctx = ctx
|
||||
|
||||
# Create Completion Channel. It is a file descriptor that kernel sends notifications through, not a thing in infiniband spec, just linux-ism
|
||||
self.comp_channel = ib.ibv_create_comp_channel(self.ctx.ctx)
|
||||
# Create Completion Queue. When a Work Request with signaled flag is completed a Completion Queue Entry is pushed onto this queue
|
||||
self.cq = ib.ibv_create_cq(self.ctx.ctx, _capacity:=256, _cq_context:=None, self.comp_channel, _comp_vector:=0)
|
||||
self.pending_wrids: set[int] = set()
|
||||
self.wrid_num: Iterator[int] = itertools.count(0) # wc_id is uint64, this will never overflow
|
||||
|
||||
# Create Queue Pair. It's the closest thing to a socket in infiniband with QP num being the closest thing to a port, except it's allocated by hca
|
||||
qp_init_attrs_cap = ib.struct_ibv_qp_cap(max_send_wr=1024, max_recv_wr=64, max_send_sge=8, max_recv_sge=8, max_inline_data=64)
|
||||
qp_init_attrs = ib.struct_ibv_qp_init_attr(send_cq=self.cq, recv_cq=self.cq, cap=qp_init_attrs_cap, qp_type=ib.IBV_QPT_RC) # Reliable Connection
|
||||
self.qp = ib.ibv_create_qp(self.ctx.pd, ctypes.byref(qp_init_attrs))
|
||||
self.qp_cap = qp_init_attrs.cap
|
||||
|
||||
# The most important thing about QPs is their state, when a new QP is created it's in the RESET state, before it can be properly used it has to go
|
||||
# through Init, Ready To Receive, Ready To Send. A good docs on QP state machine: https://www.rdmamojo.com/2012/05/05/qp-state-machine/
|
||||
|
||||
# INIT
|
||||
qp_access_flags = ib.IBV_ACCESS_REMOTE_WRITE | ib.IBV_ACCESS_REMOTE_READ
|
||||
qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_INIT, port_num=DEFAULT_PORT, qp_access_flags=qp_access_flags)
|
||||
checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_PORT | ib.IBV_QP_ACCESS_FLAGS | ib.IBV_QP_PKEY_INDEX))
|
||||
|
||||
self.gid, self.qp_num = bytes(self.ctx.gid_attr.raw), self.qp.contents.qp_num
|
||||
|
||||
# Exchange GID and QP num with remote. At least in RoCEv2 gid can be guessed from remote's ip, QP num can't.
|
||||
|
||||
def connect(self, remote_gid:bytes, remote_qp_num:int):
|
||||
# RTR
|
||||
qp_ah_attr_grh = ib.struct_ibv_global_route(hop_limit=1, dgid=ib.union_ibv_gid(raw=(ctypes.c_ubyte * 16)(*remote_gid)), sgid_index=DEFAULT_GID)
|
||||
qp_ah_attr = ib.struct_ibv_ah_attr(is_global=1, port_num=DEFAULT_PORT, grh=qp_ah_attr_grh)
|
||||
qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_RTR, path_mtu=ib.IBV_MTU_4096, dest_qp_num=remote_qp_num, rq_psn=0, max_dest_rd_atomic=1,
|
||||
min_rnr_timer=12, ah_attr=qp_ah_attr)
|
||||
checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_PATH_MTU | ib.IBV_QP_DEST_QPN | ib.IBV_QP_RQ_PSN | \
|
||||
ib.IBV_QP_MAX_DEST_RD_ATOMIC | ib.IBV_QP_MIN_RNR_TIMER | ib.IBV_QP_AV))
|
||||
|
||||
# RTS
|
||||
qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_RTS, timeout=14, retry_cnt=7, rnr_retry=7, sq_psn=0, max_rd_atomic=1)
|
||||
checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_TIMEOUT | ib.IBV_QP_RETRY_CNT | ib.IBV_QP_RNR_RETRY | ib.IBV_QP_SQ_PSN | \
|
||||
ib.IBV_QP_MAX_QP_RD_ATOMIC))
|
||||
|
||||
def __del__(self):
|
||||
self.wait_cq() # need to wait for **everything** to complete before it's safe to dealloc queues and stuff
|
||||
ib.ibv_destroy_qp(self.qp)
|
||||
ib.ibv_destroy_cq(self.cq)
|
||||
ib.ibv_destroy_comp_channel(self.comp_channel)
|
||||
|
||||
def next_wrid(self):
|
||||
self.pending_wrids.add(wrid:=next(self.wrid_num))
|
||||
return wrid
|
||||
|
||||
def wait_cq(self, wr_id: int|None=None):
|
||||
while (wr_id in self.pending_wrids) if wr_id is not None else self.pending_wrids:
|
||||
if self.ctx.ctx.contents.ops.poll_cq(self.cq, _num_entries:=1, ctypes.byref(wc:=ib.struct_ibv_wc())):
|
||||
if wc.status != ib.IBV_WC_SUCCESS:
|
||||
raise RuntimeError(f'Work Request completed with error: wr_id={wc.wr_id} status={ib.enum_ibv_wc_status.get(wc.status, wc.status)}')
|
||||
self.pending_wrids.remove(wc.wr_id)
|
||||
|
||||
def rdma_write(self, sgl:list[SGE]):
|
||||
swr: ctypes._Pointer[ib.struct_ibv_send_wr]|None = None
|
||||
swr_cnt, wr_id = 0, self.next_wrid()
|
||||
def _post():
|
||||
nonlocal swr, swr_cnt, wr_id
|
||||
if swr is not None:
|
||||
# The swr can be freed when this returns, the memory that sge points to can be unmapped after work completion is retrieved from cq
|
||||
checkz(self.ctx.ctx.contents.ops.post_send(self.qp, swr, ctypes.byref(_bad_wr:=ctypes.POINTER(ib.struct_ibv_send_wr)())))
|
||||
# TODO: async
|
||||
self.wait_cq(wr_id)
|
||||
swr, swr_cnt, wr_id = None, 0, self.next_wrid()
|
||||
# Everything is in reverse for elegant chaining
|
||||
for sg in reversed(sgl):
|
||||
# Message size limit (max 2GB per ib spec, 1GB on tinybox mellanoxes) applies to both scatter-gather entries and entire wrs
|
||||
for off in reversed(range(0, sg.size, self.ctx.port_attr.max_msg_sz)):
|
||||
# Scatter-Gather Entry for local memory
|
||||
sge = ctypes.pointer(ib.struct_ibv_sge(addr=sg.src_iova+off, length=min(sg.size-off, self.ctx.port_attr.max_msg_sz), lkey=sg.src_key))
|
||||
# RDMA struct for remote memory
|
||||
wr = ib.struct_ibv_send_wr_wr(rdma=ib.struct_ibv_send_wr_wr_rdma(remote_addr=sg.dst_iova+off, rkey=sg.dst_key))
|
||||
# Signal (with chosen work request id) if it's the last wr (first in the loop since it's reversed)
|
||||
wid, flags = (wr_id, ib.IBV_SEND_SIGNALED) if swr is None else (0, 0)
|
||||
# Create Send Request
|
||||
swr = ctypes.pointer(ib.struct_ibv_send_wr(opcode=ib.IBV_WR_RDMA_WRITE, sg_list=sge, num_sge=1, wr=wr, wr_id=wid, send_flags=flags, next=swr))
|
||||
# Flush if queue is being overrun
|
||||
if (swr_cnt:=swr_cnt + 1) >= self.qp_cap.max_send_wr: _post()
|
||||
_post()
|
||||
Reference in New Issue
Block a user