mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Remote scheduler changes (#11177)
This commit is contained in:
@@ -32,7 +32,6 @@ class TestRemoteMultiHost(unittest.TestCase):
|
||||
# Verify that everything is in one big cross-host graph
|
||||
assert len(do.captured._jit_cache) == 1 and isinstance(do.captured._jit_cache[0].prg, RemoteGraph), repr(do.captured)
|
||||
|
||||
@unittest.expectedFailure # multihost-aware schedule is in separate pr
|
||||
@Context(JIT_BATCH_SIZE=2**32)
|
||||
def test_multihost_aware_schedule(self):
|
||||
@TinyJit
|
||||
|
||||
@@ -278,9 +278,9 @@ class Compiler:
|
||||
class Compiled:
|
||||
profile_events:list[ProfileEvent] = [ProfileDeviceEvent("CPU")] # NOTE: CPU is the default device.
|
||||
|
||||
def __init__(self, device:str, allocator:Allocator, renderer:Renderer|None, compiler:Compiler|None, runtime, graph=None):
|
||||
def __init__(self, device:str, allocator:Allocator, renderer:Renderer|None, compiler:Compiler|None, runtime, graph=None, group_id=None):
|
||||
self.device, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
|
||||
self.renderer = renderer or Renderer()
|
||||
self.renderer, self.group_id = renderer or Renderer(), group_id
|
||||
def synchronize(self):
|
||||
"""
|
||||
Synchronize all pending operations on the device.
|
||||
|
||||
@@ -2,8 +2,8 @@ from typing import cast
|
||||
from dataclasses import dataclass, field
|
||||
from collections import deque, defaultdict
|
||||
from tinygrad.uop.ops import UOp, Variable, Ops, UPat, PatternMatcher, graph_rewrite, buffers
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import Metadata, unwrap, merge_dicts
|
||||
from tinygrad.device import Device, Buffer, MultiBuffer
|
||||
from tinygrad.helpers import Metadata, unwrap, all_same, merge_dicts
|
||||
|
||||
# **** ScheduleItem return type
|
||||
|
||||
@@ -61,11 +61,22 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
raise RuntimeError(f"input to kernel must be ASSIGN or BUFFER, not {s.op}")
|
||||
|
||||
# linearize KERNEL UOps into ScheduleItems in BFS order
|
||||
queue = deque(k for k,v in in_degree.items() if v == 0)
|
||||
|
||||
def _heuristic(k: UOp):
|
||||
if k.arg.ast.op is Ops.COPY and not all_same([Device[cast(Buffer, s.buf_uop.buffer).device].group_id for s in k.src]): return 1000
|
||||
return 0
|
||||
|
||||
last_heuristic: int = 0
|
||||
queues: defaultdict[int, deque[UOp]] = defaultdict(deque)
|
||||
last_queue: deque[UOp] = deque()
|
||||
for k,v in in_degree.items():
|
||||
if v == 0: queues[_heuristic(k)].append(k)
|
||||
|
||||
schedule: list[ScheduleItem] = []
|
||||
var_vals: dict[Variable, int] = {}
|
||||
while queue:
|
||||
k = queue.popleft()
|
||||
while last_queue or any(queues.values()):
|
||||
if not last_queue: last_heuristic, last_queue = min((it for it in queues.items() if it[1]), key=lambda x: abs(x[0]-last_heuristic))
|
||||
k = last_queue.popleft()
|
||||
# unbind var_vals from the kernel
|
||||
local_var_vals: list[dict[Variable, int]] = []
|
||||
ast = graph_rewrite(k.arg.ast, pm_unbind, ctx=local_var_vals, name="unbind vars")
|
||||
@@ -86,6 +97,6 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata))
|
||||
for x in children[k]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
if in_degree[x] == 0: queues[_heuristic(x)].append(x)
|
||||
|
||||
return schedule, var_vals
|
||||
|
||||
@@ -2,8 +2,30 @@ from typing import cast
|
||||
import functools, itertools, operator
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv, unwrap
|
||||
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, resolve
|
||||
from tinygrad.device import Device
|
||||
|
||||
# *** allreduce implementation ***
|
||||
def handle_allreduce_multirank(buf:UOp, red:UOp) -> UOp|None:
|
||||
if not isinstance(buf.device, tuple): return None
|
||||
|
||||
# Group buffers
|
||||
groups: dict[int|None, list[UOp]] = {}
|
||||
for i,dev in enumerate(buf.device):
|
||||
groups.setdefault(Device[dev].group_id, []).append(buf.mselect(i))
|
||||
|
||||
# Skip if only one group or if every group has only one buffer
|
||||
if len(groups) <= 1 or not any(len(g) > 1 for g in groups.values()): return None
|
||||
|
||||
# Reduce inside each group
|
||||
inner = [UOp(Ops.MSTACK, buf.dtype, tuple(bufs)).allreduce(red.arg, (cast(str, bufs[0].device),)).mselect(0) for bufs in groups.values()]
|
||||
|
||||
# Allreduce across groups
|
||||
outer = UOp(Ops.MSTACK, buf.dtype, tuple(inner)).allreduce(red.arg, tuple(buf.device for buf in inner))
|
||||
|
||||
# Broadcast back to all devices in the group
|
||||
gid2bid = {Device[device].group_id: i for i,device in enumerate(outer.device)}
|
||||
return outer.mselect(gid2bid[Device[red.device].group_id]).copy_to_device(red.device) if not isinstance(red.device, tuple) else \
|
||||
UOp(Ops.MSTACK, buf.dtype, tuple(outer.mselect(gid2bid[Device[device].group_id]).copy_to_device(device) for device in red.device))
|
||||
|
||||
def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
||||
if not isinstance(buf.device, tuple): return None
|
||||
@@ -84,6 +106,7 @@ def mstack_early_shrink(view:UOp, ms:UOp):
|
||||
return ms.replace(src=tuple(ret))
|
||||
|
||||
replace_allreduce = PatternMatcher([
|
||||
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce_multirank),
|
||||
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
|
||||
# BROADCAST: explicitly expand broadcast copies and combine with MSTACK
|
||||
(UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x:
|
||||
|
||||
@@ -389,7 +389,7 @@ class RemoteDevice(Compiled):
|
||||
renderer_instance = renderer_class(*renderer[2])
|
||||
renderer_instance.device = device
|
||||
graph = fromimport('tinygrad.runtime.graph.remote', "RemoteGraph") if self.properties.graph_supported else None
|
||||
super().__init__(device, RemoteAllocator(self), renderer_instance, Compiler(), functools.partial(RemoteProgram, self), graph)
|
||||
super().__init__(device, RemoteAllocator(self), renderer_instance, Compiler(), functools.partial(RemoteProgram, self), graph, id(self.conn))
|
||||
|
||||
def finalize(self):
|
||||
with contextlib.suppress(ConnectionError, http.client.HTTPException): self.q(SessionFree(), wait=True)
|
||||
|
||||
Reference in New Issue
Block a user