diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index e36c056ef8..0b8814577d 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -109,7 +109,7 @@ def cat_after_store(cat:UOp, data:UOp, sto:UOp): for s in cat.src: ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count))), *sto.src[2:])) offset += s.dtype.count - return UOp(Ops.NOOP, src=tuple(ret)) + return UOp.group(*ret) def gep_on_store(gep:UOp, st:UOp, sto:UOp): # NOTE: we need to invert the gep here, but it may be an expanding gep @@ -179,7 +179,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp): # if it wasn't split, we return None. otherwise we CAT them if len(ret) <= 1: return None - return UOp(Ops.CAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp(Ops.NOOP, src=tuple(ret)) + return UOp(Ops.CAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret) def image_fixup(ls:UOp): # normal image load or store, with the CAST from expand_index diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index b00b7d0784..ee6c0d387a 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -143,7 +143,7 @@ class CStyleLanguage(Renderer): c: defaultdict[str, int] = defaultdict(int) name = "test" for u in uops: - if u.op is Ops.NOOP: continue + if u.op in {Ops.NOOP, Ops.GROUP}: continue if u.op is Ops.AFTER: r[u] = r[u.src[0]] continue diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 04fa35c4d9..ce053157cb 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -168,7 +168,7 @@ class LLVMRenderer(Renderer): name = "test" for u in uops: - if u.op is Ops.NOOP: continue + if u.op in {Ops.NOOP, Ops.GROUP}: continue if u.op is Ops.AFTER: r[u] = r[u.src[0]] continue diff --git a/tinygrad/renderer/nir.py b/tinygrad/renderer/nir.py index cb7779458d..116004fa05 100644 --- a/tinygrad/renderer/nir.py +++ b/tinygrad/renderer/nir.py @@ -173,7 +173,7 @@ class NIRRenderer(Renderer): self.param_idx, ranges = 0, [] for u in uops: - if u.op == Ops.NOOP or u.op == Ops.INDEX: pass + if u.op in {Ops.NOOP, Ops.GROUP, Ops.INDEX}: pass elif u.op is Ops.AFTER: self.r[u] = self.r[u.src[0]] elif u.op == Ops.SINK: diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 5a68aa632d..6882b736ab 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -183,7 +183,7 @@ class PTXRenderer(Renderer): name = "test" for u in uops: - if u.op is Ops.NOOP: continue + if u.op in {Ops.NOOP, Ops.GROUP}: continue if u.op is Ops.AFTER: self.r[u] = self.r[u.src[0]] continue diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 342068302f..0815596021 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -52,7 +52,7 @@ class PythonProgram: loop_ends: dict[int, int] = {} while i < len(self.uops): uop, dtype, idp, arg = self.uops[i] - void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.STORE} + void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP, Ops.STORE} inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops] dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops] if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp) @@ -60,7 +60,7 @@ class PythonProgram: loop_ends[idp[1]] = i i = idp[1] continue - if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP): + if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP): # in the python emulator, the warp is always in sync i += 1 continue diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index f6328ffea1..bcdf6c8bd6 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -250,7 +250,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return tuple(1 if i in axis_arg else s for i,s in enumerate(ps)) # elementwise ops keep the shape the same. all inputs with shape must match - if self.op in (GroupOp.Elementwise-{Ops.BITCAST}).union({Ops.COPY, Ops.ASSIGN, Ops.NOOP, Ops.SINK, Ops.ALLREDUCE}): + if self.op in (GroupOp.Elementwise-{Ops.BITCAST}).union({Ops.COPY, Ops.ASSIGN, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE}): # TODO: remove this hack for 3 op assign input_shapes = [x._shape for x in (self.src[:2] if self.op is Ops.ASSIGN else self.src) if x._shape is not None] if len(input_shapes) == 0: return None @@ -1233,6 +1233,7 @@ sugar = { Ops.SINK: "sink", Ops.STORE: "store", Ops.LOAD: "load", Ops.SQRT: "sqr pm_pyrender = PatternMatcher([ (UPat(Ops.CONST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg}, src={x.src[0].arg})")), (UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg})")), + (UPat(Ops.END, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.end({', '.join([y.arg for y in x.src[1:]])})")), (UPat(Ops.CAST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.cast({x.dtype})")), (UPat(Ops.BITCAST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.bitcast({x.dtype})")), (UPat({Ops.MAX, Ops.THREEFRY, Ops.CMPLT, Ops.CMPNE, Ops.POW}, src=UPat(Ops.NOOP), name="x"), diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 2045e590a4..bf29a5c29f 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -147,7 +147,7 @@ program_spec = PatternMatcher([ (UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local (UPat(Ops.BARRIER, dtypes.void), lambda: True), # BARRIERs can also happen at the end of loops - (UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True), + (UPat((Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True), ])+shared_spec # ***** UOp spec in kernel graph ***** @@ -177,6 +177,9 @@ full_spec = PatternMatcher([ # any END (UPat(Ops.END), lambda: True), + # NOOP in the full spec + (UPat(Ops.NOOP), lambda: True), + # Invalid must have type Index (UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index), # where on index in rhs position is fine