mselect doesn't work

This commit is contained in:
George Hotz
2025-05-16 14:45:48 -07:00
parent bb2e430ac3
commit c78b77bd7d
5 changed files with 18 additions and 8 deletions

View File

@@ -389,12 +389,12 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
# replace assign sources with a view of the target buffer
parents_rep: dict[UOp, UOp] = {}
for s in k.src:
if s.op is Ops.MSELECT:
assert s.src[0].op is Ops.ASSIGN
s = s.src[0]
if s.op is Ops.ASSIGN:
for out in s.src[1].arg.ast.src: parents_rep[out] = s.buf_uop.view(unwrap(out.st))
parents_rep[s] = s.buf_uop
if s.op is Ops.MSELECT:
for out in s.src[0].src[1].arg.ast.src: parents_rep[out] = s.src[0].buf_uop.view(unwrap(out.st))
parents_rep[s] = s.buf_uop
ast = k.arg.ast.substitute(parents_rep, name="replace realized")
# push views to edges
ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right")

View File

@@ -140,7 +140,14 @@ def copy_multi(multi:UOp, device:UOp):
dnum = UOp.variable("_device_num", 0, len(multi.device)-1)
padded = multi.src[0].pad(tuple((0,0) if a != multi.axis else (bsz*dnum, bsz*(dcount-1) - bsz*dnum) for a in range(len(multi.shape))))
ret = padded.allreduce(Ops.ADD)
return ret if isinstance(device.arg, str) else ret.multi(axis=None)
if isinstance(device.arg, str):
# copying to single device
if device.arg in multi.device: return ret.mselect(multi.device.index(device.arg))
return ret.mselect(0).copy_to_device(device)
else:
# copying to multi device
if device.arg == multi.device: return ret.multi(axis=None)
return ret.mselect(0).copy_to_device(device).multi(axis=None)
def assign_multi(dest:UOp, src:UOp):
if dest.axis != src.axis: raise RuntimeError(f"axis must match in assign {dest.axis} != {src.axis}")

View File

@@ -63,7 +63,8 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
base = k.src[1].buf_uop.buffer
assert isinstance(base, Buffer), "base can't be MultiBuffer"
buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
ubufs = tuple(s.buf_uop.buffer for s in k.src)
ubufs = tuple(s.buf_uop.buffer.bufs[s.arg] if s.op is Ops.MSELECT else s.buf_uop.buffer for s in k.src)
print(ubufs)
if any(isinstance(x, MultiBuffer) for x in ubufs):
if ast.op is Ops.COPY:
if isinstance(ubufs[0], MultiBuffer) and isinstance(ubufs[1], MultiBuffer) and ast.arg is None:

View File

@@ -539,7 +539,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@property
def buf_uop(self) -> UOp:
if self.op is Ops.BUFFER: return self
if self.op is Ops.MSELECT: return self.src[0].buf_uop.mselect(self.arg)
if self.op is Ops.MSELECT: return self.src[0].buf_uop
assert self.op is Ops.ASSIGN, f"must be ASSIGN {self.op}"
return self.src[0].base
@property

View File

@@ -47,10 +47,13 @@ def validate_kernel(k:UOp):
assign_spec = PatternMatcher([
# KERNEL can attach to an ASSIGN to describe the compute required to realize a BUFFER
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN)), name="k"), validate_kernel),
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.MSELECT)), name="k"), validate_kernel),
# ASSIGN has a target and a value. It can also optionally depend on other assigns
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
# MSELECT picks a multi buffer
(UPat(Ops.MSELECT, name="m"), lambda m: m.arg >= 0 and m.arg < len(m.src[0].device))
])
# *** this is the spec of a Tensor in UOp ***
@@ -79,7 +82,6 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
# COPY/ALLREDUCE
(UPat(Ops.COPY, name="copy", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda copy,x: copy.dtype == x.dtype),
(UPat(Ops.ALLREDUCE, name="red", src=(UPat.var("x"),)), lambda red,x: red.dtype == x.dtype and isinstance(red.arg, Ops)),
(UPat(Ops.MSELECT, name="m"), lambda m: m.arg >= 0 and m.arg < len(m.src[0].device))
])
# ***** uop type spec *****