mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
@@ -8,6 +8,7 @@ from tinygrad.dtype import dtypes
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import dedup, flatten, prod
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
@@ -41,7 +42,7 @@ class TestCStyleFailures(unittest.TestCase):
|
||||
ret = _test_uop_result([Tensor([1])], uops)[0]
|
||||
self.assertEqual(ret[0], 1)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local and Device.DEFAULT == "PTX", "need local")
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "tests for ptx renderer")
|
||||
class TestPTXFailures(unittest.TestCase):
|
||||
def test_gated_store_with_alu(self):
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
@@ -63,5 +64,12 @@ class TestPTXFailures(unittest.TestCase):
|
||||
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
|
||||
def test_gated_define_acc_with_half_dtype(self):
|
||||
a = Tensor.randn(32, 32, dtype=dtypes.half).realize()
|
||||
b = Tensor.randn(34, 32, dtype=dtypes.half).realize()
|
||||
result = a.pad((1,1)).matmul(b, acc_dtype=dtypes.half).numpy()
|
||||
reference = a.pad((1,1)).matmul(b, acc_dtype=dtypes.float).numpy()
|
||||
np.testing.assert_allclose(result, reference)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -176,8 +176,7 @@ class PTXRenderer(Renderer):
|
||||
if u.op in {Ops.CAST, Ops.BITCAST} and (u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType)):
|
||||
r[u] = r[u.src[0]]
|
||||
continue
|
||||
if u.op is Ops.DEFINE_ACC and u.dtype in [dtypes.half, dtypes.bool]: r[u.src[0]] = ssa("const", u.src[0])
|
||||
elif u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0]
|
||||
if u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0]
|
||||
elif u.op is Ops.DEFINE_VAR: bufs.append((u.arg[0], u.dtype))
|
||||
elif u.op is Ops.LOAD:
|
||||
assert u.src[0].dtype == dtypes.int64, "load isn't int64"
|
||||
|
||||
Reference in New Issue
Block a user