mstack replaces scheduler complexity (#10654)

* mstack replaces scheduler complexity

* leave that one

* contiguous

* work

* upd

* minimal failing test

* simpler

* attention is broken

* fix transformer

* failing tests

* real fix for llama

* kv cache test

* jit multi assign test

* better tests

* comment

* fix jit issue

* traverse after buf_uop
This commit is contained in:
George Hotz
2025-06-06 11:31:41 -07:00
committed by GitHub
parent 7f0f97aa76
commit bf4ffc054c
5 changed files with 30 additions and 18 deletions

View File

@@ -90,7 +90,8 @@ sym = symbolic_simple+PatternMatcher([
# double ASSIGN to same target is one ASSIGN
(UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat.var("x"))))), lambda x,t: t.assign(x.contiguous())),
# ASSIGN to unrealized replaces the UOp
(UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat.var("x"))), lambda x,t: x.contiguous() if t.base.op not in {Ops.BUFFER, Ops.BUFFER_VIEW} else None),
(UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat.var("x"))), lambda x,t: x.contiguous() if t.base.op not in {Ops.BUFFER, Ops.BUFFER_VIEW} and
not (t.base.op is Ops.MSTACK and all(x.op is Ops.BUFFER for x in t.base.src)) else None),
# put CAST to smaller dtype before EXPAND
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm"),)), lambda cast,vm: vm.base.cast(cast.dtype).view(vm.st)
if cast.dtype.itemsize <= vm.dtype.itemsize and resolve(prod(vm.shape) > vm.st.real_size()) else None),
@@ -258,6 +259,9 @@ create_kernels = PatternMatcher([
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
# push RESHAPE through MSELECT
(UPat(Ops.MSELECT, src=(UPat(Ops.RESHAPE, name="r"),), name="ms"), lambda ms,r: r.src[0].mselect(ms.arg).reshape(r.arg)),
# push RESHAPE through MSTACK
(UPat(Ops.MSTACK, src=UPat(Ops.RESHAPE), name="ms"),
lambda ms: UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).reshape(ms.src[0].arg)),
])
# **** swizzler
@@ -409,9 +413,10 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
# replace buffer with define_global + add load/store last
bufs = []
for s in k.src:
s = s.buf_uop
# traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
bufs.append(s.buf_uop)
bufs.append(s)
ast = graph_rewrite(ast, view_left+add_buffer_ops+fix_kernel_ops, bufs, bottom_up=True, name="replace buffer")
if ast.op is Ops.SINK and not all_same([x.device for x in k.src]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")

View File

@@ -63,10 +63,25 @@ def mselect_reorder_view(ms:UOp, view:UOp, base:UOp):
st = st.substitute({dnums[0]:dnums[0].const_like(ms.arg)})
return base.mselect(ms.arg).view(st)
def mstack_reorder_view(ms:UOp):
args = [x.arg for x in ms.src]
assert all_same(args) and len([x for x in args[0].vars() if x.arg[0] == '_device_num']) == 0
return UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).view(args[0])
replace_allreduce = PatternMatcher([
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
# BROADCAST: explicitly expand broadcast copies and combine with MSTACK
(UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x:
UOp(Ops.MSTACK, c.dtype, tuple(x.copy_to_device(d) for d in c.device)) if isinstance(c.device, tuple) and isinstance(x.device, str) else None),
# COPY_TO_ONE: if copying from multidevice to one, MSELECT the first (TODO: a little from each?)
(UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x:
x.mselect(0).copy_to_device(c.device) if isinstance(c.device, str) and isinstance(x.device, tuple) else None),
# MSELECT on MSTACK is replaced with nothing
(UPat(Ops.MSELECT, src=(UPat(Ops.MSTACK, name="mstack"),), name="ms"), lambda mstack, ms: mstack.src[ms.arg]),
# MSELECT must select a base, if there are views apply them after selecting the base
(UPat(Ops.MSELECT, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"),), name="ms"), mselect_reorder_view),
# move view through MSTACK
(UPat(Ops.MSTACK, src=UPat(Ops.VIEW), name="ms"), mstack_reorder_view),
])
# ***** multi functions *****

View File

@@ -77,21 +77,10 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
ubufs = tuple(s.buf_uop.buffer for s in k.src)
if any(isinstance(x, MultiBuffer) for x in ubufs):
if ast.op is Ops.COPY and (isinstance(ubufs[0], Buffer) or isinstance(ubufs[1], Buffer)):
if isinstance(ubufs[1], MultiBuffer) and isinstance(ubufs[0], Buffer): # src is multiple buffers, none selected
# COPY ANY -> ONE. Currently we just select the first
schedule.append(ScheduleItem(ast, (ubufs[0], ubufs[1].bufs[0]), k.arg.metadata))
elif isinstance(ubufs[0], MultiBuffer) and isinstance(ubufs[1], Buffer):
# COPY ONE -> ALL (BROADCAST)
for b in ubufs[0].bufs: schedule.append(ScheduleItem(ast, (b, ubufs[1]), k.arg.metadata))
else:
raise RuntimeError("unsupported copy type")
else:
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
# ALL -> ALL
dnums = [x for x in ast.variables() if x.arg[0] == '_device_num']
for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0]:i} if len(dnums) else {}))
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
dnums = [x for x in ast.variables() if x.arg[0] == '_device_num']
for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0]:i} if len(dnums) else {}))
else:
# ONE -> ONE
schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata))

View File

@@ -557,7 +557,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
buffers[self] = ret
return ret
@property
def realized(self) -> Optional[Buffer|MultiBuffer]: return self.buffer if self.op is Ops.BUFFER and self.buffer.is_allocated() else None
def realized(self) -> Optional[Buffer|MultiBuffer]:
# NOTE: this is used by the JIT to determine which inputs we capture
return self.buffer if self.op in {Ops.BUFFER, Ops.MSTACK} and self.buffer.is_allocated() else None
@property
def is_realized(self) -> bool:
return all(x.base.realized is not None for x in self.base.src) if self.base.op is Ops.MULTI else self.base.realized is not None

View File

@@ -39,6 +39,7 @@ buffer_spec = PatternMatcher([
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"),
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)),
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.MSTACK, src=UPat(Ops.BUFFER)),)), lambda: True),
# allow VIEW here. TODO: what views specifically are allowed? does this mess with gradient?
(UPat(Ops.VIEW), lambda: True),
])