mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
mselect doesn't work
This commit is contained in:
@@ -389,12 +389,12 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
|
||||
# replace assign sources with a view of the target buffer
|
||||
parents_rep: dict[UOp, UOp] = {}
|
||||
for s in k.src:
|
||||
if s.op is Ops.MSELECT:
|
||||
assert s.src[0].op is Ops.ASSIGN
|
||||
s = s.src[0]
|
||||
if s.op is Ops.ASSIGN:
|
||||
for out in s.src[1].arg.ast.src: parents_rep[out] = s.buf_uop.view(unwrap(out.st))
|
||||
parents_rep[s] = s.buf_uop
|
||||
if s.op is Ops.MSELECT:
|
||||
for out in s.src[0].src[1].arg.ast.src: parents_rep[out] = s.src[0].buf_uop.view(unwrap(out.st))
|
||||
parents_rep[s] = s.buf_uop
|
||||
ast = k.arg.ast.substitute(parents_rep, name="replace realized")
|
||||
# push views to edges
|
||||
ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right")
|
||||
|
||||
@@ -140,7 +140,14 @@ def copy_multi(multi:UOp, device:UOp):
|
||||
dnum = UOp.variable("_device_num", 0, len(multi.device)-1)
|
||||
padded = multi.src[0].pad(tuple((0,0) if a != multi.axis else (bsz*dnum, bsz*(dcount-1) - bsz*dnum) for a in range(len(multi.shape))))
|
||||
ret = padded.allreduce(Ops.ADD)
|
||||
return ret if isinstance(device.arg, str) else ret.multi(axis=None)
|
||||
if isinstance(device.arg, str):
|
||||
# copying to single device
|
||||
if device.arg in multi.device: return ret.mselect(multi.device.index(device.arg))
|
||||
return ret.mselect(0).copy_to_device(device)
|
||||
else:
|
||||
# copying to multi device
|
||||
if device.arg == multi.device: return ret.multi(axis=None)
|
||||
return ret.mselect(0).copy_to_device(device).multi(axis=None)
|
||||
|
||||
def assign_multi(dest:UOp, src:UOp):
|
||||
if dest.axis != src.axis: raise RuntimeError(f"axis must match in assign {dest.axis} != {src.axis}")
|
||||
|
||||
@@ -63,7 +63,8 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
base = k.src[1].buf_uop.buffer
|
||||
assert isinstance(base, Buffer), "base can't be MultiBuffer"
|
||||
buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
ubufs = tuple(s.buf_uop.buffer for s in k.src)
|
||||
ubufs = tuple(s.buf_uop.buffer.bufs[s.arg] if s.op is Ops.MSELECT else s.buf_uop.buffer for s in k.src)
|
||||
print(ubufs)
|
||||
if any(isinstance(x, MultiBuffer) for x in ubufs):
|
||||
if ast.op is Ops.COPY:
|
||||
if isinstance(ubufs[0], MultiBuffer) and isinstance(ubufs[1], MultiBuffer) and ast.arg is None:
|
||||
|
||||
@@ -539,7 +539,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@property
|
||||
def buf_uop(self) -> UOp:
|
||||
if self.op is Ops.BUFFER: return self
|
||||
if self.op is Ops.MSELECT: return self.src[0].buf_uop.mselect(self.arg)
|
||||
if self.op is Ops.MSELECT: return self.src[0].buf_uop
|
||||
assert self.op is Ops.ASSIGN, f"must be ASSIGN {self.op}"
|
||||
return self.src[0].base
|
||||
@property
|
||||
|
||||
@@ -47,10 +47,13 @@ def validate_kernel(k:UOp):
|
||||
|
||||
assign_spec = PatternMatcher([
|
||||
# KERNEL can attach to an ASSIGN to describe the compute required to realize a BUFFER
|
||||
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN)), name="k"), validate_kernel),
|
||||
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.MSELECT)), name="k"), validate_kernel),
|
||||
|
||||
# ASSIGN has a target and a value. It can also optionally depend on other assigns
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
|
||||
|
||||
# MSELECT picks a multi buffer
|
||||
(UPat(Ops.MSELECT, name="m"), lambda m: m.arg >= 0 and m.arg < len(m.src[0].device))
|
||||
])
|
||||
|
||||
# *** this is the spec of a Tensor in UOp ***
|
||||
@@ -79,7 +82,6 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
|
||||
# COPY/ALLREDUCE
|
||||
(UPat(Ops.COPY, name="copy", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda copy,x: copy.dtype == x.dtype),
|
||||
(UPat(Ops.ALLREDUCE, name="red", src=(UPat.var("x"),)), lambda red,x: red.dtype == x.dtype and isinstance(red.arg, Ops)),
|
||||
(UPat(Ops.MSELECT, name="m"), lambda m: m.arg >= 0 and m.arg < len(m.src[0].device))
|
||||
])
|
||||
|
||||
# ***** uop type spec *****
|
||||
|
||||
Reference in New Issue
Block a user