don't const rewrite in cstyle (#7442)

* don't const rewrite in cstyle

* Update cstyle.py

* simple_symbolic

* fix bfloat16 const on AMD
This commit is contained in:
George Hotz
2024-10-31 18:16:49 +07:00
committed by GitHub
parent bdde795239
commit 5dd1ffd5d0
6 changed files with 26 additions and 13 deletions

View File

@@ -34,7 +34,7 @@ def assert_jit_cache_len(fxn, expected_len):
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
if dtype == dtypes.bfloat16:
# NOTE: this requires bf16 buffer support
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
return device in {"AMD"} or (device in {"CUDA", "NV", "METAL"} and not CI and not getenv("PTX"))
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
# for CI GPU and OSX, cl_khr_fp16 isn't supported
# for CI LLVM, it segfaults because it can't link to the casting function