diff --git a/test/backend/test_renderer_failures.py b/test/backend/test_renderer_failures.py index 5f9dd02a20..59928be324 100644 --- a/test/backend/test_renderer_failures.py +++ b/test/backend/test_renderer_failures.py @@ -78,7 +78,9 @@ class TestCStyleFailures(unittest.TestCase): def test_repeat_add(self): self._test_src_strip_paren(Ops.ADD) def test_repeat_mul(self): self._test_src_strip_paren(Ops.MUL) def test_repeat_xor(self): self._test_src_strip_paren(Ops.XOR) + @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, WGSLRenderer), "wgsl ends up with '(' * 5") def test_repeat_or(self): self._test_src_strip_paren(Ops.OR) + @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, WGSLRenderer), "wgsl ends up with '(' * 5") def test_repeat_and(self): self._test_src_strip_paren(Ops.AND) def test_repeat_sub(self): self._test_src_strip_paren(Ops.SUB, should_strip_paren=False) diff --git a/test/null/test_uop_symbolic.py b/test/null/test_uop_symbolic.py index e25dae745e..7a2e293bac 100644 --- a/test/null/test_uop_symbolic.py +++ b/test/null/test_uop_symbolic.py @@ -390,6 +390,9 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(Variable("a", 0, 6) < 3, 0, 1, "(a<3)") self.helper_test_variable(Variable("a", 0, 6) < 8, 1, 1, "True") + def test_cast_bool(self): + self.helper_test_variable(Variable("a", 0, 10).cast(dtypes.bool), 0, 1, "a!=0") + def test_lt_sum_remove(self): self.helper_test_variable(Variable("a", 0, 6) + 2 < 3, 0, 1, "(a<1)") diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 8adafe0fbf..4cb71aa571 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -97,9 +97,6 @@ base_rewrite = PatternMatcher([ f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])), # unary/binary/ternary ops (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"), - # rewrite cast to bool to CMPNE 0 - (UPat(Ops.CAST, name="x", dtype=dtypes.bool), - lambda ctx,x: f" {ctx[x]} = {lop[x.src[0].dtype.scalar()][Ops.CMPNE]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, zeroinitializer"), (UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"), (UPat(Ops.TRUNC, name="x"), lambda ctx,x: f" {ctx[x]} = call {ldt(x.dtype)} @llvm.trunc.{ldt(x.dtype.scalar())}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"), diff --git a/tinygrad/renderer/nir.py b/tinygrad/renderer/nir.py index de257f3c16..21ef133160 100644 --- a/tinygrad/renderer/nir.py +++ b/tinygrad/renderer/nir.py @@ -26,7 +26,6 @@ aop = {**{x:u_aop for x in (dtypes.bool,)+dtypes.uints}, **{x:s_aop for x in dty def c(t:DType, u:bool=True) -> str: return "u" if t in dtypes.uints and u else ("i" if t in dtypes.ints else ("f" if t in dtypes.floats else "b")) def ncast(b:mesa.nir_builder, src:mesa.nir_def, it:DType, ot:DType) -> mesa.nir_def: if isinstance(it, PtrDType) and ot == dtypes.long: return src - if ot == dtypes.bool: return nalu(b, c(it, False)+'ne'+('u' if c(it) == 'f' else ''), src, nimm(b, 0, it)) return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.bitsize}", src) def nif(b:mesa.nir_builder, cond:mesa.nir_def, then_fn:Callable, else_fn:Callable): diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 60259d9186..7e41391ef9 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -99,8 +99,6 @@ string_rewrite = PatternMatcher([ (UPat(Ops.BITCAST, name="x", src=(UPat.var("a"),), allow_any_len=True), lambda ctx, x, a: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {ctx.r[a]};"), (UPat(Ops.CAST, name="x", src=(UPat(dtype=dtypes.bool, name="a"),)), lambda ctx, x, a: f"selp.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(1, x.dtype)}, {render_val(0, x.dtype)}, {ctx.r[a]};"), - (UPat(Ops.CAST, name="x", dtype=dtypes.bool, src=(UPat.var("a"),)), - lambda ctx, x, a: f"setp.ne.b{ctx.types[a.dtype][1:]} {ctx.r[x]}, {ctx.r[a]}, {render_val(0, a.dtype)};"), (UPat(Ops.CAST, name="x", src=(UPat.var("a"),)), lambda ctx, x, a: f"cvt{modifier(x.dtype, a.dtype)}.{ctx.cast_types[x.dtype]}.{ctx.cast_types[a.dtype]} {ctx.r[x]}, {ctx.r[a]};"), # store / gated load / load diff --git a/tinygrad/uop/decompositions.py b/tinygrad/uop/decompositions.py index 0a5f526dad..a153847a99 100644 --- a/tinygrad/uop/decompositions.py +++ b/tinygrad/uop/decompositions.py @@ -335,7 +335,6 @@ def l2i(op: Ops, dt: DType, *uops:UOp): case Ops.CAST if dt in dtypes.floats: small = (a1.eq(0) & (a0 >= 0)) | (a1.eq(-1) & (a0 < 0)) return small.where(a0.cast(dt), ((a1.cast(dtypes.float32) * (2**32)) + a0.bitcast(dtypes.uint).cast(dtypes.float32)).cast(dt)) - case Ops.CAST if dt == dtypes.bool: return a0.ne(UOp.const(a0.dtype, 0)) | a1.ne(UOp.const(a1.dtype, 0)) case Ops.CAST: return a0.bitcast(dtypes.uint).cast(dt) case Ops.BITCAST: return a0.bitcast(dt), a1.bitcast(dt) case Ops.SHL: diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index e636b8070b..1e8f6bd31c 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -105,6 +105,7 @@ symbolic_simple = propagate_invalid + PatternMatcher([ (UPat(Ops.BITCAST, name="root", src=(UPat.cvar("c"),)), fold_bitcast), # b.cast(a).cast(b) -> b if a preserves all values in b (UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x if x.dtype == b.dtype and can_lossless_cast(b.dtype, a.dtype) else None), + (UPat.var("x").cast(dtypes.bool), lambda x: x != 0), # ** pow ** (UPat.var("x").alu(Ops.POW, UPat.cvar("c", vec=False)), simplify_pow), # positive const ** x @@ -395,9 +396,6 @@ sym = symbolic+pm_simplify_valid+PatternMatcher([ # reorder ALU/VECTORIZE (UPat(GroupOp.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'), lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(alu.op, alu.dtype.scalar(), (x,y)),)*alu.dtype.count)), - # ** self folding ** - # x!=0 -> (bool)x - (UPat.var("x")!=0, lambda x: x.cast(dtypes.bool.vec(x.dtype.count))), # ** where ** # # fold nested where with same condition: in cond.where(t,f), cond.where(a,b)->a in t, ->b in f # (UPat.var("cond").where(UPat.var("t"), UPat.var("f")), fold_where_closure),