mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
use Ops.GROUP instead of Ops.NOOP for merging stores (#12912)
* use Ops.GROUP instead of Ops.NOOP for merging stores * fs noop
This commit is contained in:
@@ -109,7 +109,7 @@ def cat_after_store(cat:UOp, data:UOp, sto:UOp):
|
||||
for s in cat.src:
|
||||
ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count))), *sto.src[2:]))
|
||||
offset += s.dtype.count
|
||||
return UOp(Ops.NOOP, src=tuple(ret))
|
||||
return UOp.group(*ret)
|
||||
|
||||
def gep_on_store(gep:UOp, st:UOp, sto:UOp):
|
||||
# NOTE: we need to invert the gep here, but it may be an expanding gep
|
||||
@@ -179,7 +179,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||
|
||||
# if it wasn't split, we return None. otherwise we CAT them
|
||||
if len(ret) <= 1: return None
|
||||
return UOp(Ops.CAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp(Ops.NOOP, src=tuple(ret))
|
||||
return UOp(Ops.CAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret)
|
||||
|
||||
def image_fixup(ls:UOp):
|
||||
# normal image load or store, with the CAST from expand_index
|
||||
|
||||
@@ -143,7 +143,7 @@ class CStyleLanguage(Renderer):
|
||||
c: defaultdict[str, int] = defaultdict(int)
|
||||
name = "test"
|
||||
for u in uops:
|
||||
if u.op is Ops.NOOP: continue
|
||||
if u.op in {Ops.NOOP, Ops.GROUP}: continue
|
||||
if u.op is Ops.AFTER:
|
||||
r[u] = r[u.src[0]]
|
||||
continue
|
||||
|
||||
@@ -168,7 +168,7 @@ class LLVMRenderer(Renderer):
|
||||
|
||||
name = "test"
|
||||
for u in uops:
|
||||
if u.op is Ops.NOOP: continue
|
||||
if u.op in {Ops.NOOP, Ops.GROUP}: continue
|
||||
if u.op is Ops.AFTER:
|
||||
r[u] = r[u.src[0]]
|
||||
continue
|
||||
|
||||
@@ -173,7 +173,7 @@ class NIRRenderer(Renderer):
|
||||
self.param_idx, ranges = 0, []
|
||||
|
||||
for u in uops:
|
||||
if u.op == Ops.NOOP or u.op == Ops.INDEX: pass
|
||||
if u.op in {Ops.NOOP, Ops.GROUP, Ops.INDEX}: pass
|
||||
elif u.op is Ops.AFTER:
|
||||
self.r[u] = self.r[u.src[0]]
|
||||
elif u.op == Ops.SINK:
|
||||
|
||||
@@ -183,7 +183,7 @@ class PTXRenderer(Renderer):
|
||||
|
||||
name = "test"
|
||||
for u in uops:
|
||||
if u.op is Ops.NOOP: continue
|
||||
if u.op in {Ops.NOOP, Ops.GROUP}: continue
|
||||
if u.op is Ops.AFTER:
|
||||
self.r[u] = self.r[u.src[0]]
|
||||
continue
|
||||
|
||||
@@ -52,7 +52,7 @@ class PythonProgram:
|
||||
loop_ends: dict[int, int] = {}
|
||||
while i < len(self.uops):
|
||||
uop, dtype, idp, arg = self.uops[i]
|
||||
void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.STORE}
|
||||
void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP, Ops.STORE}
|
||||
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
|
||||
@@ -60,7 +60,7 @@ class PythonProgram:
|
||||
loop_ends[idp[1]] = i
|
||||
i = idp[1]
|
||||
continue
|
||||
if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP):
|
||||
if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP):
|
||||
# in the python emulator, the warp is always in sync
|
||||
i += 1
|
||||
continue
|
||||
|
||||
@@ -250,7 +250,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return tuple(1 if i in axis_arg else s for i,s in enumerate(ps))
|
||||
|
||||
# elementwise ops keep the shape the same. all inputs with shape must match
|
||||
if self.op in (GroupOp.Elementwise-{Ops.BITCAST}).union({Ops.COPY, Ops.ASSIGN, Ops.NOOP, Ops.SINK, Ops.ALLREDUCE}):
|
||||
if self.op in (GroupOp.Elementwise-{Ops.BITCAST}).union({Ops.COPY, Ops.ASSIGN, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE}):
|
||||
# TODO: remove this hack for 3 op assign
|
||||
input_shapes = [x._shape for x in (self.src[:2] if self.op is Ops.ASSIGN else self.src) if x._shape is not None]
|
||||
if len(input_shapes) == 0: return None
|
||||
@@ -1233,6 +1233,7 @@ sugar = { Ops.SINK: "sink", Ops.STORE: "store", Ops.LOAD: "load", Ops.SQRT: "sqr
|
||||
pm_pyrender = PatternMatcher([
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg}, src={x.src[0].arg})")),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg})")),
|
||||
(UPat(Ops.END, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.end({', '.join([y.arg for y in x.src[1:]])})")),
|
||||
(UPat(Ops.CAST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.cast({x.dtype})")),
|
||||
(UPat(Ops.BITCAST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.bitcast({x.dtype})")),
|
||||
(UPat({Ops.MAX, Ops.THREEFRY, Ops.CMPLT, Ops.CMPNE, Ops.POW}, src=UPat(Ops.NOOP), name="x"),
|
||||
|
||||
@@ -147,7 +147,7 @@ program_spec = PatternMatcher([
|
||||
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
|
||||
(UPat(Ops.BARRIER, dtypes.void), lambda: True), # BARRIERs can also happen at the end of loops
|
||||
|
||||
(UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
|
||||
(UPat((Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
|
||||
])+shared_spec
|
||||
|
||||
# ***** UOp spec in kernel graph *****
|
||||
@@ -177,6 +177,9 @@ full_spec = PatternMatcher([
|
||||
# any END
|
||||
(UPat(Ops.END), lambda: True),
|
||||
|
||||
# NOOP in the full spec
|
||||
(UPat(Ops.NOOP), lambda: True),
|
||||
|
||||
# Invalid must have type Index
|
||||
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),
|
||||
# where on index in rhs position is fine
|
||||
|
||||
Reference in New Issue
Block a user