removed unused named arg in rules [pr] (#15414)

This commit is contained in:
chenyu
2026-03-22 09:25:46 -04:00
committed by GitHub
parent 2363bceb47
commit bcc08307da
8 changed files with 27 additions and 27 deletions

View File

@@ -286,12 +286,12 @@ pm_render = PatternMatcher([
if len(x.src) == 1 or x.src[1].op in (Ops.CUSTOM, Ops.STORE, Ops.BARRIER) else None),
# Where after gated load becomes alt value
# NOTE: if a is CAST and a.src[0].dtype == l.dtype, use a.src[0] to avoid roundtrip cast (e.g. uint->float->uint)
(UPat.var("c").where(UPat(Ops.LOAD, src=(UPat().index(UPat.var("idx"), UPat.var("c")).or_casted(),), allow_any_len=True, name="l").or_casted(),
UPat.var("a")), lambda c,idx,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+
l.src[2:]).cast(a.dtype)),
(UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat.var("idx"), UPat.var("c").logical_not()).or_casted(),),
allow_any_len=True, name="l").or_casted()), lambda c,idx,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype
else a.cast(l.dtype))+l.src[2:]).cast(a.dtype)),
(UPat.var("c").where(UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c")).or_casted(),), allow_any_len=True, name="l").or_casted(),
UPat.var("a")), lambda c,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+
l.src[2:]).cast(a.dtype)),
(UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c").logical_not()).or_casted(),),
allow_any_len=True, name="l").or_casted()), lambda c,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype
else a.cast(l.dtype))+l.src[2:]).cast(a.dtype)),
])
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***

View File

@@ -127,7 +127,7 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
si_lowerer = PatternMatcher([
(UPat((Ops.SINK, Ops.PROGRAM), name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
(UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])),
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
(UPat(Ops.COPY), lambda ctx: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
if hasattr(alc:=Device[ctx[0].device].allocator, '_transfer') and alc.supports_transfer and all_same([x.device.split(":")[0] for x in ctx]) \
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device))),
(UPat(Ops.CUSTOM_FUNCTION, arg="encdec", name="cf"), lambda ctx,cf: EncDec(cf, ctx[0].nbytes, ctx[0].device)),

View File

@@ -487,8 +487,8 @@ class AMDHIPRenderer(CStyleLanguage):
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]},"
f" {fp8_index(x.src[0].dtype)}, {fp8_index(x.src[0].dtype)}, 0, 0, 0, 0)" if x.arg[1][2] == 128 else None),
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]}, 0, 0, 0)"),
(UPat(Ops.CAST, dtypes.fp8s, (UPat.var("y", dtypes.float),), name="x",),
lambda ctx,x,y: f"f32_to_fp8({ctx[x.src[0]]}, {fp8_index(x.dtype)})"),
(UPat(Ops.CAST, dtypes.fp8s, (UPat(dtype=dtypes.float),), name="x",),
lambda ctx,x: f"f32_to_fp8({ctx[x.src[0]]}, {fp8_index(x.dtype)})"),
(UPat(Ops.CAST, dtypes.float, (UPat.var("y", dtypes.fp8s),), name="x",),
lambda ctx,x,y: f"__builtin_amdgcn_cvt_f32_{('fp8', 'bf8')[fp8_index(y.dtype)]}((unsigned int){ctx[x.src[0]]}, 0)"),
]) + base_rewrite

View File

