mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
@@ -1,6 +1,6 @@
|
|||||||
import unittest
|
import unittest, pickle
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.dtype import dtypes, DType, ImageDType, PtrDType, to_dtype
|
from tinygrad.dtype import dtypes, DType, ImageDType, PtrDType, to_dtype, Invalid, InvalidType
|
||||||
|
|
||||||
class TestImageDType(unittest.TestCase):
|
class TestImageDType(unittest.TestCase):
|
||||||
def test_image_scalar(self):
|
def test_image_scalar(self):
|
||||||
@@ -82,5 +82,12 @@ class TestCanLosslessCast(unittest.TestCase):
|
|||||||
self.assertTrue(can_lossless_cast(dtypes.int8, dtypes.half))
|
self.assertTrue(can_lossless_cast(dtypes.int8, dtypes.half))
|
||||||
self.assertFalse(can_lossless_cast(dtypes.int8, dtypes.bfloat16))
|
self.assertFalse(can_lossless_cast(dtypes.int8, dtypes.bfloat16))
|
||||||
|
|
||||||
|
class TestInvalidSingleton(unittest.TestCase):
|
||||||
|
def test_singleton(self):
|
||||||
|
self.assertIs(InvalidType(), InvalidType())
|
||||||
|
self.assertIs(InvalidType(), Invalid)
|
||||||
|
def test_pickle(self):
|
||||||
|
self.assertIs(pickle.loads(pickle.dumps(Invalid)), Invalid)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -5,20 +5,17 @@ from dataclasses import dataclass, fields
|
|||||||
from tinygrad.helpers import getenv, prod, round_up, next_power2
|
from tinygrad.helpers import getenv, prod, round_up, next_power2
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
|
||||||
class InvalidTypeMetaClass(type):
|
class InvalidType:
|
||||||
instance:None|InvalidType = None
|
_instance: ClassVar[InvalidType|None] = None
|
||||||
def __call__(cls):
|
def __new__(cls):
|
||||||
if (ret:=InvalidTypeMetaClass.instance) is not None: return ret
|
if cls._instance is None: cls._instance = object.__new__(cls)
|
||||||
InvalidTypeMetaClass.instance = ret = super().__call__()
|
return cls._instance
|
||||||
return ret
|
|
||||||
|
|
||||||
class InvalidType(metaclass=InvalidTypeMetaClass):
|
|
||||||
def __eq__(self, other): return self is other
|
def __eq__(self, other): return self is other
|
||||||
def __lt__(self, other): return self is not other
|
def __lt__(self, other): return self is not other
|
||||||
def __gt__(self, other): return self is not other
|
def __gt__(self, other): return self is not other
|
||||||
def __hash__(self): return id(self)
|
def __hash__(self): return id(self)
|
||||||
def __repr__(self): return "Invalid"
|
def __repr__(self): return "Invalid"
|
||||||
def __reduce__(self): return (InvalidType, ()) # Return the global Invalid instance
|
def __reduce__(self): return (InvalidType, ()) # unpickle returns the singleton
|
||||||
|
|
||||||
Invalid = InvalidType()
|
Invalid = InvalidType()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user