diff --git a/test/unit/test_dtype.py b/test/unit/test_dtype.py index d429a359fe..e5db4af55e 100644 --- a/test/unit/test_dtype.py +++ b/test/unit/test_dtype.py @@ -1,6 +1,6 @@ -import unittest +import unittest, pickle from tinygrad.tensor import Tensor -from tinygrad.dtype import dtypes, DType, ImageDType, PtrDType, to_dtype +from tinygrad.dtype import dtypes, DType, ImageDType, PtrDType, to_dtype, Invalid, InvalidType class TestImageDType(unittest.TestCase): def test_image_scalar(self): @@ -82,5 +82,12 @@ class TestCanLosslessCast(unittest.TestCase): self.assertTrue(can_lossless_cast(dtypes.int8, dtypes.half)) self.assertFalse(can_lossless_cast(dtypes.int8, dtypes.bfloat16)) +class TestInvalidSingleton(unittest.TestCase): + def test_singleton(self): + self.assertIs(InvalidType(), InvalidType()) + self.assertIs(InvalidType(), Invalid) + def test_pickle(self): + self.assertIs(pickle.loads(pickle.dumps(Invalid)), Invalid) + if __name__ == "__main__": unittest.main() diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 7b37a321c2..817685fb14 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -5,20 +5,17 @@ from dataclasses import dataclass, fields from tinygrad.helpers import getenv, prod, round_up, next_power2 from enum import Enum, auto -class InvalidTypeMetaClass(type): - instance:None|InvalidType = None - def __call__(cls): - if (ret:=InvalidTypeMetaClass.instance) is not None: return ret - InvalidTypeMetaClass.instance = ret = super().__call__() - return ret - -class InvalidType(metaclass=InvalidTypeMetaClass): +class InvalidType: + _instance: ClassVar[InvalidType|None] = None + def __new__(cls): + if cls._instance is None: cls._instance = object.__new__(cls) + return cls._instance def __eq__(self, other): return self is other def __lt__(self, other): return self is not other def __gt__(self, other): return self is not other def __hash__(self): return id(self) def __repr__(self): return "Invalid" - def __reduce__(self): return (InvalidType, ()) # Return the global Invalid instance + def __reduce__(self): return (InvalidType, ()) # unpickle returns the singleton Invalid = InvalidType()