mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Bool cast to cmpne (#14544)
* test * rm in llvmir * rm in ptx and nir * hmmmm * rm in decompositions * skip tests * add test * just this * rm comment --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -78,7 +78,9 @@ class TestCStyleFailures(unittest.TestCase):
|
||||
def test_repeat_add(self): self._test_src_strip_paren(Ops.ADD)
|
||||
def test_repeat_mul(self): self._test_src_strip_paren(Ops.MUL)
|
||||
def test_repeat_xor(self): self._test_src_strip_paren(Ops.XOR)
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, WGSLRenderer), "wgsl ends up with '(' * 5")
|
||||
def test_repeat_or(self): self._test_src_strip_paren(Ops.OR)
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, WGSLRenderer), "wgsl ends up with '(' * 5")
|
||||
def test_repeat_and(self): self._test_src_strip_paren(Ops.AND)
|
||||
def test_repeat_sub(self): self._test_src_strip_paren(Ops.SUB, should_strip_paren=False)
|
||||
|
||||
|
||||
@@ -390,6 +390,9 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(Variable("a", 0, 6) < 3, 0, 1, "(a<3)")
|
||||
self.helper_test_variable(Variable("a", 0, 6) < 8, 1, 1, "True")
|
||||
|
||||
def test_cast_bool(self):
|
||||
self.helper_test_variable(Variable("a", 0, 10).cast(dtypes.bool), 0, 1, "a!=0")
|
||||
|
||||
def test_lt_sum_remove(self):
|
||||
self.helper_test_variable(Variable("a", 0, 6) + 2 < 3, 0, 1, "(a<1)")
|
||||
|
||||
|
||||
@@ -97,9 +97,6 @@ base_rewrite = PatternMatcher([
|
||||
f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),
|
||||
# unary/binary/ternary ops
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
||||
# rewrite cast to bool to CMPNE 0
|
||||
(UPat(Ops.CAST, name="x", dtype=dtypes.bool),
|
||||
lambda ctx,x: f" {ctx[x]} = {lop[x.src[0].dtype.scalar()][Ops.CMPNE]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, zeroinitializer"),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
||||
(UPat(Ops.TRUNC, name="x"),
|
||||
lambda ctx,x: f" {ctx[x]} = call {ldt(x.dtype)} @llvm.trunc.{ldt(x.dtype.scalar())}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
|
||||
|
||||
@@ -26,7 +26,6 @@ aop = {**{x:u_aop for x in (dtypes.bool,)+dtypes.uints}, **{x:s_aop for x in dty
|
||||
def c(t:DType, u:bool=True) -> str: return "u" if t in dtypes.uints and u else ("i" if t in dtypes.ints else ("f" if t in dtypes.floats else "b"))
|
||||
def ncast(b:mesa.nir_builder, src:mesa.nir_def, it:DType, ot:DType) -> mesa.nir_def:
|
||||
if isinstance(it, PtrDType) and ot == dtypes.long: return src
|
||||
if ot == dtypes.bool: return nalu(b, c(it, False)+'ne'+('u' if c(it) == 'f' else ''), src, nimm(b, 0, it))
|
||||
return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.bitsize}", src)
|
||||
|
||||
def nif(b:mesa.nir_builder, cond:mesa.nir_def, then_fn:Callable, else_fn:Callable):
|
||||
|
||||
@@ -99,8 +99,6 @@ string_rewrite = PatternMatcher([
|
||||
(UPat(Ops.BITCAST, name="x", src=(UPat.var("a"),), allow_any_len=True), lambda ctx, x, a: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {ctx.r[a]};"),
|
||||
(UPat(Ops.CAST, name="x", src=(UPat(dtype=dtypes.bool, name="a"),)),
|
||||
lambda ctx, x, a: f"selp.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(1, x.dtype)}, {render_val(0, x.dtype)}, {ctx.r[a]};"),
|
||||
(UPat(Ops.CAST, name="x", dtype=dtypes.bool, src=(UPat.var("a"),)),
|
||||
lambda ctx, x, a: f"setp.ne.b{ctx.types[a.dtype][1:]} {ctx.r[x]}, {ctx.r[a]}, {render_val(0, a.dtype)};"),
|
||||
(UPat(Ops.CAST, name="x", src=(UPat.var("a"),)),
|
||||
lambda ctx, x, a: f"cvt{modifier(x.dtype, a.dtype)}.{ctx.cast_types[x.dtype]}.{ctx.cast_types[a.dtype]} {ctx.r[x]}, {ctx.r[a]};"),
|
||||
# store / gated load / load
|
||||
|
||||
@@ -335,7 +335,6 @@ def l2i(op: Ops, dt: DType, *uops:UOp):
|
||||
case Ops.CAST if dt in dtypes.floats:
|
||||
small = (a1.eq(0) & (a0 >= 0)) | (a1.eq(-1) & (a0 < 0))
|
||||
return small.where(a0.cast(dt), ((a1.cast(dtypes.float32) * (2**32)) + a0.bitcast(dtypes.uint).cast(dtypes.float32)).cast(dt))
|
||||
case Ops.CAST if dt == dtypes.bool: return a0.ne(UOp.const(a0.dtype, 0)) | a1.ne(UOp.const(a1.dtype, 0))
|
||||
case Ops.CAST: return a0.bitcast(dtypes.uint).cast(dt)
|
||||
case Ops.BITCAST: return a0.bitcast(dt), a1.bitcast(dt)
|
||||
case Ops.SHL:
|
||||
|
||||
@@ -105,6 +105,7 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
||||
(UPat(Ops.BITCAST, name="root", src=(UPat.cvar("c"),)), fold_bitcast),
|
||||
# b.cast(a).cast(b) -> b if a preserves all values in b
|
||||
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x if x.dtype == b.dtype and can_lossless_cast(b.dtype, a.dtype) else None),
|
||||
(UPat.var("x").cast(dtypes.bool), lambda x: x != 0),
|
||||
# ** pow **
|
||||
(UPat.var("x").alu(Ops.POW, UPat.cvar("c", vec=False)), simplify_pow),
|
||||
# positive const ** x
|
||||
@@ -395,9 +396,6 @@ sym = symbolic+pm_simplify_valid+PatternMatcher([
|
||||
# reorder ALU/VECTORIZE
|
||||
(UPat(GroupOp.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'),
|
||||
lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(alu.op, alu.dtype.scalar(), (x,y)),)*alu.dtype.count)),
|
||||
# ** self folding **
|
||||
# x!=0 -> (bool)x
|
||||
(UPat.var("x")!=0, lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
|
||||
# ** where **
|
||||
# # fold nested where with same condition: in cond.where(t,f), cond.where(a,b)->a in t, ->b in f
|
||||
# (UPat.var("cond").where(UPat.var("t"), UPat.var("f")), fold_where_closure),
|
||||
|
||||
Reference in New Issue
Block a user