diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index e2cf6f26af..f167e0d5bc 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -8,6 +8,7 @@ from tinygrad.dtype import dtypes from tinygrad.engine.realize import CompiledRunner from tinygrad.helpers import dedup, flatten, prod from tinygrad.renderer.cstyle import CStyleLanguage +from tinygrad.renderer.ptx import PTXRenderer from tinygrad.ops import UOp, Ops from tinygrad.renderer import ProgramSpec from tinygrad.tensor import Tensor, _to_np_dtype @@ -41,7 +42,7 @@ class TestCStyleFailures(unittest.TestCase): ret = _test_uop_result([Tensor([1])], uops)[0] self.assertEqual(ret[0], 1) -@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local and Device.DEFAULT == "PTX", "need local") +@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "tests for ptx renderer") class TestPTXFailures(unittest.TestCase): def test_gated_store_with_alu(self): a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) @@ -63,5 +64,12 @@ class TestPTXFailures(unittest.TestCase): ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0] np.testing.assert_equal(ret, [0, 1, 1, 1]) + def test_gated_define_acc_with_half_dtype(self): + a = Tensor.randn(32, 32, dtype=dtypes.half).realize() + b = Tensor.randn(34, 32, dtype=dtypes.half).realize() + result = a.pad((1,1)).matmul(b, acc_dtype=dtypes.half).numpy() + reference = a.pad((1,1)).matmul(b, acc_dtype=dtypes.float).numpy() + np.testing.assert_allclose(result, reference) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 37a46c1348..d16653357c 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -176,8 +176,7 @@ class PTXRenderer(Renderer): if u.op in {Ops.CAST, Ops.BITCAST} and (u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType)): r[u] = r[u.src[0]] continue - if u.op is Ops.DEFINE_ACC and u.dtype in [dtypes.half, dtypes.bool]: r[u.src[0]] = ssa("const", u.src[0]) - elif u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0] + if u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0] elif u.op is Ops.DEFINE_VAR: bufs.append((u.arg[0], u.dtype)) elif u.op is Ops.LOAD: assert u.src[0].dtype == dtypes.int64, "load isn't int64"