remove move_mask from the devectorizer (#9511)

* remove move_mask from the devectorizer

* add (wrong) ptx

* reason

* enable index addition in PTX, we won't have the INDEX anyways

* space
This commit is contained in:
qazal
2025-03-20 11:53:12 +08:00
committed by GitHub
parent 9302738263
commit 0b20f91ce7
10 changed files with 30 additions and 25 deletions

View File

@@ -53,6 +53,7 @@ class TestPTXFailures(unittest.TestCase):
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
np.testing.assert_equal(ret, [0, 1, 1, 1])
@unittest.skip("INDEX can only have a gate ALU parent, not an IF")
def test_gated_store_with_if(self):
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)

View File

@@ -46,7 +46,7 @@ class TestHelpers(unittest.TestCase):
class TestValidIdxSimplification(unittest.TestCase):
def check(self, load, sidx, svalid):
load = full_graph_rewrite(load.sink()).src[0]
idx, valid = load.src[0].src[1], load.src[2]
idx, valid = load.src[0].src[1], load.src[0].src[2]
self.assertEqual(idx.render(simplify=False), sidx)
self.assertEqual(valid.render(simplify=False), svalid)
@@ -133,7 +133,7 @@ class TestImageSimplification(unittest.TestCase):
idx0, idx1 = idx.src[0], idx.src[1]
self.assertEqual(idx0.render(simplify=False), sidx0)
self.assertEqual(idx1.render(simplify=False), sidx1)
if svalid is not None: self.assertEqual(load.src[2].render(simplify=False), svalid)
if svalid is not None: self.assertEqual(load.src[0].src[2].render(simplify=False), svalid)
def test_idx_gt_c(self):
# (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid

View File

@@ -249,11 +249,6 @@ load_store_indexing = PatternMatcher([
UPat.var("val"))), delete_redundant_gates),
])
def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:UOp|None=None) -> UOp:
# this moves the mask from the indexing to the load/store op for rendering
nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx)
return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is Ops.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:])
pm_render = PatternMatcher([
# for rendering, we use explicit VECTORIZE
(UPat(Ops.CONST, name='c'),
@@ -261,12 +256,11 @@ pm_render = PatternMatcher([
(UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
# move masks of loads/stores
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat.var("mask")))
.or_casted("cast"),), allow_any_len=True, name="x"), move_mask),
# give any loads that are masked an alt value
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), name="x"), lambda x: x.replace(src=x.src+(x.const_like(0),))),
# gate any stores that aren't gated with ifs
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store"),
lambda store,idx: UOp(Ops.STORE, src=store.src+(UOp(Ops.IF, src=(idx.src[2],)),))),
])
# *** uop graph ***

View File

@@ -735,7 +735,7 @@ class UPat(MathTrait):
@staticmethod
def any(*src): return UPatAny(src=src)
def or_casted(self, name:str|None=None): return UPat.any(self, UPat(Ops.CAST, name=name, src=(self,)))
def or_casted(self, name:str|None=None): return UPat.any(self if name is None else self.named(name), UPat(Ops.CAST, name=name, src=(self,)))
@staticmethod
@functools.lru_cache(None)

View File

@@ -43,9 +43,10 @@ base_rewrite = PatternMatcher([
# default const render
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
# new load/store
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True),
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
(UPat(Ops.LOAD, src=(UPat.var('bidx'), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("gate"))).or_casted('bidx'), UPat.var("var"))),
lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
(UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"),
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
# alu/gep
@@ -235,7 +236,7 @@ class OpenCLRenderer(CStyleLanguage):
string_rewrite = PatternMatcher([
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"),
# load/store image (OpenCL)
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var"), UPat.var("gate"))),
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("var"))),
lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)),
lambda ctx,buf,idx: f"read_imagef({ctx[buf]}, smp, {ctx[idx]})"),

View File

@@ -56,7 +56,8 @@ base_rewrite = PatternMatcher([
# memory load/store
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
(UPat(Ops.LOAD, src=(UPat.var('idx'), UPat.var('alt'), UPat.var('mask')), name="x"), lambda ctx,x,idx,alt,mask:
(UPat(Ops.LOAD, src=(UPat.or_casted(name='idx', self=UPat(src=(UPat(), UPat(), UPat.var('mask')))), UPat.var('alt')), name="x"),
lambda ctx,x,idx,alt,mask:
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n"
f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n"

View File

@@ -49,6 +49,11 @@ ptx_matcher = PatternMatcher([
# load/store use pointer arithmetic, and the cast does nothing
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize),
(UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None),
# move mask from INDEX to the load/store to enable pointer arithmetic
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat.var("gate"))), UPat.var("alt"))),
lambda buf,idx,gate,alt: UOp(Ops.LOAD, alt.dtype, (buf.index(idx), alt, gate))),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat())), UPat.var("val"), UPat.var("gate"))),
lambda buf,idx,val,gate: UOp.store(buf.index(idx), val, gate)),
# ptx shr and shl instructions require y to be uint
(UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.SHL, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),

