Remote scheduler changes (#11177)

This commit is contained in:
uuuvn
2025-07-21 21:29:44 +05:00
committed by GitHub
parent e368628736
commit 178dbf3f66
5 changed files with 43 additions and 10 deletions

View File

@@ -32,7 +32,6 @@ class TestRemoteMultiHost(unittest.TestCase):
# Verify that everything is in one big cross-host graph
assert len(do.captured._jit_cache) == 1 and isinstance(do.captured._jit_cache[0].prg, RemoteGraph), repr(do.captured)
@unittest.expectedFailure # multihost-aware schedule is in separate pr
@Context(JIT_BATCH_SIZE=2**32)
def test_multihost_aware_schedule(self):
@TinyJit

View File

@@ -278,9 +278,9 @@ class Compiler:
class Compiled:
profile_events:list[ProfileEvent] = [ProfileDeviceEvent("CPU")] # NOTE: CPU is the default device.
def __init__(self, device:str, allocator:Allocator, renderer:Renderer|None, compiler:Compiler|None, runtime, graph=None):
def __init__(self, device:str, allocator:Allocator, renderer:Renderer|None, compiler:Compiler|None, runtime, graph=None, group_id=None):
self.device, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
self.renderer = renderer or Renderer()
self.renderer, self.group_id = renderer or Renderer(), group_id
def synchronize(self):
"""
Synchronize all pending operations on the device.

View File

@@ -2,8 +2,8 @@ from typing import cast
from dataclasses import dataclass, field
from collections import deque, defaultdict
from tinygrad.uop.ops import UOp, Variable, Ops, UPat, PatternMatcher, graph_rewrite, buffers
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import Metadata, unwrap, merge_dicts
from tinygrad.device import Device, Buffer, MultiBuffer
from tinygrad.helpers import Metadata, unwrap, all_same, merge_dicts
# **** ScheduleItem return type
@@ -61,11 +61,22 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
raise RuntimeError(f"input to kernel must be ASSIGN or BUFFER, not {s.op}")
# linearize KERNEL UOps into ScheduleItems in BFS order
queue = deque(k for k,v in in_degree.items() if v == 0)
def _heuristic(k: UOp):
if k.arg.ast.op is Ops.COPY and not all_same([Device[cast(Buffer, s.buf_uop.buffer).device].group_id for s in k.src]): return 1000
return 0
last_heuristic: int = 0
queues: defaultdict[int, deque[UOp]] = defaultdict(deque)
last_queue: deque[UOp] = deque()
for k,v in in_degree.items():
if v == 0: queues[_heuristic(k)].append(k)
schedule: list[ScheduleItem] = []
var_vals: dict[Variable, int] = {}
while queue:
k = queue.popleft()
while last_queue or any(queues.values()):
if not last_queue: last_heuristic, last_queue = min((it for it in queues.items() if it[1]), key=lambda x: abs(x[0]-last_heuristic))
k = last_queue.popleft()
# unbind var_vals from the kernel
local_var_vals: list[dict[Variable, int]] = []
ast = graph_rewrite(k.arg.ast, pm_unbind, ctx=local_var_vals, name="unbind vars")
@@ -86,6 +97,6 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata))
for x in children[k]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
if in_degree[x] == 0: queues[_heuristic(x)].append(x)
return schedule, var_vals

View File

@@ -2,8 +2,30 @@ from typing import cast
import functools, itertools, operator
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv, unwrap
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, resolve
from tinygrad.device import Device
# *** allreduce implementation ***
def handle_allreduce_multirank(buf:UOp, red:UOp) -> UOp|None:
if not isinstance(buf.device, tuple): return None
# Group buffers
groups: dict[int|None, list[UOp]] = {}
for i,dev in enumerate(buf.device):
groups.setdefault(Device[dev].group_id, []).append(buf.mselect(i))
# Skip if only one group or if every group has only one buffer
if len(groups) <= 1 or not any(len(g) > 1 for g in groups.values()): return None
# Reduce inside each group
inner = [UOp(Ops.MSTACK, buf.dtype, tuple(bufs)).allreduce(red.arg, (cast(str, bufs[0].device),)).mselect(0) for bufs in groups.values()]
# Allreduce across groups
outer = UOp(Ops.MSTACK, buf.dtype, tuple(inner)).allreduce(red.arg, tuple(buf.device for buf in inner))
# Broadcast back to all devices in the group
gid2bid = {Device[device].group_id: i for i,device in enumerate(outer.device)}
return outer.mselect(gid2bid[Device[red.device].group_id]).copy_to_device(red.device) if not isinstance(red.device, tuple) else \
UOp(Ops.MSTACK, buf.dtype, tuple(outer.mselect(gid2bid[Device[device].group_id]).copy_to_device(device) for device in red.device))
def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
if not isinstance(buf.device, tuple): return None
@@ -84,6 +106,7 @@ def mstack_early_shrink(view:UOp, ms:UOp):
return ms.replace(src=tuple(ret))
replace_allreduce = PatternMatcher([
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce_multirank),
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
# BROADCAST: explicitly expand broadcast copies and combine with MSTACK
(UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x:

View File

@@ -389,7 +389,7 @@ class RemoteDevice(Compiled):
renderer_instance = renderer_class(*renderer[2])
renderer_instance.device = device
graph = fromimport('tinygrad.runtime.graph.remote', "RemoteGraph") if self.properties.graph_supported else None
super().__init__(device, RemoteAllocator(self), renderer_instance, Compiler(), functools.partial(RemoteProgram, self), graph)
super().__init__(device, RemoteAllocator(self), renderer_instance, Compiler(), functools.partial(RemoteProgram, self), graph, id(self.conn))
def finalize(self):
with contextlib.suppress(ConnectionError, http.client.HTTPException): self.q(SessionFree(), wait=True)