From 39aae679e4f91260646dd952e595ade7eebe72b6 Mon Sep 17 00:00:00 2001 From: hooved <172129504+hooved@users.noreply.github.com> Date: Tue, 30 Sep 2025 00:02:30 -0400 Subject: [PATCH] Support bfloat16 on NULL backend (#12340) * add failing test * move test * only run test with NULL default * add skip reason * add fix --- .github/workflows/test.yml | 2 +- test/device/test_null.py | 13 +++++++++++++ tinygrad/device.py | 4 ++-- 3 files changed, 16 insertions(+), 3 deletions(-) create mode 100644 test/device/test_null.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 24715bd5b4..a78957d54a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -264,7 +264,7 @@ jobs: - name: Run unit tests run: python -m pytest -n=auto test/unit/ --durations=20 - name: Run targetted tests on NULL backend - run: NULL=1 python3 test/test_multitensor.py TestMultiTensor.test_data_parallel_resnet_train_step + run: NULL=1 python3 -m unittest test.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step test/device/test_null.py - name: Run SDXL on NULL backend run: MAX_BUFFER_SIZE=0 NULL=1 DEBUG=1 python3 examples/sdxl.py --seed 0 --noshow --timing --fakeweights - name: Run Clip tests for SD MLPerf on NULL backend diff --git a/test/device/test_null.py b/test/device/test_null.py new file mode 100644 index 0000000000..d20b228ce2 --- /dev/null +++ b/test/device/test_null.py @@ -0,0 +1,13 @@ +import unittest +from tinygrad import dtypes, Device +from tinygrad.device import is_dtype_supported + +@unittest.skipUnless(Device.DEFAULT=="NULL", "Don't run when testing non-NULL backends") +class TestNULLSupportsDTypes(unittest.TestCase): + def test_null_supports_ints_floats_bool(self): + dts = dtypes.ints + dtypes.floats + (dtypes.bool,) + not_supported = [dt for dt in dts if not is_dtype_supported(dt, "NULL")] + self.assertFalse(not_supported, msg=f"expected these dtypes to be supported by NULL: {not_supported}") + +if __name__ == "__main__": + unittest.main() diff --git a/tinygrad/device.py b/tinygrad/device.py index 64fb3bb0ac..c099ac6998 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -327,8 +327,8 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool: if device == "METAL": return not CI if device in {"CUDA", "NV"}: return not CI and not getenv(f"{device}_PTX") if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} - return device in {"AMD", "PYTHON"} - if dtype in dtypes.fp8s: return device == "PYTHON" + return device in {"AMD", "PYTHON", "NULL"} + if dtype in dtypes.fp8s: return device in {"PYTHON", "NULL"} if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short, dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half] # for CI GPU and OSX, cl_khr_fp16 isn't supported