simpler InvalidType [pr] (#13957)

simpler singleton pattern
This commit is contained in:
chenyu
2026-01-01 13:55:51 -05:00
committed by GitHub
parent b8ea0d779c
commit 8e416df438
2 changed files with 15 additions and 11 deletions

View File

@@ -1,6 +1,6 @@
import unittest import unittest, pickle
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes, DType, ImageDType, PtrDType, to_dtype from tinygrad.dtype import dtypes, DType, ImageDType, PtrDType, to_dtype, Invalid, InvalidType
class TestImageDType(unittest.TestCase): class TestImageDType(unittest.TestCase):
def test_image_scalar(self): def test_image_scalar(self):
@@ -82,5 +82,12 @@ class TestCanLosslessCast(unittest.TestCase):
self.assertTrue(can_lossless_cast(dtypes.int8, dtypes.half)) self.assertTrue(can_lossless_cast(dtypes.int8, dtypes.half))
self.assertFalse(can_lossless_cast(dtypes.int8, dtypes.bfloat16)) self.assertFalse(can_lossless_cast(dtypes.int8, dtypes.bfloat16))
class TestInvalidSingleton(unittest.TestCase):
def test_singleton(self):
self.assertIs(InvalidType(), InvalidType())
self.assertIs(InvalidType(), Invalid)
def test_pickle(self):
self.assertIs(pickle.loads(pickle.dumps(Invalid)), Invalid)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -5,20 +5,17 @@ from dataclasses import dataclass, fields
from tinygrad.helpers import getenv, prod, round_up, next_power2 from tinygrad.helpers import getenv, prod, round_up, next_power2
from enum import Enum, auto from enum import Enum, auto
class InvalidTypeMetaClass(type): class InvalidType:
instance:None|InvalidType = None _instance: ClassVar[InvalidType|None] = None
def __call__(cls): def __new__(cls):
if (ret:=InvalidTypeMetaClass.instance) is not None: return ret if cls._instance is None: cls._instance = object.__new__(cls)
InvalidTypeMetaClass.instance = ret = super().__call__() return cls._instance
return ret
class InvalidType(metaclass=InvalidTypeMetaClass):
def __eq__(self, other): return self is other def __eq__(self, other): return self is other
def __lt__(self, other): return self is not other def __lt__(self, other): return self is not other
def __gt__(self, other): return self is not other def __gt__(self, other): return self is not other
def __hash__(self): return id(self) def __hash__(self): return id(self)
def __repr__(self): return "Invalid" def __repr__(self): return "Invalid"
def __reduce__(self): return (InvalidType, ()) # Return the global Invalid instance def __reduce__(self): return (InvalidType, ()) # unpickle returns the singleton
Invalid = InvalidType() Invalid = InvalidType()