diff --git a/test/test_ops.py b/test/test_ops.py index 622dc7d364..ba3149dd36 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -839,27 +839,23 @@ class TestOps(unittest.TestCase): helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: x.max(axis=1)) helper_test_op([()], lambda x: x.max()) - @unittest.skipIf(getenv("PTX"), "broken in PTX") def test_any(self): helper_test_op([(3,4,5,6)], lambda x: x.any(), forward_only=True) helper_test_op(None, lambda x: x.any(), vals=[[True, True]], forward_only=True) helper_test_op(None, lambda x: x.any(), vals=[[True, False]], forward_only=True) helper_test_op(None, lambda x: x.any(), vals=[[False, False]], forward_only=True) helper_test_op([()], lambda x: x.any(), forward_only=True) - @unittest.skipIf(getenv("PTX"), "broken in PTX") def test_any_axis(self): helper_test_op([(3,4,5,6)], lambda x: x.any(axis=(1,2)), forward_only=True) def test_any_zero_axis(self): helper_test_op([(1,0,3,0,5)], lambda x: x.any(axis=(1,3)), forward_only=True) - @unittest.skipIf(getenv("PTX"), "broken in PTX") def test_all(self): helper_test_op([(3,4,5,6)], lambda x: x.all(), forward_only=True) helper_test_op(None, lambda x: x.all(), vals=[[True, True]], forward_only=True) helper_test_op(None, lambda x: x.all(), vals=[[True, False]], forward_only=True) helper_test_op(None, lambda x: x.all(), vals=[[False, False]], forward_only=True) helper_test_op([()], lambda x: x.all(), forward_only=True) - @unittest.skipIf(getenv("PTX"), "broken in PTX") def test_all_axis(self): helper_test_op([(3,4,5,6)], lambda x: x.all(axis=(1,2)), forward_only=True) def test_all_zero_axis(self): diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index 0e4e2d1eb7..2db5ce0bbb 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -167,7 +167,7 @@ class PTXRenderer(Renderer): if dtype.count > 1: r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].arg, dtype.scalar())};") - else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(src[0].arg, dtype)};") + else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {ssa('acc', u)}, {const(src[0].arg, dtype)};") elif uop is UOps.SPECIAL: assert args[1][0] != "i", "idx not supported" kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};") @@ -194,8 +194,7 @@ class PTXRenderer(Renderer): elif uop is UOps.PHI: if dtype.count > 1: for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};") - else: - kk(f"mov.b{self.types[dtype][1:]} {r[src[0]]}, {r[src[1]]};") + else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {r[src[0]]}, {r[src[1]]};") r[u] = r[src[0]] elif uop in {UOps.VECTORIZE}: assert src[0].dtype is not None and dtype.count > 1 @@ -245,6 +244,8 @@ ptx_matcher = PatternMatcher([ *[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"), lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.op, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.src]), x.arg),))) for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half], + (UPat(UOps.ALU, name="x", dtype=dtypes.bool, arg=BinaryOps.MAX), + lambda x: UOp(UOps.CAST, dtypes.bool, (UOp(UOps.ALU, dtypes.uint8, tuple(UOp(UOps.CAST, dtypes.uint8, (s,)) for s in x.src), x.arg),))), (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))), lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)), (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),