fix Tensor.all and Tensor.any for PTX (#5335)

supported boolean acc and boolean phi. and rewrite boolean max to uint8 max
This commit is contained in:
chenyu
2024-07-08 18:15:04 -04:00
committed by GitHub
parent 053c706961
commit 0f0940225a
2 changed files with 4 additions and 7 deletions

View File

@@ -839,27 +839,23 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: x.max(axis=1))
helper_test_op([()], lambda x: x.max())
@unittest.skipIf(getenv("PTX"), "broken in PTX")
def test_any(self):
helper_test_op([(3,4,5,6)], lambda x: x.any(), forward_only=True)
helper_test_op(None, lambda x: x.any(), vals=[[True, True]], forward_only=True)
helper_test_op(None, lambda x: x.any(), vals=[[True, False]], forward_only=True)
helper_test_op(None, lambda x: x.any(), vals=[[False, False]], forward_only=True)
helper_test_op([()], lambda x: x.any(), forward_only=True)
@unittest.skipIf(getenv("PTX"), "broken in PTX")
def test_any_axis(self):
helper_test_op([(3,4,5,6)], lambda x: x.any(axis=(1,2)), forward_only=True)
def test_any_zero_axis(self):
helper_test_op([(1,0,3,0,5)], lambda x: x.any(axis=(1,3)), forward_only=True)
@unittest.skipIf(getenv("PTX"), "broken in PTX")
def test_all(self):
helper_test_op([(3,4,5,6)], lambda x: x.all(), forward_only=True)
helper_test_op(None, lambda x: x.all(), vals=[[True, True]], forward_only=True)
helper_test_op(None, lambda x: x.all(), vals=[[True, False]], forward_only=True)
helper_test_op(None, lambda x: x.all(), vals=[[False, False]], forward_only=True)
helper_test_op([()], lambda x: x.all(), forward_only=True)
@unittest.skipIf(getenv("PTX"), "broken in PTX")
def test_all_axis(self):
helper_test_op([(3,4,5,6)], lambda x: x.all(axis=(1,2)), forward_only=True)
def test_all_zero_axis(self):

View File

@@ -167,7 +167,7 @@ class PTXRenderer(Renderer):
if dtype.count > 1:
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].arg, dtype.scalar())};")
else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
elif uop is UOps.SPECIAL:
assert args[1][0] != "i", "idx not supported"
kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};")
@@ -194,8 +194,7 @@ class PTXRenderer(Renderer):
elif uop is UOps.PHI:
if dtype.count > 1:
for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
else:
kk(f"mov.b{self.types[dtype][1:]} {r[src[0]]}, {r[src[1]]};")
else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {r[src[0]]}, {r[src[1]]};")
r[u] = r[src[0]]
elif uop in {UOps.VECTORIZE}:
assert src[0].dtype is not None and dtype.count > 1
@@ -245,6 +244,8 @@ ptx_matcher = PatternMatcher([
*[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.op, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.src]), x.arg),)))
for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half],
(UPat(UOps.ALU, name="x", dtype=dtypes.bool, arg=BinaryOps.MAX),
lambda x: UOp(UOps.CAST, dtypes.bool, (UOp(UOps.ALU, dtypes.uint8, tuple(UOp(UOps.CAST, dtypes.uint8, (s,)) for s in x.src), x.arg),))),
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),