This commit is contained in:
Christopher Milan
2026-01-30 17:12:30 -08:00
parent 51423a4cd0
commit 90474d076f

View File

@@ -247,6 +247,18 @@ class TestBFloat16DTypeCast(unittest.TestCase):
class TestHalfDType(TestDType): DTYPE = dtypes.half
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "half decomp requires bitshift")
class TestEmulatedHalf(TestHalfDType):
@classmethod
def setUpClass(cls):
cls.stack = contextlib.ExitStack()
cls.stack.enter_context(Context(EMULATED_DTYPES="half"))
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
@classmethod
def tearDownClass(cls): cls.stack.close()
class TestFloatDType(TestDType):
DTYPE = dtypes.float
@@ -372,17 +384,6 @@ class TestBoolDType(TestDType): DTYPE = dtypes.bool
class TestBFloat16Type(TestDType): DTYPE = dtypes.bfloat16
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "half decomp requires bitshift")
class TestEmulatedFloat16(TestFloat16Type):
@classmethod
def setUpClass(cls):
cls.stack = contextlib.ExitStack()
cls.stack.enter_context(Context(EMULATED_DTYPES="half"))
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
@classmethod
def tearDownClass(cls): cls.stack.close()
class TestFp8e4m3(TestDType): DTYPE = dtypes.fp8e4m3
class TestFp8e5m2(TestDType): DTYPE = dtypes.fp8e5m2