mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 14:28:09 -05:00
fix Tensor.all and Tensor.any for PTX (#5335)
supported boolean acc and boolean phi. and rewrite boolean max to uint8 max
This commit is contained in:
@@ -839,27 +839,23 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: x.max(axis=1))
|
||||
helper_test_op([()], lambda x: x.max())
|
||||
|
||||
@unittest.skipIf(getenv("PTX"), "broken in PTX")
|
||||
def test_any(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.any(), forward_only=True)
|
||||
helper_test_op(None, lambda x: x.any(), vals=[[True, True]], forward_only=True)
|
||||
helper_test_op(None, lambda x: x.any(), vals=[[True, False]], forward_only=True)
|
||||
helper_test_op(None, lambda x: x.any(), vals=[[False, False]], forward_only=True)
|
||||
helper_test_op([()], lambda x: x.any(), forward_only=True)
|
||||
@unittest.skipIf(getenv("PTX"), "broken in PTX")
|
||||
def test_any_axis(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.any(axis=(1,2)), forward_only=True)
|
||||
def test_any_zero_axis(self):
|
||||
helper_test_op([(1,0,3,0,5)], lambda x: x.any(axis=(1,3)), forward_only=True)
|
||||
|
||||
@unittest.skipIf(getenv("PTX"), "broken in PTX")
|
||||
def test_all(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.all(), forward_only=True)
|
||||
helper_test_op(None, lambda x: x.all(), vals=[[True, True]], forward_only=True)
|
||||
helper_test_op(None, lambda x: x.all(), vals=[[True, False]], forward_only=True)
|
||||
helper_test_op(None, lambda x: x.all(), vals=[[False, False]], forward_only=True)
|
||||
helper_test_op([()], lambda x: x.all(), forward_only=True)
|
||||
@unittest.skipIf(getenv("PTX"), "broken in PTX")
|
||||
def test_all_axis(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.all(axis=(1,2)), forward_only=True)
|
||||
def test_all_zero_axis(self):
|
||||
|
||||
@@ -167,7 +167,7 @@ class PTXRenderer(Renderer):
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].arg, dtype.scalar())};")
|
||||
else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
|
||||
else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
|
||||
elif uop is UOps.SPECIAL:
|
||||
assert args[1][0] != "i", "idx not supported"
|
||||
kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};")
|
||||
@@ -194,8 +194,7 @@ class PTXRenderer(Renderer):
|
||||
elif uop is UOps.PHI:
|
||||
if dtype.count > 1:
|
||||
for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
|
||||
else:
|
||||
kk(f"mov.b{self.types[dtype][1:]} {r[src[0]]}, {r[src[1]]};")
|
||||
else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {r[src[0]]}, {r[src[1]]};")
|
||||
r[u] = r[src[0]]
|
||||
elif uop in {UOps.VECTORIZE}:
|
||||
assert src[0].dtype is not None and dtype.count > 1
|
||||
@@ -245,6 +244,8 @@ ptx_matcher = PatternMatcher([
|
||||
*[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
|
||||
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.op, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.src]), x.arg),)))
|
||||
for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half],
|
||||
(UPat(UOps.ALU, name="x", dtype=dtypes.bool, arg=BinaryOps.MAX),
|
||||
lambda x: UOp(UOps.CAST, dtypes.bool, (UOp(UOps.ALU, dtypes.uint8, tuple(UOp(UOps.CAST, dtypes.uint8, (s,)) for s in x.src), x.arg),))),
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
|
||||
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
|
||||
|
||||
Reference in New Issue
Block a user