fix and test (#8814)

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Ignacio Sica
2025-01-30 18:35:53 -03:00
committed by GitHub
parent f5da275f46
commit f0924e0857
2 changed files with 10 additions and 3 deletions

View File

@@ -8,6 +8,7 @@ from tinygrad.dtype import dtypes
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import dedup, flatten, prod
from tinygrad.renderer.cstyle import CStyleLanguage
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.ops import UOp, Ops
from tinygrad.renderer import ProgramSpec
from tinygrad.tensor import Tensor, _to_np_dtype
@@ -41,7 +42,7 @@ class TestCStyleFailures(unittest.TestCase):
ret = _test_uop_result([Tensor([1])], uops)[0]
self.assertEqual(ret[0], 1)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local and Device.DEFAULT == "PTX", "need local")
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "tests for ptx renderer")
class TestPTXFailures(unittest.TestCase):
def test_gated_store_with_alu(self):
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
@@ -63,5 +64,12 @@ class TestPTXFailures(unittest.TestCase):
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
np.testing.assert_equal(ret, [0, 1, 1, 1])
def test_gated_define_acc_with_half_dtype(self):
a = Tensor.randn(32, 32, dtype=dtypes.half).realize()
b = Tensor.randn(34, 32, dtype=dtypes.half).realize()
result = a.pad((1,1)).matmul(b, acc_dtype=dtypes.half).numpy()
reference = a.pad((1,1)).matmul(b, acc_dtype=dtypes.float).numpy()
np.testing.assert_allclose(result, reference)
if __name__ == '__main__':
unittest.main()

View File

@@ -176,8 +176,7 @@ class PTXRenderer(Renderer):
if u.op in {Ops.CAST, Ops.BITCAST} and (u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType)):
r[u] = r[u.src[0]]
continue
if u.op is Ops.DEFINE_ACC and u.dtype in [dtypes.half, dtypes.bool]: r[u.src[0]] = ssa("const", u.src[0])
elif u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0]
if u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0]
elif u.op is Ops.DEFINE_VAR: bufs.append((u.arg[0], u.dtype))
elif u.op is Ops.LOAD:
assert u.src[0].dtype == dtypes.int64, "load isn't int64"