View File

@@ -20,7 +20,7 @@ def packed_store(bidx:UOp, var:UOp):
def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None):
div_idx = bidx.src[1]//(4//dtype.itemsize)
shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//dtype.itemsize))*UOp.const(dtypes.uint32, 8*dtype.itemsize)
if var is not None: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), var, root.src[2], dtype=dtypes.uint32, arg=root.arg)
if var is not None: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx, bidx.src[2])), var, dtype=dtypes.uint32, arg=root.arg)
else: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), *root.src[1:], dtype=dtypes.uint32, arg=root.arg)
val = (load.cast(dtypes.uint32) >> shift_am) & (0xFF if dtype.itemsize == 1 else 0xFFFF)
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
@@ -31,7 +31,7 @@ wgsl_matcher = PatternMatcher([
(UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat.var("b")), name="c"),
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
(UPat(Ops.LOAD, name="l", src=(UPat.var("b"),)), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype) else None),
(UPat(Ops.LOAD, name="l", src=(UPat.var("b"), UPat.cvar("c"), UPat())),
(UPat(Ops.LOAD, name="l", src=(UPat.var("b"), UPat.cvar("c"))),
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype) else None),
(UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True), lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype) else None),
# TODO: why is this needed, and only for this MUL order
@@ -64,13 +64,13 @@ class WGSLRenderer(CStyleLanguage):
(UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"),lambda ctx,x:f"bitcast<{ctx.type_map[x.dtype]}>(vec2<f16>({ctx[x.src[0]]},0))" \
if x.src[0].dtype == dtypes.half else f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
(UPat.load(UPat.var("b"), UPat.cvar("v"), UPat.var("g")),lambda ctx,b,v,g: f"select({ctx[v]}, {ctx.render_load(ctx[b], b.dtype)}, {ctx[g]})"),
(UPat.load(UPat.var("b"), UPat.cvar("v")),lambda ctx,b,v: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[b.src[2]]})"),
(UPat.load(UPat.var("b"), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.dtype)),
(UPat.store(UPat.var("b"), UPat.var("v"), allow_any_len=True),lambda ctx,b,v:\
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
else f"{ctx[b]} = {ctx[v]};"),
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"))),
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx")), allow_any_len=True),
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
# fix nan check: 'a != a -> is_nan()'
(UPat.var("a") != UPat.var("a"), lambda ctx,a: f"(min({ctx[a]}, 1.0) == 1.0 && max({ctx[a]}, -1.0) == -1.0)"),

View File

@@ -17,7 +17,7 @@ def _load(m, i):
return m[i]
def load(inp, j=0):
if len(inp) == 3: return [_load(m, x+j if x is not None else None) if gate else default for (m,x),default,gate in zip(*inp)]
if len(inp) == 2: return [_load(m, x+j if x is not None else None) if gate else default for (m,x,gate),default in zip(*inp)]
return [_load(m, x+j if x is not None else None) for m,x in inp[0]]
def _store(m, i, v):
@@ -80,13 +80,14 @@ class PythonProgram:
elif uop is Ops.DEFINE_ACC:
ul[i] = [[inp[0][0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
elif uop is Ops.INDEX:
ret = []
ret:list = []
if isinstance(dtp[0], ImageDType):
for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append((m, None))
else: ret.append((m, ox*4 + oy*dtp[0].shape[1]*4))
else:
for m,o in zip(inp[0], inp[1]): ret.append((m,o))
if len(inp) == 3: ret = [(m,o,g) for (m,o),g in zip(ret, inp[2])] # set the gate last
ul[i] = ret
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
ul[i] = inp[0]

View File

@@ -74,12 +74,14 @@ spec = PatternMatcher([
# **** new style load/store ****
# INDEX is used in new style load/store
# INDEX takes a <buf, alu, gate?>
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True),
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
# LOAD takes a <bufidx, alt?, gate?, barrier?>
# LOAD takes a <bufidx, alt?, barrier?>
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat.var("alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat.var("alt")), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
# STORE takes a <bufidx, val, gate?>
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat())), lambda: True),