mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
Support bfloat16 on NULL backend (#12340)
* add failing test * move test * only run test with NULL default * add skip reason * add fix
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -264,7 +264,7 @@ jobs:
|
||||
- name: Run unit tests
|
||||
run: python -m pytest -n=auto test/unit/ --durations=20
|
||||
- name: Run targetted tests on NULL backend
|
||||
run: NULL=1 python3 test/test_multitensor.py TestMultiTensor.test_data_parallel_resnet_train_step
|
||||
run: NULL=1 python3 -m unittest test.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step test/device/test_null.py
|
||||
- name: Run SDXL on NULL backend
|
||||
run: MAX_BUFFER_SIZE=0 NULL=1 DEBUG=1 python3 examples/sdxl.py --seed 0 --noshow --timing --fakeweights
|
||||
- name: Run Clip tests for SD MLPerf on NULL backend
|
||||
|
||||
13
test/device/test_null.py
Normal file
13
test/device/test_null.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import unittest
|
||||
from tinygrad import dtypes, Device
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT=="NULL", "Don't run when testing non-NULL backends")
|
||||
class TestNULLSupportsDTypes(unittest.TestCase):
|
||||
def test_null_supports_ints_floats_bool(self):
|
||||
dts = dtypes.ints + dtypes.floats + (dtypes.bool,)
|
||||
not_supported = [dt for dt in dts if not is_dtype_supported(dt, "NULL")]
|
||||
self.assertFalse(not_supported, msg=f"expected these dtypes to be supported by NULL: {not_supported}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -327,8 +327,8 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
|
||||
if device == "METAL": return not CI
|
||||
if device in {"CUDA", "NV"}: return not CI and not getenv(f"{device}_PTX")
|
||||
if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"}
|
||||
return device in {"AMD", "PYTHON"}
|
||||
if dtype in dtypes.fp8s: return device == "PYTHON"
|
||||
return device in {"AMD", "PYTHON", "NULL"}
|
||||
if dtype in dtypes.fp8s: return device in {"PYTHON", "NULL"}
|
||||
if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
|
||||
dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half]
|
||||
# for CI GPU and OSX, cl_khr_fp16 isn't supported
|
||||
|
||||
Reference in New Issue
Block a user