mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix
This commit is contained in:
@@ -247,6 +247,18 @@ class TestBFloat16DTypeCast(unittest.TestCase):
|
||||
|
||||
class TestHalfDType(TestDType): DTYPE = dtypes.half
|
||||
|
||||
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "half decomp requires bitshift")
|
||||
class TestEmulatedHalf(TestHalfDType):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.stack = contextlib.ExitStack()
|
||||
cls.stack.enter_context(Context(EMULATED_DTYPES="half"))
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls): cls.stack.close()
|
||||
|
||||
|
||||
class TestFloatDType(TestDType):
|
||||
DTYPE = dtypes.float
|
||||
|
||||
@@ -372,17 +384,6 @@ class TestBoolDType(TestDType): DTYPE = dtypes.bool
|
||||
|
||||
class TestBFloat16Type(TestDType): DTYPE = dtypes.bfloat16
|
||||
|
||||
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "half decomp requires bitshift")
|
||||
class TestEmulatedFloat16(TestFloat16Type):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.stack = contextlib.ExitStack()
|
||||
cls.stack.enter_context(Context(EMULATED_DTYPES="half"))
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls): cls.stack.close()
|
||||
|
||||
class TestFp8e4m3(TestDType): DTYPE = dtypes.fp8e4m3
|
||||
class TestFp8e5m2(TestDType): DTYPE = dtypes.fp8e5m2
|
||||
|
||||
|
||||
Reference in New Issue
Block a user