safe changes from new dtype branch [pr] (#7397)

* safe changes from new dtype branch [pr]

* only image test on GPU
This commit is contained in:
George Hotz
2024-10-30 16:18:48 +07:00
committed by GitHub
parent 0ca241693b
commit 4e2895f8d2
9 changed files with 36 additions and 28 deletions

View File

@@ -297,7 +297,7 @@ class TestBoolDType(TestDType): DTYPE = dtypes.bool
class TestPtrDType(unittest.TestCase):
def test_vec_double(self):
dt1 = dtypes.float.vec(4).ptr(v=4)
dt1 = dtypes.float.vec(4).ptr().vec(4)
dt2 = dtypes.float.vec(4).ptr().vec(4)
self.assertEqual(dt1, dt2)
self.assertEqual(str(dt1), str(dt2))
@@ -313,7 +313,7 @@ class TestPtrDType(unittest.TestCase):
self.assertEqual(dt, dtypes.float)
def test_serialize(self):
dt = dtypes.float.vec(4).ptr(v=4)
dt = dtypes.float.vec(4).ptr().vec(4)
self.assertEqual(dt, eval(str(dt)))
def test_vcount(self):
@@ -327,18 +327,18 @@ class TestPtrDType(unittest.TestCase):
self.assertEqual(dt.v, 1)
self.assertEqual(dt.count, 4)
dt = dtypes.float.vec(4).ptr(v=4)
dt = dtypes.float.vec(4).ptr().vec(4)
self.assertEqual(dt.vcount, 4)
self.assertEqual(dt.v, 4)
self.assertEqual(dt.count, 4)
class TestImageDType(unittest.TestCase):
def test_image_scalar(self):
assert dtypes.imagef((10,10)).scalar() == dtypes.float32
assert dtypes.imageh((10,10)).scalar() == dtypes.float32
assert dtypes.imagef((10,10)).base.scalar() == dtypes.float32
assert dtypes.imageh((10,10)).base.scalar() == dtypes.float32
def test_image_vec(self):
assert dtypes.imagef((10,10)).vec(4) == dtypes.float32.vec(4)
assert dtypes.imageh((10,10)).vec(4) == dtypes.float32.vec(4)
assert dtypes.imagef((10,10)).base.vec(4) == dtypes.float32.vec(4)
assert dtypes.imageh((10,10)).base.vec(4) == dtypes.float32.vec(4)
class TestEqStrDType(unittest.TestCase):
def test_image_ne(self):