Support bfloat16 on NULL backend (#12340)

* add failing test

* move test

* only run test with NULL default

* add skip reason

* add fix
This commit is contained in:
hooved
2025-09-30 00:02:30 -04:00
committed by GitHub
parent af935e7d32
commit 39aae679e4
3 changed files with 16 additions and 3 deletions

View File

@@ -264,7 +264,7 @@ jobs:
- name: Run unit tests
run: python -m pytest -n=auto test/unit/ --durations=20
- name: Run targetted tests on NULL backend
run: NULL=1 python3 test/test_multitensor.py TestMultiTensor.test_data_parallel_resnet_train_step
run: NULL=1 python3 -m unittest test.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step test/device/test_null.py
- name: Run SDXL on NULL backend
run: MAX_BUFFER_SIZE=0 NULL=1 DEBUG=1 python3 examples/sdxl.py --seed 0 --noshow --timing --fakeweights
- name: Run Clip tests for SD MLPerf on NULL backend

13
test/device/test_null.py Normal file
View File

@@ -0,0 +1,13 @@
import unittest
from tinygrad import dtypes, Device
from tinygrad.device import is_dtype_supported
@unittest.skipUnless(Device.DEFAULT=="NULL", "Don't run when testing non-NULL backends")
class TestNULLSupportsDTypes(unittest.TestCase):
def test_null_supports_ints_floats_bool(self):
dts = dtypes.ints + dtypes.floats + (dtypes.bool,)
not_supported = [dt for dt in dts if not is_dtype_supported(dt, "NULL")]
self.assertFalse(not_supported, msg=f"expected these dtypes to be supported by NULL: {not_supported}")
if __name__ == "__main__":
unittest.main()

View File

@@ -327,8 +327,8 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
if device == "METAL": return not CI
if device in {"CUDA", "NV"}: return not CI and not getenv(f"{device}_PTX")
if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"}
return device in {"AMD", "PYTHON"}
if dtype in dtypes.fp8s: return device == "PYTHON"
return device in {"AMD", "PYTHON", "NULL"}
if dtype in dtypes.fp8s: return device in {"PYTHON", "NULL"}
if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half]
# for CI GPU and OSX, cl_khr_fp16 isn't supported