Bool cast to cmpne (#14544)

* test

* rm in llvmir

* rm in ptx and nir

* hmmmm

* rm in decompositions

* skip tests

* add test

* just this

* rm comment

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
ttomsa
2026-02-23 15:31:36 +00:00
committed by GitHub
parent 806581f807
commit 0366474089
7 changed files with 6 additions and 10 deletions

View File

@@ -78,7 +78,9 @@ class TestCStyleFailures(unittest.TestCase):
def test_repeat_add(self): self._test_src_strip_paren(Ops.ADD)
def test_repeat_mul(self): self._test_src_strip_paren(Ops.MUL)
def test_repeat_xor(self): self._test_src_strip_paren(Ops.XOR)
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, WGSLRenderer), "wgsl ends up with '(' * 5")
def test_repeat_or(self): self._test_src_strip_paren(Ops.OR)
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, WGSLRenderer), "wgsl ends up with '(' * 5")
def test_repeat_and(self): self._test_src_strip_paren(Ops.AND)
def test_repeat_sub(self): self._test_src_strip_paren(Ops.SUB, should_strip_paren=False)

View File

@@ -390,6 +390,9 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(Variable("a", 0, 6) < 3, 0, 1, "(a<3)")
self.helper_test_variable(Variable("a", 0, 6) < 8, 1, 1, "True")
def test_cast_bool(self):
self.helper_test_variable(Variable("a", 0, 10).cast(dtypes.bool), 0, 1, "a!=0")
def test_lt_sum_remove(self):
self.helper_test_variable(Variable("a", 0, 6) + 2 < 3, 0, 1, "(a<1)")

View File

@@ -97,9 +97,6 @@ base_rewrite = PatternMatcher([
f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),
# unary/binary/ternary ops
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
# rewrite cast to bool to CMPNE 0
(UPat(Ops.CAST, name="x", dtype=dtypes.bool),
lambda ctx,x: f" {ctx[x]} = {lop[x.src[0].dtype.scalar()][Ops.CMPNE]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, zeroinitializer"),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
(UPat(Ops.TRUNC, name="x"),
lambda ctx,x: f" {ctx[x]} = call {ldt(x.dtype)} @llvm.trunc.{ldt(x.dtype.scalar())}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),

View File

@@ -26,7 +26,6 @@ aop = {**{x:u_aop for x in (dtypes.bool,)+dtypes.uints}, **{x:s_aop for x in dty
def c(t:DType, u:bool=True) -> str: return "u" if t in dtypes.uints and u else ("i" if t in dtypes.ints else ("f" if t in dtypes.floats else "b"))
def ncast(b:mesa.nir_builder, src:mesa.nir_def, it:DType, ot:DType) -> mesa.nir_def:
if isinstance(it, PtrDType) and ot == dtypes.long: return src
if ot == dtypes.bool: return nalu(b, c(it, False)+'ne'+('u' if c(it) == 'f' else ''), src, nimm(b, 0, it))
return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.bitsize}", src)
def nif(b:mesa.nir_builder, cond:mesa.nir_def, then_fn:Callable, else_fn:Callable):

View File

@@ -99,8 +99,6 @@ string_rewrite = PatternMatcher([
(UPat(Ops.BITCAST, name="x", src=(UPat.var("a"),), allow_any_len=True), lambda ctx, x, a: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {ctx.r[a]};"),
(UPat(Ops.CAST, name="x", src=(UPat(dtype=dtypes.bool, name="a"),)),
lambda ctx, x, a: f"selp.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(1, x.dtype)}, {render_val(0, x.dtype)}, {ctx.r[a]};"),
(UPat(Ops.CAST, name="x", dtype=dtypes.bool, src=(UPat.var("a"),)),
lambda ctx, x, a: f"setp.ne.b{ctx.types[a.dtype][1:]} {ctx.r[x]}, {ctx.r[a]}, {render_val(0, a.dtype)};"),
(UPat(Ops.CAST, name="x", src=(UPat.var("a"),)),
lambda ctx, x, a: f"cvt{modifier(x.dtype, a.dtype)}.{ctx.cast_types[x.dtype]}.{ctx.cast_types[a.dtype]} {ctx.r[x]}, {ctx.r[a]};"),
# store / gated load / load

View File

@@ -335,7 +335,6 @@ def l2i(op: Ops, dt: DType, *uops:UOp):
case Ops.CAST if dt in dtypes.floats:
small = (a1.eq(0) & (a0 >= 0)) | (a1.eq(-1) & (a0 < 0))
return small.where(a0.cast(dt), ((a1.cast(dtypes.float32) * (2**32)) + a0.bitcast(dtypes.uint).cast(dtypes.float32)).cast(dt))
case Ops.CAST if dt == dtypes.bool: return a0.ne(UOp.const(a0.dtype, 0)) | a1.ne(UOp.const(a1.dtype, 0))
case Ops.CAST: return a0.bitcast(dtypes.uint).cast(dt)
case Ops.BITCAST: return a0.bitcast(dt), a1.bitcast(dt)
case Ops.SHL:

View File

@@ -105,6 +105,7 @@ symbolic_simple = propagate_invalid + PatternMatcher([
(UPat(Ops.BITCAST, name="root", src=(UPat.cvar("c"),)), fold_bitcast),
# b.cast(a).cast(b) -> b if a preserves all values in b
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x if x.dtype == b.dtype and can_lossless_cast(b.dtype, a.dtype) else None),
(UPat.var("x").cast(dtypes.bool), lambda x: x != 0),
# ** pow **
(UPat.var("x").alu(Ops.POW, UPat.cvar("c", vec=False)), simplify_pow),
# positive const ** x
@@ -395,9 +396,6 @@ sym = symbolic+pm_simplify_valid+PatternMatcher([
# reorder ALU/VECTORIZE
(UPat(GroupOp.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'),
lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(alu.op, alu.dtype.scalar(), (x,y)),)*alu.dtype.count)),
# ** self folding **
# x!=0 -> (bool)x
(UPat.var("x")!=0, lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
# ** where **
# # fold nested where with same condition: in cond.where(t,f), cond.where(a,b)->a in t, ->b in f
# (UPat.var("cond").where(UPat.var("t"), UPat.var("f")), fold_where_closure),