use Ops.GROUP instead of Ops.NOOP for merging stores (#12912)

* use Ops.GROUP instead of Ops.NOOP for merging stores

* fs noop
This commit is contained in:
George Hotz
2025-10-25 12:26:12 +08:00
committed by GitHub
parent b4f6a2c7a3
commit 6415e3e8a7
8 changed files with 14 additions and 10 deletions

View File

@@ -109,7 +109,7 @@ def cat_after_store(cat:UOp, data:UOp, sto:UOp):
for s in cat.src:
ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count))), *sto.src[2:]))
offset += s.dtype.count
return UOp(Ops.NOOP, src=tuple(ret))
return UOp.group(*ret)
def gep_on_store(gep:UOp, st:UOp, sto:UOp):
# NOTE: we need to invert the gep here, but it may be an expanding gep
@@ -179,7 +179,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
# if it wasn't split, we return None. otherwise we CAT them
if len(ret) <= 1: return None
return UOp(Ops.CAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp(Ops.NOOP, src=tuple(ret))
return UOp(Ops.CAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret)
def image_fixup(ls:UOp):
# normal image load or store, with the CAST from expand_index

View File

@@ -143,7 +143,7 @@ class CStyleLanguage(Renderer):
c: defaultdict[str, int] = defaultdict(int)
name = "test"
for u in uops:
if u.op is Ops.NOOP: continue
if u.op in {Ops.NOOP, Ops.GROUP}: continue
if u.op is Ops.AFTER:
r[u] = r[u.src[0]]
continue

View File

@@ -168,7 +168,7 @@ class LLVMRenderer(Renderer):
name = "test"
for u in uops:
if u.op is Ops.NOOP: continue
if u.op in {Ops.NOOP, Ops.GROUP}: continue
if u.op is Ops.AFTER:
r[u] = r[u.src[0]]
continue

View File

@@ -173,7 +173,7 @@ class NIRRenderer(Renderer):
self.param_idx, ranges = 0, []
for u in uops:
if u.op == Ops.NOOP or u.op == Ops.INDEX: pass
if u.op in {Ops.NOOP, Ops.GROUP, Ops.INDEX}: pass
elif u.op is Ops.AFTER:
self.r[u] = self.r[u.src[0]]
elif u.op == Ops.SINK:

View File

@@ -183,7 +183,7 @@ class PTXRenderer(Renderer):
name = "test"
for u in uops:
if u.op is Ops.NOOP: continue
if u.op in {Ops.NOOP, Ops.GROUP}: continue
if u.op is Ops.AFTER:
self.r[u] = self.r[u.src[0]]
continue

View File

@@ -52,7 +52,7 @@ class PythonProgram:
loop_ends: dict[int, int] = {}
while i < len(self.uops):
uop, dtype, idp, arg = self.uops[i]
void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.STORE}
void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP, Ops.STORE}
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
@@ -60,7 +60,7 @@ class PythonProgram:
loop_ends[idp[1]] = i
i = idp[1]
continue
if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP):
if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP):
# in the python emulator, the warp is always in sync
i += 1
continue

View File

@@ -250,7 +250,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return tuple(1 if i in axis_arg else s for i,s in enumerate(ps))
# elementwise ops keep the shape the same. all inputs with shape must match
if self.op in (GroupOp.Elementwise-{Ops.BITCAST}).union({Ops.COPY, Ops.ASSIGN, Ops.NOOP, Ops.SINK, Ops.ALLREDUCE}):
if self.op in (GroupOp.Elementwise-{Ops.BITCAST}).union({Ops.COPY, Ops.ASSIGN, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE}):
# TODO: remove this hack for 3 op assign
input_shapes = [x._shape for x in (self.src[:2] if self.op is Ops.ASSIGN else self.src) if x._shape is not None]
if len(input_shapes) == 0: return None
@@ -1233,6 +1233,7 @@ sugar = { Ops.SINK: "sink", Ops.STORE: "store", Ops.LOAD: "load", Ops.SQRT: "sqr
pm_pyrender = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg}, src={x.src[0].arg})")),
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg})")),
(UPat(Ops.END, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.end({', '.join([y.arg for y in x.src[1:]])})")),
(UPat(Ops.CAST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.cast({x.dtype})")),
(UPat(Ops.BITCAST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.bitcast({x.dtype})")),
(UPat({Ops.MAX, Ops.THREEFRY, Ops.CMPLT, Ops.CMPNE, Ops.POW}, src=UPat(Ops.NOOP), name="x"),

View File

@@ -147,7 +147,7 @@ program_spec = PatternMatcher([
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
(UPat(Ops.BARRIER, dtypes.void), lambda: True), # BARRIERs can also happen at the end of loops
(UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
(UPat((Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
])+shared_spec
# ***** UOp spec in kernel graph *****
@@ -177,6 +177,9 @@ full_spec = PatternMatcher([
# any END
(UPat(Ops.END), lambda: True),
# NOOP in the full spec
(UPat(Ops.NOOP), lambda: True),
# Invalid must have type Index
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),
# where on index in rhs position is fine