From 90474d076fff9181b00fea6aa1208f8b555f060b Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Fri, 30 Jan 2026 17:12:30 -0800 Subject: [PATCH] fix --- test/test_dtype.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index e1bbe695cc..803b65cdb6 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -247,6 +247,18 @@ class TestBFloat16DTypeCast(unittest.TestCase): class TestHalfDType(TestDType): DTYPE = dtypes.half +@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "half decomp requires bitshift") +class TestEmulatedHalf(TestHalfDType): + @classmethod + def setUpClass(cls): + cls.stack = contextlib.ExitStack() + cls.stack.enter_context(Context(EMULATED_DTYPES="half")) + cls.DATA = rand_for_dtype(cls.DTYPE, 10) + + @classmethod + def tearDownClass(cls): cls.stack.close() + + class TestFloatDType(TestDType): DTYPE = dtypes.float @@ -372,17 +384,6 @@ class TestBoolDType(TestDType): DTYPE = dtypes.bool class TestBFloat16Type(TestDType): DTYPE = dtypes.bfloat16 -@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "half decomp requires bitshift") -class TestEmulatedFloat16(TestFloat16Type): - @classmethod - def setUpClass(cls): - cls.stack = contextlib.ExitStack() - cls.stack.enter_context(Context(EMULATED_DTYPES="half")) - cls.DATA = rand_for_dtype(cls.DTYPE, 10) - - @classmethod - def tearDownClass(cls): cls.stack.close() - class TestFp8e4m3(TestDType): DTYPE = dtypes.fp8e4m3 class TestFp8e5m2(TestDType): DTYPE = dtypes.fp8e5m2