@@ -224,17 +224,17 @@ class AMDLLVMRenderer(LLVMRenderer):
(UPat(tuple(llvm_intrinsics), name="x"),
lambda ctx, x: f" {ctx[x]} = call {ldt(x.dtype)} @llvm.{llvm_intrinsics[x.op]}.{ldt(x.dtype.scalar())}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
(UPat(Ops.BARRIER), lambda ctx: barrier),
(UPat(Ops.CAST, dtypes.fp8s, (UPat.var("y", dtypes.float),), name="x",), lambda ctx,x,y:
(UPat(Ops.CAST, dtypes.fp8s, (UPat(dtype=dtypes.float),), name="x",), lambda ctx,x:
f" {ctx[x]} = call i8 @f32_to_fp8({ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i1 {'1' if x.dtype == dtypes.fp8e5m2 else '0'})"),
(UPat(Ops.CAST, dtypes.float, (UPat.var("y", dtypes.fp8s),), name="x",), lambda ctx,x,y:
f" {ctx[x.src[0]]}_i32 = zext i8 {ctx[x.src[0]]} to i32\n"
f" {ctx[x]} = call float @llvm.amdgcn.cvt.f32.{'bf8' if y.dtype == dtypes.fp8e5m2 else 'fp8'}(i32 {ctx[x.src[0]]}_i32, i32 0)"),
]) + base_rewrite
extra_matcher = LLVMRenderer.extra_matcher + create_non_native_float_pats(dtypes.fp8s) + PatternMatcher([
(UPat(Ops.CAST, name="x", dtype=dtypes.half.vec(16), src=UPat.var("y", dtypes.half.vec(8))),
lambda x, y: UOp(Ops.VECTORIZE, dtypes.half.vec(16), tuple(y.gep(i // 2) if i % 2 == 0 else UOp.const(dtypes.half, 0.0) for i in range(16)))),
(UPat(Ops.CAST, name="x", dtype=dtypes.half.vec(8), src=UPat.var("y", dtypes.half.vec(16))),
lambda x, y: UOp(Ops.VECTORIZE, dtypes.half.vec(8), tuple(y.gep(i * 2) for i in range(8)))),
(UPat(Ops.CAST, dtype=dtypes.half.vec(16), src=UPat.var("y", dtypes.half.vec(8))),
lambda y: UOp(Ops.VECTORIZE, dtypes.half.vec(16), tuple(y.gep(i // 2) if i % 2 == 0 else UOp.const(dtypes.half, 0.0) for i in range(16)))),
(UPat(Ops.CAST, dtype=dtypes.half.vec(8), src=UPat.var("y", dtypes.half.vec(16))),
lambda y: UOp(Ops.VECTORIZE, dtypes.half.vec(8), tuple(y.gep(i * 2) for i in range(8)))),
# amd llvm intrinsics llvm.log2/llvm.exp2 don't support double
(UPat(Ops.LOG2, dtype=dtypes.double, src=(UPat.var("d"),)), xlog2),
(UPat(Ops.EXP2, dtype=dtypes.double, src=(UPat.var("d"),)), xexp2),

View File

@@ -144,8 +144,8 @@ class NIRRenderer(Renderer):
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8)),
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 4)),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val")), allow_any_len=True, name="x"),
lambda ctx,x,buf,off,val: nstore(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val")), allow_any_len=True),
lambda ctx,buf,off,val: nstore(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"), UPat.var("gate"))), UPat.var("alt")), allow_any_len=True, name="x"),
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
lambda: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype), lambda: ctx.r[alt])),

View File

@@ -317,7 +317,7 @@ pm_const_buffer_folding = pm_mops+PatternMatcher([
# copy on CONST is CONST
(UPat(Ops.COPY, src=(UPat.cvar("x"), UPat()), name="copy"), lambda copy,x: copy.const_like(x.arg)),
# hack if a noop turned to a const
(UPat(Ops.NOOP, src=(UPat.cvar("c"),), name="noop"), lambda c,noop: c),
(UPat(Ops.NOOP, src=(UPat.cvar("c"),)), lambda c: c),
# mstack on CONST is CONST
(UPat(Ops.MSTACK, src=(UPat.var("s"),), allow_any_len=True).f(Ops.INDEX, allow_any_len=True),
lambda s: UOp.const(c.dtype, c.arg) if (c:=s.base).op is Ops.CONST else None),

View File

@@ -69,9 +69,9 @@ shared_spec = PatternMatcher([
# ***** UOp spec in the Tensor graph *****
movement_ops = PatternMatcher([
(UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.weakint))), lambda mv,x: True),
(UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.weakint), UPat(dtype=dtypes.weakint))), lambda mv,x: True),
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)),
(UPat((Ops.RESHAPE, Ops.EXPAND), src=(UPat(), UPat(dtype=dtypes.weakint))), lambda: True),
(UPat((Ops.PAD, Ops.SHRINK), src=(UPat(), UPat(dtype=dtypes.weakint), UPat(dtype=dtypes.weakint))), lambda: True),
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat(),)), lambda mv: isinstance(mv.arg, tuple)),
# inputs to movement ops
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.weakint), lambda: True),
@@ -213,7 +213,7 @@ kernel_spec = PatternMatcher([
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True), lambda: True),
# bufferize can be on anything
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: True),
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True), lambda: True),
# reduce must be on ranges
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype in (dtypes.weakint, dtypes.int) for y in x.src[1:])),
@@ -286,7 +286,7 @@ full_spec = PatternMatcher([
(UPat(Ops.BIND, (dtypes.int, dtypes.weakint), (UPat(), UPat()), arg=None), lambda: True),
# in progress MSTACK may lose device
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
(UPat((Ops.MSELECT, Ops.MSTACK)), lambda: True),
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
(UPat(Ops.VECTORIZE), lambda: True),

View File

@@ -28,19 +28,19 @@ z3_renderer = PatternMatcher([
# loads are variables bounded by the min/max of the dtype. non-pointer INDEX is also a LOAD
(UPat((Ops.LOAD, Ops.INDEX), dtypes.ints+(dtypes.weakint,), name="x"), lambda x,ctx:
create_bounded(f"load{len(ctx[1])}", x.dtype.min, x.dtype.max, ctx[0])),
(UPat((Ops.LOAD, Ops.INDEX), dtypes.bool, name="x"), lambda x,ctx: (z3.Bool(f"load{len(ctx[1])}", ctx=ctx[0]), None)),
(UPat((Ops.LOAD, Ops.INDEX), dtypes.bool), lambda ctx: (z3.Bool(f"load{len(ctx[1])}", ctx=ctx[0]), None)),
# constants
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x,ctx: (z3.Int("Invalid", ctx=ctx[0]), None)),
(UPat(Ops.CONST, arg=Invalid), lambda ctx: (z3.Int("Invalid", ctx=ctx[0]), None)),
(UPat(Ops.CONST, dtypes.ints+(dtypes.weakint,), name="x"), lambda x,ctx: (z3.IntVal(x.arg, ctx=ctx[0]), None)),
(UPat(Ops.CONST, dtypes.bool, name="x"), lambda x,ctx: (z3.BoolVal(x.arg, ctx=ctx[0]), None)),
# casts from floats create new variables
(UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,), src=(UPat(dtype=dtypes.floats),), name="x"), lambda x,ctx:
create_bounded(f"cast{len(ctx[1])}", x.dtype.min, x.dtype.max, ctx[0])),
# A comparison between floats introduces a new bool variable
(UPat(GroupOp.Comparison, src=UPat(dtype=dtypes.floats), name="x"), lambda x,ctx: (z3.Bool(f"float_cmp{len(ctx[1])}", ctx=ctx[0]), None)),
(UPat(GroupOp.Comparison, src=UPat(dtype=dtypes.floats)), lambda ctx: (z3.Bool(f"float_cmp{len(ctx[1])}", ctx=ctx[0]), None)),
# casts from bool/int to int/bool
(UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,),src=(UPat.var("x", dtypes.bool),), name="c"), lambda x,c,ctx: (z3.If(ctx[1][x], 1, 0), None)),
(UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,), src=(UPat.var("x", dtypes.ints+(dtypes.weakint,)),), name="c"), lambda x,c,ctx: (ctx[1][x], None)),
(UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,),src=(UPat.var("x", dtypes.bool),)), lambda x,ctx: (z3.If(ctx[1][x], 1, 0), None)),
(UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,), src=(UPat.var("x", dtypes.ints+(dtypes.weakint,)),)), lambda x,ctx: (ctx[1][x], None)),
(UPat(Ops.CAST, dtypes.bool, name="x"), lambda x,ctx: (ctx[1][x.src[0]]!=0, None)),
(UPat(GroupOp.ALU, name="x"), lambda x,ctx: (z3_alu[x.op](*(ctx[1][s] for s in x.src)), None)),
])