Restore vcount [pr] (#7390)

* Revert "Revert "add vcount to PtrDtype (#7388)""

This reverts commit 399a5219dd.

* Revert "Revert "add tests to vcount stuff [pr] (#7389)""

This reverts commit cc8d6dbdf3.

* no ptr
This commit is contained in:
George Hotz
2024-10-30 10:27:55 +07:00
committed by GitHub
parent 399a5219dd
commit 1058f9c9ff
5 changed files with 56 additions and 8 deletions

View File

@@ -295,6 +295,43 @@ class TestUint64DType(TestDType):
class TestBoolDType(TestDType): DTYPE = dtypes.bool
class TestPtrDType(unittest.TestCase):
def test_vec_double(self):
dt1 = dtypes.float.vec(4).ptr(v=4)
dt2 = dtypes.float.vec(4).ptr().vec(4)
self.assertEqual(dt1, dt2)
self.assertEqual(str(dt1), str(dt2))
def test_scalar(self):
dt = dtypes.float.vec(4).ptr().scalar()
self.assertEqual(dt.base, dtypes.float.vec(4))
dt = dtypes.float.vec(4).ptr().vec(4).scalar()
self.assertEqual(dt.base, dtypes.float.vec(4))
dt = dtypes.float.vec(4).scalar()
self.assertEqual(dt, dtypes.float)
def test_serialize(self):
dt = dtypes.float.vec(4).ptr(v=4)
self.assertEqual(dt, eval(str(dt)))
def test_vcount(self):
dt = dtypes.float.ptr().vec(4)
self.assertEqual(dt.vcount, 4)
self.assertEqual(dt.v, 4)
self.assertEqual(dt.count, 1)
dt = dtypes.float.vec(4).ptr()
self.assertEqual(dt.vcount, 1)
self.assertEqual(dt.v, 1)
self.assertEqual(dt.count, 4)
dt = dtypes.float.vec(4).ptr(v=4)
self.assertEqual(dt.vcount, 4)
self.assertEqual(dt.v, 4)
self.assertEqual(dt.count, 4)
class TestImageDType(unittest.TestCase):
def test_image_scalar(self):
assert dtypes.imagef((10,10)).scalar() == dtypes.float32